strawberry-graphql 0.255.0__py3-none-any.whl → 0.256.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. strawberry/__init__.py +9 -9
  2. strawberry/aiohttp/test/client.py +10 -8
  3. strawberry/aiohttp/views.py +5 -7
  4. strawberry/annotation.py +12 -15
  5. strawberry/asgi/__init__.py +3 -6
  6. strawberry/asgi/test/client.py +9 -8
  7. strawberry/chalice/views.py +4 -2
  8. strawberry/channels/__init__.py +1 -1
  9. strawberry/channels/handlers/base.py +3 -7
  10. strawberry/channels/handlers/http_handler.py +5 -6
  11. strawberry/channels/handlers/ws_handler.py +3 -4
  12. strawberry/channels/testing.py +5 -9
  13. strawberry/cli/commands/codegen.py +9 -9
  14. strawberry/cli/commands/upgrade/__init__.py +2 -3
  15. strawberry/cli/commands/upgrade/_run_codemod.py +7 -5
  16. strawberry/codegen/exceptions.py +2 -2
  17. strawberry/codegen/plugins/print_operation.py +6 -6
  18. strawberry/codegen/plugins/python.py +6 -6
  19. strawberry/codegen/plugins/typescript.py +3 -3
  20. strawberry/codegen/query_codegen.py +29 -34
  21. strawberry/codegen/types.py +35 -34
  22. strawberry/codemods/annotated_unions.py +5 -2
  23. strawberry/dataloader.py +13 -20
  24. strawberry/directive.py +12 -5
  25. strawberry/django/test/client.py +4 -4
  26. strawberry/django/views.py +4 -5
  27. strawberry/exceptions/__init__.py +24 -24
  28. strawberry/exceptions/conflicting_arguments.py +2 -2
  29. strawberry/exceptions/duplicated_type_name.py +3 -3
  30. strawberry/exceptions/handler.py +7 -7
  31. strawberry/exceptions/invalid_union_type.py +2 -2
  32. strawberry/exceptions/missing_arguments_annotations.py +2 -2
  33. strawberry/exceptions/missing_field_annotation.py +2 -2
  34. strawberry/exceptions/object_is_not_an_enum.py +2 -2
  35. strawberry/exceptions/private_strawberry_field.py +2 -2
  36. strawberry/exceptions/syntax.py +4 -4
  37. strawberry/exceptions/utils/source_finder.py +7 -6
  38. strawberry/experimental/pydantic/__init__.py +3 -3
  39. strawberry/experimental/pydantic/_compat.py +14 -14
  40. strawberry/experimental/pydantic/conversion.py +2 -2
  41. strawberry/experimental/pydantic/conversion_types.py +3 -3
  42. strawberry/experimental/pydantic/error_type.py +18 -16
  43. strawberry/experimental/pydantic/exceptions.py +5 -5
  44. strawberry/experimental/pydantic/fields.py +2 -13
  45. strawberry/experimental/pydantic/object_type.py +20 -22
  46. strawberry/experimental/pydantic/utils.py +6 -10
  47. strawberry/ext/dataclasses/dataclasses.py +3 -3
  48. strawberry/ext/mypy_plugin.py +6 -9
  49. strawberry/extensions/__init__.py +7 -8
  50. strawberry/extensions/add_validation_rules.py +5 -3
  51. strawberry/extensions/base_extension.py +4 -4
  52. strawberry/extensions/context.py +15 -14
  53. strawberry/extensions/directives.py +2 -2
  54. strawberry/extensions/disable_validation.py +1 -1
  55. strawberry/extensions/field_extension.py +2 -1
  56. strawberry/extensions/mask_errors.py +3 -2
  57. strawberry/extensions/max_aliases.py +2 -2
  58. strawberry/extensions/max_tokens.py +1 -1
  59. strawberry/extensions/parser_cache.py +2 -1
  60. strawberry/extensions/pyinstrument.py +5 -2
  61. strawberry/extensions/query_depth_limiter.py +13 -13
  62. strawberry/extensions/runner.py +7 -7
  63. strawberry/extensions/tracing/apollo.py +11 -9
  64. strawberry/extensions/tracing/datadog.py +3 -1
  65. strawberry/extensions/tracing/opentelemetry.py +7 -10
  66. strawberry/extensions/utils.py +3 -3
  67. strawberry/extensions/validation_cache.py +2 -1
  68. strawberry/fastapi/context.py +3 -3
  69. strawberry/fastapi/router.py +9 -14
  70. strawberry/federation/__init__.py +4 -4
  71. strawberry/federation/argument.py +2 -1
  72. strawberry/federation/enum.py +8 -8
  73. strawberry/federation/field.py +25 -28
  74. strawberry/federation/object_type.py +24 -26
  75. strawberry/federation/scalar.py +7 -8
  76. strawberry/federation/schema.py +30 -36
  77. strawberry/federation/schema_directive.py +5 -5
  78. strawberry/federation/schema_directives.py +14 -14
  79. strawberry/federation/union.py +3 -2
  80. strawberry/field_extensions/input_mutation.py +1 -2
  81. strawberry/file_uploads/utils.py +4 -3
  82. strawberry/flask/views.py +3 -2
  83. strawberry/http/__init__.py +6 -6
  84. strawberry/http/async_base_view.py +9 -14
  85. strawberry/http/base.py +5 -4
  86. strawberry/http/ides.py +1 -1
  87. strawberry/http/parse_content_type.py +1 -2
  88. strawberry/http/sync_base_view.py +3 -5
  89. strawberry/http/temporal_response.py +1 -2
  90. strawberry/http/types.py +3 -2
  91. strawberry/litestar/controller.py +8 -14
  92. strawberry/parent.py +1 -2
  93. strawberry/permission.py +6 -8
  94. strawberry/printer/ast_from_value.py +2 -1
  95. strawberry/printer/printer.py +50 -30
  96. strawberry/quart/views.py +3 -3
  97. strawberry/relay/exceptions.py +4 -4
  98. strawberry/relay/fields.py +22 -24
  99. strawberry/relay/types.py +29 -27
  100. strawberry/relay/utils.py +4 -4
  101. strawberry/sanic/utils.py +4 -4
  102. strawberry/sanic/views.py +5 -7
  103. strawberry/scalars.py +2 -2
  104. strawberry/schema/base.py +16 -11
  105. strawberry/schema/compat.py +4 -4
  106. strawberry/schema/execute.py +6 -10
  107. strawberry/schema/name_converter.py +3 -3
  108. strawberry/schema/schema.py +37 -25
  109. strawberry/schema/schema_converter.py +22 -24
  110. strawberry/schema/subscribe.py +4 -3
  111. strawberry/schema/types/base_scalars.py +1 -1
  112. strawberry/schema/types/concrete_type.py +2 -2
  113. strawberry/schema/types/scalar.py +3 -4
  114. strawberry/schema_codegen/__init__.py +4 -4
  115. strawberry/schema_directive.py +8 -8
  116. strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +8 -9
  117. strawberry/subscriptions/protocols/graphql_transport_ws/types.py +16 -16
  118. strawberry/subscriptions/protocols/graphql_ws/handlers.py +6 -5
  119. strawberry/subscriptions/protocols/graphql_ws/types.py +13 -13
  120. strawberry/test/__init__.py +1 -1
  121. strawberry/test/client.py +21 -19
  122. strawberry/tools/create_type.py +4 -3
  123. strawberry/tools/merge_types.py +1 -2
  124. strawberry/types/__init__.py +1 -1
  125. strawberry/types/arguments.py +10 -12
  126. strawberry/types/auto.py +2 -2
  127. strawberry/types/base.py +17 -21
  128. strawberry/types/enum.py +3 -5
  129. strawberry/types/execution.py +8 -12
  130. strawberry/types/field.py +26 -31
  131. strawberry/types/fields/resolver.py +15 -17
  132. strawberry/types/graphql.py +2 -2
  133. strawberry/types/info.py +5 -9
  134. strawberry/types/lazy_type.py +3 -5
  135. strawberry/types/mutation.py +25 -28
  136. strawberry/types/nodes.py +11 -9
  137. strawberry/types/object_type.py +14 -16
  138. strawberry/types/private.py +1 -2
  139. strawberry/types/scalar.py +2 -2
  140. strawberry/types/type_resolver.py +5 -5
  141. strawberry/types/union.py +8 -11
  142. strawberry/types/unset.py +3 -3
  143. strawberry/utils/aio.py +3 -8
  144. strawberry/utils/await_maybe.py +3 -2
  145. strawberry/utils/debug.py +2 -2
  146. strawberry/utils/deprecations.py +2 -2
  147. strawberry/utils/inspect.py +3 -5
  148. strawberry/utils/str_converters.py +1 -1
  149. strawberry/utils/typing.py +38 -67
  150. {strawberry_graphql-0.255.0.dist-info → strawberry_graphql-0.256.1.dist-info}/METADATA +3 -6
  151. strawberry_graphql-0.256.1.dist-info/RECORD +236 -0
  152. strawberry_graphql-0.255.0.dist-info/RECORD +0 -236
  153. {strawberry_graphql-0.255.0.dist-info → strawberry_graphql-0.256.1.dist-info}/LICENSE +0 -0
  154. {strawberry_graphql-0.255.0.dist-info → strawberry_graphql-0.256.1.dist-info}/WHEEL +0 -0
  155. {strawberry_graphql-0.255.0.dist-info → strawberry_graphql-0.256.1.dist-info}/entry_points.txt +0 -0
@@ -6,14 +6,9 @@ from functools import partial, reduce
6
6
  from typing import (
7
7
  TYPE_CHECKING,
8
8
  Any,
9
- Awaitable,
10
9
  Callable,
11
- Dict,
12
10
  Generic,
13
- List,
14
11
  Optional,
15
- Tuple,
16
- Type,
17
12
  TypeVar,
18
13
  Union,
19
14
  cast,
@@ -76,6 +71,8 @@ from . import compat
76
71
  from .types.concrete_type import ConcreteType
77
72
 
78
73
  if TYPE_CHECKING:
74
+ from collections.abc import Awaitable
75
+
79
76
  from graphql import (
80
77
  GraphQLInputType,
81
78
  GraphQLNullableType,
@@ -111,8 +108,8 @@ def _get_thunk_mapping(
111
108
  type_definition: StrawberryObjectDefinition,
112
109
  name_converter: Callable[[StrawberryField], str],
113
110
  field_converter: FieldConverterProtocol[FieldType],
114
- get_fields: Callable[[StrawberryObjectDefinition], List[StrawberryField]],
115
- ) -> Dict[str, FieldType]:
111
+ get_fields: Callable[[StrawberryObjectDefinition], list[StrawberryField]],
112
+ ) -> dict[str, FieldType]:
116
113
  """Create a GraphQL core `ThunkMapping` mapping of field names to field types.
117
114
 
118
115
  This method filters out remaining `strawberry.Private` annotated fields that
@@ -124,7 +121,7 @@ def _get_thunk_mapping(
124
121
  Raises:
125
122
  TypeError: If the type of a field in ``fields`` is `UNRESOLVED`
126
123
  """
127
- thunk_mapping: Dict[str, FieldType] = {}
124
+ thunk_mapping: dict[str, FieldType] = {}
128
125
 
129
126
  fields = get_fields(type_definition)
130
127
 
@@ -173,7 +170,7 @@ class CustomGraphQLEnumType(GraphQLEnumType):
173
170
  return self.wrapped_cls(super().parse_value(input_value))
174
171
 
175
172
  def parse_literal(
176
- self, value_node: ValueNode, _variables: Optional[Dict[str, Any]] = None
173
+ self, value_node: ValueNode, _variables: Optional[dict[str, Any]] = None
177
174
  ) -> Any:
178
175
  return self.wrapped_cls(super().parse_literal(value_node, _variables))
179
176
 
@@ -185,8 +182,8 @@ def get_arguments(
185
182
  info: Info,
186
183
  kwargs: Any,
187
184
  config: StrawberryConfig,
188
- scalar_registry: Dict[object, Union[ScalarWrapper, ScalarDefinition]],
189
- ) -> Tuple[List[Any], Dict[str, Any]]:
185
+ scalar_registry: dict[object, Union[ScalarWrapper, ScalarDefinition]],
186
+ ) -> tuple[list[Any], dict[str, Any]]:
190
187
  # TODO: An extension might have changed the resolver arguments,
191
188
  # but we need them here since we are calling it.
192
189
  # This is a bit of a hack, but it's the easiest way to get the arguments
@@ -242,10 +239,10 @@ class GraphQLCoreConverter:
242
239
  def __init__(
243
240
  self,
244
241
  config: StrawberryConfig,
245
- scalar_registry: Dict[object, Union[ScalarWrapper, ScalarDefinition]],
246
- get_fields: Callable[[StrawberryObjectDefinition], List[StrawberryField]],
242
+ scalar_registry: dict[object, Union[ScalarWrapper, ScalarDefinition]],
243
+ get_fields: Callable[[StrawberryObjectDefinition], list[StrawberryField]],
247
244
  ) -> None:
248
- self.type_map: Dict[str, ConcreteType] = {}
245
+ self.type_map: dict[str, ConcreteType] = {}
249
246
  self.config = config
250
247
  self.scalar_registry = scalar_registry
251
248
  self.get_fields = get_fields
@@ -329,13 +326,14 @@ class GraphQLCoreConverter:
329
326
  },
330
327
  )
331
328
 
332
- def from_schema_directive(self, cls: Type) -> GraphQLDirective:
329
+ def from_schema_directive(self, cls: type) -> GraphQLDirective:
333
330
  strawberry_directive = cast(
334
- "StrawberrySchemaDirective", cls.__strawberry_directive__
331
+ "StrawberrySchemaDirective",
332
+ cls.__strawberry_directive__, # type: ignore[attr-defined]
335
333
  )
336
334
  module = sys.modules[cls.__module__]
337
335
 
338
- args: Dict[str, GraphQLArgument] = {}
336
+ args: dict[str, GraphQLArgument] = {}
339
337
  for field in strawberry_directive.fields:
340
338
  default = field.default
341
339
  if default == dataclasses.MISSING:
@@ -436,7 +434,7 @@ class GraphQLCoreConverter:
436
434
 
437
435
  def get_graphql_fields(
438
436
  self, type_definition: StrawberryObjectDefinition
439
- ) -> Dict[str, GraphQLField]:
437
+ ) -> dict[str, GraphQLField]:
440
438
  return _get_thunk_mapping(
441
439
  type_definition=type_definition,
442
440
  name_converter=self.config.name_converter.from_field,
@@ -446,7 +444,7 @@ class GraphQLCoreConverter:
446
444
 
447
445
  def get_graphql_input_fields(
448
446
  self, type_definition: StrawberryObjectDefinition
449
- ) -> Dict[str, GraphQLInputField]:
447
+ ) -> dict[str, GraphQLInputField]:
450
448
  return _get_thunk_mapping(
451
449
  type_definition=type_definition,
452
450
  name_converter=self.config.name_converter.from_field,
@@ -672,8 +670,8 @@ class GraphQLCoreConverter:
672
670
  def _get_result(
673
671
  _source: Any,
674
672
  info: Info,
675
- field_args: List[Any],
676
- field_kwargs: Dict[str, Any],
673
+ field_args: list[Any],
674
+ field_kwargs: dict[str, Any],
677
675
  ) -> Any:
678
676
  return field.get_result(
679
677
  _source, info=info, args=field_args, kwargs=field_kwargs
@@ -762,7 +760,7 @@ class GraphQLCoreConverter:
762
760
  _resolver._is_default = not field.base_resolver # type: ignore
763
761
  return _resolver
764
762
 
765
- def from_scalar(self, scalar: Type) -> GraphQLScalarType:
763
+ def from_scalar(self, scalar: type) -> GraphQLScalarType:
766
764
  scalar_definition: ScalarDefinition
767
765
 
768
766
  if scalar in self.scalar_registry:
@@ -773,7 +771,7 @@ class GraphQLCoreConverter:
773
771
  else:
774
772
  scalar_definition = _scalar_definition
775
773
  else:
776
- scalar_definition = scalar._scalar_definition
774
+ scalar_definition = scalar._scalar_definition # type: ignore[attr-defined]
777
775
 
778
776
  scalar_name = self.config.name_converter.from_type(scalar_definition)
779
777
 
@@ -864,7 +862,7 @@ class GraphQLCoreConverter:
864
862
  assert isinstance(graphql_union, GraphQLUnionType) # For mypy
865
863
  return graphql_union
866
864
 
867
- graphql_types: List[GraphQLObjectType] = []
865
+ graphql_types: list[GraphQLObjectType] = []
868
866
  for type_ in union.types:
869
867
  graphql_type = self.from_type(type_)
870
868
 
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Optional, Type, Union
3
+ from collections.abc import AsyncGenerator, AsyncIterator
4
+ from typing import TYPE_CHECKING, Optional, Union
4
5
 
5
6
  from graphql import (
6
7
  ExecutionResult as OriginalExecutionResult,
@@ -44,7 +45,7 @@ async def _subscribe(
44
45
  extensions_runner: SchemaExtensionsRunner,
45
46
  process_errors: ProcessErrors,
46
47
  middleware_manager: MiddlewareManager,
47
- execution_context_class: Optional[Type[GraphQLExecutionContext]] = None,
48
+ execution_context_class: Optional[type[GraphQLExecutionContext]] = None,
48
49
  ) -> AsyncGenerator[Union[PreExecutionError, ExecutionResult], None]:
49
50
  async with extensions_runner.operation():
50
51
  if initial_error := await _parse_and_validate_async(
@@ -128,7 +129,7 @@ async def subscribe(
128
129
  extensions_runner: SchemaExtensionsRunner,
129
130
  process_errors: ProcessErrors,
130
131
  middleware_manager: MiddlewareManager,
131
- execution_context_class: Optional[Type[GraphQLExecutionContext]] = None,
132
+ execution_context_class: Optional[type[GraphQLExecutionContext]] = None,
132
133
  ) -> SubscriptionResult:
133
134
  asyncgen = _subscribe(
134
135
  schema,
@@ -81,4 +81,4 @@ Void = scalar(
81
81
  description="Represents NULL values",
82
82
  )
83
83
 
84
- __all__ = ["Date", "DateTime", "Time", "Decimal", "UUID", "Void"]
84
+ __all__ = ["UUID", "Date", "DateTime", "Decimal", "Time", "Void"]
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
- from typing import TYPE_CHECKING, Dict, Union
4
+ from typing import TYPE_CHECKING, Union
5
5
 
6
6
  from graphql import GraphQLField, GraphQLInputField, GraphQLType
7
7
 
@@ -22,7 +22,7 @@ class ConcreteType:
22
22
  implementation: GraphQLType
23
23
 
24
24
 
25
- TypeMap = Dict[str, ConcreteType]
25
+ TypeMap = dict[str, ConcreteType]
26
26
 
27
27
 
28
28
  __all__ = ["ConcreteType", "Field", "GraphQLType", "TypeMap"]
@@ -1,6 +1,5 @@
1
1
  import datetime
2
2
  import decimal
3
- from typing import Dict, Type
4
3
  from uuid import UUID
5
4
 
6
5
  from graphql import (
@@ -45,11 +44,11 @@ def _make_scalar_definition(scalar_type: GraphQLScalarType) -> ScalarDefinition:
45
44
  )
46
45
 
47
46
 
48
- def _get_scalar_definition(scalar: Type) -> ScalarDefinition:
49
- return scalar._scalar_definition
47
+ def _get_scalar_definition(scalar: type) -> ScalarDefinition:
48
+ return scalar._scalar_definition # type: ignore[attr-defined]
50
49
 
51
50
 
52
- DEFAULT_SCALAR_REGISTRY: Dict[object, ScalarDefinition] = {
51
+ DEFAULT_SCALAR_REGISTRY: dict[object, ScalarDefinition] = {
53
52
  type(None): _get_scalar_definition(base_scalars.Void),
54
53
  None: _get_scalar_definition(base_scalars.Void),
55
54
  str: _make_scalar_definition(GraphQLString),
@@ -3,11 +3,11 @@ from __future__ import annotations
3
3
  import dataclasses
4
4
  import keyword
5
5
  from collections import defaultdict
6
- from typing import TYPE_CHECKING, List, Tuple, Union
6
+ from graphlib import TopologicalSorter
7
+ from typing import TYPE_CHECKING, Union
7
8
  from typing_extensions import Protocol, TypeAlias
8
9
 
9
10
  import libcst as cst
10
- from graphlib import TopologicalSorter
11
11
  from graphql import (
12
12
  EnumTypeDefinitionNode,
13
13
  EnumValueDefinitionNode,
@@ -42,7 +42,7 @@ if TYPE_CHECKING:
42
42
 
43
43
 
44
44
  class HasDirectives(Protocol):
45
- directives: Tuple[ConstDirectiveNode, ...]
45
+ directives: tuple[ConstDirectiveNode, ...]
46
46
 
47
47
 
48
48
  _SCALAR_MAP = {
@@ -256,7 +256,7 @@ def _get_field(
256
256
  )
257
257
 
258
258
 
259
- ArgumentValue: TypeAlias = Union[str, bool, List["ArgumentValue"]]
259
+ ArgumentValue: TypeAlias = Union[str, bool, list["ArgumentValue"]]
260
260
 
261
261
 
262
262
  def _get_argument_value(argument_value: ConstValueNode) -> ArgumentValue:
@@ -1,6 +1,6 @@
1
1
  import dataclasses
2
2
  from enum import Enum
3
- from typing import Callable, List, Optional, Type, TypeVar
3
+ from typing import Callable, Optional, TypeVar
4
4
  from typing_extensions import dataclass_transform
5
5
 
6
6
  from strawberry.types.field import StrawberryField, field
@@ -28,15 +28,15 @@ class Location(Enum):
28
28
  class StrawberrySchemaDirective:
29
29
  python_name: str
30
30
  graphql_name: Optional[str]
31
- locations: List[Location]
32
- fields: List["StrawberryField"]
31
+ locations: list[Location]
32
+ fields: list["StrawberryField"]
33
33
  description: Optional[str] = None
34
34
  repeatable: bool = False
35
35
  print_definition: bool = True
36
- origin: Optional[Type] = None
36
+ origin: Optional[type] = None
37
37
 
38
38
 
39
- T = TypeVar("T", bound=Type)
39
+ T = TypeVar("T", bound=type)
40
40
 
41
41
 
42
42
  @dataclass_transform(
@@ -46,17 +46,17 @@ T = TypeVar("T", bound=Type)
46
46
  )
47
47
  def schema_directive(
48
48
  *,
49
- locations: List[Location],
49
+ locations: list[Location],
50
50
  description: Optional[str] = None,
51
51
  name: Optional[str] = None,
52
52
  repeatable: bool = False,
53
53
  print_definition: bool = True,
54
- ) -> Callable[..., T]:
54
+ ) -> Callable[[T], T]:
55
55
  def _wrap(cls: T) -> T:
56
56
  cls = _wrap_dataclass(cls) # type: ignore
57
57
  fields = _get_fields(cls, {})
58
58
 
59
- cls.__strawberry_directive__ = StrawberrySchemaDirective(
59
+ cls.__strawberry_directive__ = StrawberrySchemaDirective( # type: ignore[attr-defined]
60
60
  python_name=cls.__name__,
61
61
  graphql_name=name,
62
62
  locations=locations,
@@ -2,14 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  import asyncio
4
4
  import logging
5
+ from collections.abc import Awaitable
5
6
  from contextlib import suppress
6
7
  from typing import (
7
8
  TYPE_CHECKING,
8
9
  Any,
9
- Awaitable,
10
- Dict,
11
10
  Generic,
12
- List,
13
11
  Optional,
14
12
  cast,
15
13
  )
@@ -40,6 +38,7 @@ from strawberry.utils.debug import pretty_print_graphql_operation
40
38
  from strawberry.utils.operation import get_operation_type
41
39
 
42
40
  if TYPE_CHECKING:
41
+ from collections.abc import Awaitable
43
42
  from datetime import timedelta
44
43
 
45
44
  from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter
@@ -71,8 +70,8 @@ class BaseGraphQLTransportWSHandler(Generic[Context, RootValue]):
71
70
  self.connection_init_received = False
72
71
  self.connection_acknowledged = False
73
72
  self.connection_timed_out = False
74
- self.operations: Dict[str, Operation[Context, RootValue]] = {}
75
- self.completed_tasks: List[asyncio.Task] = []
73
+ self.operations: dict[str, Operation[Context, RootValue]] = {}
74
+ self.completed_tasks: list[asyncio.Task] = []
76
75
 
77
76
  async def handle(self) -> None:
78
77
  self.on_request_accepted()
@@ -343,14 +342,14 @@ class Operation(Generic[Context, RootValue]):
343
342
  """A class encapsulating a single operation with its id. Helps enforce protocol state transition."""
344
343
 
345
344
  __slots__ = [
345
+ "completed",
346
346
  "handler",
347
347
  "id",
348
+ "operation_name",
348
349
  "operation_type",
349
350
  "query",
350
- "variables",
351
- "operation_name",
352
- "completed",
353
351
  "task",
352
+ "variables",
354
353
  ]
355
354
 
356
355
  def __init__(
@@ -359,7 +358,7 @@ class Operation(Generic[Context, RootValue]):
359
358
  id: str,
360
359
  operation_type: OperationType,
361
360
  query: str,
362
- variables: Optional[Dict[str, object]],
361
+ variables: Optional[dict[str, object]],
363
362
  operation_name: Optional[str],
364
363
  ) -> None:
365
364
  self.handler = handler
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, TypedDict, Union
1
+ from typing import TypedDict, Union
2
2
  from typing_extensions import Literal, NotRequired
3
3
 
4
4
  from graphql import GraphQLFormattedError
@@ -8,35 +8,35 @@ class ConnectionInitMessage(TypedDict):
8
8
  """Direction: Client -> Server."""
9
9
 
10
10
  type: Literal["connection_init"]
11
- payload: NotRequired[Union[Dict[str, object], None]]
11
+ payload: NotRequired[Union[dict[str, object], None]]
12
12
 
13
13
 
14
14
  class ConnectionAckMessage(TypedDict):
15
15
  """Direction: Server -> Client."""
16
16
 
17
17
  type: Literal["connection_ack"]
18
- payload: NotRequired[Union[Dict[str, object], None]]
18
+ payload: NotRequired[Union[dict[str, object], None]]
19
19
 
20
20
 
21
21
  class PingMessage(TypedDict):
22
22
  """Direction: bidirectional."""
23
23
 
24
24
  type: Literal["ping"]
25
- payload: NotRequired[Union[Dict[str, object], None]]
25
+ payload: NotRequired[Union[dict[str, object], None]]
26
26
 
27
27
 
28
28
  class PongMessage(TypedDict):
29
29
  """Direction: bidirectional."""
30
30
 
31
31
  type: Literal["pong"]
32
- payload: NotRequired[Union[Dict[str, object], None]]
32
+ payload: NotRequired[Union[dict[str, object], None]]
33
33
 
34
34
 
35
35
  class SubscribeMessagePayload(TypedDict):
36
36
  operationName: NotRequired[Union[str, None]]
37
37
  query: str
38
- variables: NotRequired[Union[Dict[str, object], None]]
39
- extensions: NotRequired[Union[Dict[str, object], None]]
38
+ variables: NotRequired[Union[dict[str, object], None]]
39
+ extensions: NotRequired[Union[dict[str, object], None]]
40
40
 
41
41
 
42
42
  class SubscribeMessage(TypedDict):
@@ -48,9 +48,9 @@ class SubscribeMessage(TypedDict):
48
48
 
49
49
 
50
50
  class NextMessagePayload(TypedDict):
51
- errors: NotRequired[List[GraphQLFormattedError]]
52
- data: NotRequired[Union[Dict[str, object], None]]
53
- extensions: NotRequired[Dict[str, object]]
51
+ errors: NotRequired[list[GraphQLFormattedError]]
52
+ data: NotRequired[Union[dict[str, object], None]]
53
+ extensions: NotRequired[dict[str, object]]
54
54
 
55
55
 
56
56
  class NextMessage(TypedDict):
@@ -66,7 +66,7 @@ class ErrorMessage(TypedDict):
66
66
 
67
67
  id: str
68
68
  type: Literal["error"]
69
- payload: List[GraphQLFormattedError]
69
+ payload: list[GraphQLFormattedError]
70
70
 
71
71
 
72
72
  class CompleteMessage(TypedDict):
@@ -89,13 +89,13 @@ Message = Union[
89
89
 
90
90
 
91
91
  __all__ = [
92
- "ConnectionInitMessage",
92
+ "CompleteMessage",
93
93
  "ConnectionAckMessage",
94
+ "ConnectionInitMessage",
95
+ "ErrorMessage",
96
+ "Message",
97
+ "NextMessage",
94
98
  "PingMessage",
95
99
  "PongMessage",
96
100
  "SubscribeMessage",
97
- "NextMessage",
98
- "ErrorMessage",
99
- "CompleteMessage",
100
- "Message",
101
101
  ]
@@ -1,12 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ from collections.abc import AsyncGenerator
4
5
  from contextlib import suppress
5
6
  from typing import (
6
7
  TYPE_CHECKING,
7
8
  Any,
8
- AsyncGenerator,
9
- Dict,
10
9
  Generic,
11
10
  Optional,
12
11
  cast,
@@ -28,6 +27,8 @@ from strawberry.types.unset import UnsetType
28
27
  from strawberry.utils.debug import pretty_print_graphql_operation
29
28
 
30
29
  if TYPE_CHECKING:
30
+ from collections.abc import AsyncGenerator
31
+
31
32
  from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter
32
33
  from strawberry.schema import BaseSchema
33
34
 
@@ -56,8 +57,8 @@ class BaseGraphQLWSHandler(Generic[Context, RootValue]):
56
57
  self.keep_alive = keep_alive
57
58
  self.keep_alive_interval = keep_alive_interval
58
59
  self.keep_alive_task: Optional[asyncio.Task] = None
59
- self.subscriptions: Dict[str, AsyncGenerator] = {}
60
- self.tasks: Dict[str, asyncio.Task] = {}
60
+ self.subscriptions: dict[str, AsyncGenerator] = {}
61
+ self.tasks: dict[str, asyncio.Task] = {}
61
62
 
62
63
  async def handle(self) -> None:
63
64
  try:
@@ -164,7 +165,7 @@ class BaseGraphQLWSHandler(Generic[Context, RootValue]):
164
165
  operation_id: str,
165
166
  query: str,
166
167
  operation_name: Optional[str],
167
- variables: Optional[Dict[str, object]],
168
+ variables: Optional[dict[str, object]],
168
169
  ) -> None:
169
170
  try:
170
171
  agen_or_err = await self.schema.subscribe(
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, TypedDict, Union
1
+ from typing import TypedDict, Union
2
2
  from typing_extensions import Literal, NotRequired
3
3
 
4
4
  from graphql import GraphQLFormattedError
@@ -6,12 +6,12 @@ from graphql import GraphQLFormattedError
6
6
 
7
7
  class ConnectionInitMessage(TypedDict):
8
8
  type: Literal["connection_init"]
9
- payload: NotRequired[Dict[str, object]]
9
+ payload: NotRequired[dict[str, object]]
10
10
 
11
11
 
12
12
  class StartMessagePayload(TypedDict):
13
13
  query: str
14
- variables: NotRequired[Dict[str, object]]
14
+ variables: NotRequired[dict[str, object]]
15
15
  operationName: NotRequired[str]
16
16
 
17
17
 
@@ -32,20 +32,20 @@ class ConnectionTerminateMessage(TypedDict):
32
32
 
33
33
  class ConnectionErrorMessage(TypedDict):
34
34
  type: Literal["connection_error"]
35
- payload: NotRequired[Dict[str, object]]
35
+ payload: NotRequired[dict[str, object]]
36
36
 
37
37
 
38
38
  class ConnectionAckMessage(TypedDict):
39
39
  type: Literal["connection_ack"]
40
- payload: NotRequired[Dict[str, object]]
40
+ payload: NotRequired[dict[str, object]]
41
41
 
42
42
 
43
43
  class DataMessagePayload(TypedDict):
44
44
  data: object
45
- errors: NotRequired[List[GraphQLFormattedError]]
45
+ errors: NotRequired[list[GraphQLFormattedError]]
46
46
 
47
47
  # Non-standard field:
48
- extensions: NotRequired[Dict[str, object]]
48
+ extensions: NotRequired[dict[str, object]]
49
49
 
50
50
 
51
51
  class DataMessage(TypedDict):
@@ -84,15 +84,15 @@ OperationMessage = Union[
84
84
 
85
85
 
86
86
  __all__ = [
87
+ "CompleteMessage",
88
+ "ConnectionAckMessage",
89
+ "ConnectionErrorMessage",
87
90
  "ConnectionInitMessage",
88
- "StartMessage",
89
- "StopMessage",
91
+ "ConnectionKeepAliveMessage",
90
92
  "ConnectionTerminateMessage",
91
- "ConnectionErrorMessage",
92
- "ConnectionAckMessage",
93
93
  "DataMessage",
94
94
  "ErrorMessage",
95
- "CompleteMessage",
96
- "ConnectionKeepAliveMessage",
97
95
  "OperationMessage",
96
+ "StartMessage",
97
+ "StopMessage",
98
98
  ]
@@ -1,3 +1,3 @@
1
1
  from .client import BaseGraphQLTestClient, Body, Response
2
2
 
3
- __all__ = ["Body", "Response", "BaseGraphQLTestClient"]
3
+ __all__ = ["BaseGraphQLTestClient", "Body", "Response"]
strawberry/test/client.py CHANGED
@@ -4,23 +4,25 @@ import json
4
4
  import warnings
5
5
  from abc import ABC, abstractmethod
6
6
  from dataclasses import dataclass
7
- from typing import TYPE_CHECKING, Any, Coroutine, Dict, List, Mapping, Optional, Union
7
+ from typing import TYPE_CHECKING, Any, Optional, Union
8
8
  from typing_extensions import Literal, TypedDict
9
9
 
10
10
  if TYPE_CHECKING:
11
+ from collections.abc import Coroutine, Mapping
12
+
11
13
  from graphql import GraphQLFormattedError
12
14
 
13
15
 
14
16
  @dataclass
15
17
  class Response:
16
- errors: Optional[List[GraphQLFormattedError]]
17
- data: Optional[Dict[str, object]]
18
- extensions: Optional[Dict[str, object]]
18
+ errors: Optional[list[GraphQLFormattedError]]
19
+ data: Optional[dict[str, object]]
20
+ extensions: Optional[dict[str, object]]
19
21
 
20
22
 
21
23
  class Body(TypedDict, total=False):
22
24
  query: str
23
- variables: Optional[Dict[str, object]]
25
+ variables: Optional[dict[str, object]]
24
26
 
25
27
 
26
28
  class BaseGraphQLTestClient(ABC):
@@ -35,10 +37,10 @@ class BaseGraphQLTestClient(ABC):
35
37
  def query(
36
38
  self,
37
39
  query: str,
38
- variables: Optional[Dict[str, Mapping]] = None,
39
- headers: Optional[Dict[str, object]] = None,
40
+ variables: Optional[dict[str, Mapping]] = None,
41
+ headers: Optional[dict[str, object]] = None,
40
42
  asserts_errors: Optional[bool] = None,
41
- files: Optional[Dict[str, object]] = None,
43
+ files: Optional[dict[str, object]] = None,
42
44
  assert_no_errors: Optional[bool] = True,
43
45
  ) -> Union[Coroutine[Any, Any, Response], Response]:
44
46
  body = self._build_body(query, variables, files)
@@ -71,19 +73,19 @@ class BaseGraphQLTestClient(ABC):
71
73
  @abstractmethod
72
74
  def request(
73
75
  self,
74
- body: Dict[str, object],
75
- headers: Optional[Dict[str, object]] = None,
76
- files: Optional[Dict[str, object]] = None,
76
+ body: dict[str, object],
77
+ headers: Optional[dict[str, object]] = None,
78
+ files: Optional[dict[str, object]] = None,
77
79
  ) -> Any:
78
80
  raise NotImplementedError
79
81
 
80
82
  def _build_body(
81
83
  self,
82
84
  query: str,
83
- variables: Optional[Dict[str, Mapping]] = None,
84
- files: Optional[Dict[str, object]] = None,
85
- ) -> Dict[str, object]:
86
- body: Dict[str, object] = {"query": query}
85
+ variables: Optional[dict[str, Mapping]] = None,
86
+ files: Optional[dict[str, object]] = None,
87
+ ) -> dict[str, object]:
88
+ body: dict[str, object] = {"query": query}
87
89
 
88
90
  if variables:
89
91
  body["variables"] = variables
@@ -103,8 +105,8 @@ class BaseGraphQLTestClient(ABC):
103
105
 
104
106
  @staticmethod
105
107
  def _build_multipart_file_map(
106
- variables: Dict[str, Mapping], files: Dict[str, object]
107
- ) -> Dict[str, List[str]]:
108
+ variables: dict[str, Mapping], files: dict[str, object]
109
+ ) -> dict[str, list[str]]:
108
110
  """Creates the file mapping between the variables and the files objects passed as key arguments.
109
111
 
110
112
  Args:
@@ -158,7 +160,7 @@ class BaseGraphQLTestClient(ABC):
158
160
  # }
159
161
  ```
160
162
  """
161
- map: Dict[str, List[str]] = {}
163
+ map: dict[str, list[str]] = {}
162
164
  for key, values in variables.items():
163
165
  reference = key
164
166
  variable_values = values
@@ -195,4 +197,4 @@ class BaseGraphQLTestClient(ABC):
195
197
  return response.json()
196
198
 
197
199
 
198
- __all__ = ["BaseGraphQLTestClient", "Response", "Body"]
200
+ __all__ = ["BaseGraphQLTestClient", "Body", "Response"]