strawberry-graphql 0.255.0__py3-none-any.whl → 0.256.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strawberry/__init__.py +9 -9
- strawberry/aiohttp/test/client.py +10 -8
- strawberry/aiohttp/views.py +5 -7
- strawberry/annotation.py +12 -15
- strawberry/asgi/__init__.py +3 -6
- strawberry/asgi/test/client.py +9 -8
- strawberry/chalice/views.py +4 -2
- strawberry/channels/__init__.py +1 -1
- strawberry/channels/handlers/base.py +3 -7
- strawberry/channels/handlers/http_handler.py +5 -6
- strawberry/channels/handlers/ws_handler.py +3 -4
- strawberry/channels/testing.py +5 -9
- strawberry/cli/commands/codegen.py +9 -9
- strawberry/cli/commands/upgrade/__init__.py +2 -3
- strawberry/cli/commands/upgrade/_run_codemod.py +7 -5
- strawberry/codegen/exceptions.py +2 -2
- strawberry/codegen/plugins/print_operation.py +6 -6
- strawberry/codegen/plugins/python.py +6 -6
- strawberry/codegen/plugins/typescript.py +3 -3
- strawberry/codegen/query_codegen.py +29 -34
- strawberry/codegen/types.py +35 -34
- strawberry/codemods/annotated_unions.py +5 -2
- strawberry/dataloader.py +13 -20
- strawberry/directive.py +12 -5
- strawberry/django/test/client.py +4 -4
- strawberry/django/views.py +4 -5
- strawberry/exceptions/__init__.py +24 -24
- strawberry/exceptions/conflicting_arguments.py +2 -2
- strawberry/exceptions/duplicated_type_name.py +3 -3
- strawberry/exceptions/handler.py +7 -7
- strawberry/exceptions/invalid_union_type.py +2 -2
- strawberry/exceptions/missing_arguments_annotations.py +2 -2
- strawberry/exceptions/missing_field_annotation.py +2 -2
- strawberry/exceptions/object_is_not_an_enum.py +2 -2
- strawberry/exceptions/private_strawberry_field.py +2 -2
- strawberry/exceptions/syntax.py +4 -4
- strawberry/exceptions/utils/source_finder.py +7 -6
- strawberry/experimental/pydantic/__init__.py +3 -3
- strawberry/experimental/pydantic/_compat.py +14 -14
- strawberry/experimental/pydantic/conversion.py +2 -2
- strawberry/experimental/pydantic/conversion_types.py +3 -3
- strawberry/experimental/pydantic/error_type.py +18 -16
- strawberry/experimental/pydantic/exceptions.py +5 -5
- strawberry/experimental/pydantic/fields.py +2 -13
- strawberry/experimental/pydantic/object_type.py +20 -22
- strawberry/experimental/pydantic/utils.py +6 -10
- strawberry/ext/dataclasses/dataclasses.py +3 -3
- strawberry/ext/mypy_plugin.py +6 -9
- strawberry/extensions/__init__.py +7 -8
- strawberry/extensions/add_validation_rules.py +5 -3
- strawberry/extensions/base_extension.py +4 -4
- strawberry/extensions/context.py +15 -14
- strawberry/extensions/directives.py +2 -2
- strawberry/extensions/disable_validation.py +1 -1
- strawberry/extensions/field_extension.py +2 -1
- strawberry/extensions/mask_errors.py +3 -2
- strawberry/extensions/max_aliases.py +2 -2
- strawberry/extensions/max_tokens.py +1 -1
- strawberry/extensions/parser_cache.py +2 -1
- strawberry/extensions/pyinstrument.py +5 -2
- strawberry/extensions/query_depth_limiter.py +13 -13
- strawberry/extensions/runner.py +7 -7
- strawberry/extensions/tracing/apollo.py +11 -9
- strawberry/extensions/tracing/datadog.py +3 -1
- strawberry/extensions/tracing/opentelemetry.py +7 -10
- strawberry/extensions/utils.py +3 -3
- strawberry/extensions/validation_cache.py +2 -1
- strawberry/fastapi/context.py +3 -3
- strawberry/fastapi/router.py +9 -14
- strawberry/federation/__init__.py +4 -4
- strawberry/federation/argument.py +2 -1
- strawberry/federation/enum.py +8 -8
- strawberry/federation/field.py +25 -28
- strawberry/federation/object_type.py +24 -26
- strawberry/federation/scalar.py +7 -8
- strawberry/federation/schema.py +30 -36
- strawberry/federation/schema_directive.py +5 -5
- strawberry/federation/schema_directives.py +14 -14
- strawberry/federation/union.py +3 -2
- strawberry/field_extensions/input_mutation.py +1 -2
- strawberry/file_uploads/utils.py +4 -3
- strawberry/flask/views.py +3 -2
- strawberry/http/__init__.py +6 -6
- strawberry/http/async_base_view.py +9 -14
- strawberry/http/base.py +5 -4
- strawberry/http/ides.py +1 -1
- strawberry/http/parse_content_type.py +1 -2
- strawberry/http/sync_base_view.py +3 -5
- strawberry/http/temporal_response.py +1 -2
- strawberry/http/types.py +3 -2
- strawberry/litestar/controller.py +8 -14
- strawberry/parent.py +1 -2
- strawberry/permission.py +6 -8
- strawberry/printer/ast_from_value.py +2 -1
- strawberry/printer/printer.py +50 -30
- strawberry/quart/views.py +3 -3
- strawberry/relay/exceptions.py +4 -4
- strawberry/relay/fields.py +22 -24
- strawberry/relay/types.py +29 -27
- strawberry/relay/utils.py +4 -4
- strawberry/sanic/utils.py +4 -4
- strawberry/sanic/views.py +5 -7
- strawberry/scalars.py +2 -2
- strawberry/schema/base.py +16 -11
- strawberry/schema/compat.py +4 -4
- strawberry/schema/execute.py +6 -10
- strawberry/schema/name_converter.py +3 -3
- strawberry/schema/schema.py +37 -25
- strawberry/schema/schema_converter.py +22 -24
- strawberry/schema/subscribe.py +4 -3
- strawberry/schema/types/base_scalars.py +1 -1
- strawberry/schema/types/concrete_type.py +2 -2
- strawberry/schema/types/scalar.py +3 -4
- strawberry/schema_codegen/__init__.py +4 -4
- strawberry/schema_directive.py +8 -8
- strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +8 -9
- strawberry/subscriptions/protocols/graphql_transport_ws/types.py +16 -16
- strawberry/subscriptions/protocols/graphql_ws/handlers.py +6 -5
- strawberry/subscriptions/protocols/graphql_ws/types.py +13 -13
- strawberry/test/__init__.py +1 -1
- strawberry/test/client.py +21 -19
- strawberry/tools/create_type.py +4 -3
- strawberry/tools/merge_types.py +1 -2
- strawberry/types/__init__.py +1 -1
- strawberry/types/arguments.py +10 -12
- strawberry/types/auto.py +2 -2
- strawberry/types/base.py +17 -21
- strawberry/types/enum.py +3 -5
- strawberry/types/execution.py +8 -12
- strawberry/types/field.py +26 -31
- strawberry/types/fields/resolver.py +15 -17
- strawberry/types/graphql.py +2 -2
- strawberry/types/info.py +5 -9
- strawberry/types/lazy_type.py +3 -5
- strawberry/types/mutation.py +25 -28
- strawberry/types/nodes.py +11 -9
- strawberry/types/object_type.py +14 -16
- strawberry/types/private.py +1 -2
- strawberry/types/scalar.py +2 -2
- strawberry/types/type_resolver.py +5 -5
- strawberry/types/union.py +8 -11
- strawberry/types/unset.py +3 -3
- strawberry/utils/aio.py +3 -8
- strawberry/utils/await_maybe.py +3 -2
- strawberry/utils/debug.py +2 -2
- strawberry/utils/deprecations.py +2 -2
- strawberry/utils/inspect.py +3 -5
- strawberry/utils/str_converters.py +1 -1
- strawberry/utils/typing.py +38 -67
- {strawberry_graphql-0.255.0.dist-info → strawberry_graphql-0.256.1.dist-info}/METADATA +3 -6
- strawberry_graphql-0.256.1.dist-info/RECORD +236 -0
- strawberry_graphql-0.255.0.dist-info/RECORD +0 -236
- {strawberry_graphql-0.255.0.dist-info → strawberry_graphql-0.256.1.dist-info}/LICENSE +0 -0
- {strawberry_graphql-0.255.0.dist-info → strawberry_graphql-0.256.1.dist-info}/WHEEL +0 -0
- {strawberry_graphql-0.255.0.dist-info → strawberry_graphql-0.256.1.dist-info}/entry_points.txt +0 -0
@@ -6,14 +6,9 @@ from functools import partial, reduce
|
|
6
6
|
from typing import (
|
7
7
|
TYPE_CHECKING,
|
8
8
|
Any,
|
9
|
-
Awaitable,
|
10
9
|
Callable,
|
11
|
-
Dict,
|
12
10
|
Generic,
|
13
|
-
List,
|
14
11
|
Optional,
|
15
|
-
Tuple,
|
16
|
-
Type,
|
17
12
|
TypeVar,
|
18
13
|
Union,
|
19
14
|
cast,
|
@@ -76,6 +71,8 @@ from . import compat
|
|
76
71
|
from .types.concrete_type import ConcreteType
|
77
72
|
|
78
73
|
if TYPE_CHECKING:
|
74
|
+
from collections.abc import Awaitable
|
75
|
+
|
79
76
|
from graphql import (
|
80
77
|
GraphQLInputType,
|
81
78
|
GraphQLNullableType,
|
@@ -111,8 +108,8 @@ def _get_thunk_mapping(
|
|
111
108
|
type_definition: StrawberryObjectDefinition,
|
112
109
|
name_converter: Callable[[StrawberryField], str],
|
113
110
|
field_converter: FieldConverterProtocol[FieldType],
|
114
|
-
get_fields: Callable[[StrawberryObjectDefinition],
|
115
|
-
) ->
|
111
|
+
get_fields: Callable[[StrawberryObjectDefinition], list[StrawberryField]],
|
112
|
+
) -> dict[str, FieldType]:
|
116
113
|
"""Create a GraphQL core `ThunkMapping` mapping of field names to field types.
|
117
114
|
|
118
115
|
This method filters out remaining `strawberry.Private` annotated fields that
|
@@ -124,7 +121,7 @@ def _get_thunk_mapping(
|
|
124
121
|
Raises:
|
125
122
|
TypeError: If the type of a field in ``fields`` is `UNRESOLVED`
|
126
123
|
"""
|
127
|
-
thunk_mapping:
|
124
|
+
thunk_mapping: dict[str, FieldType] = {}
|
128
125
|
|
129
126
|
fields = get_fields(type_definition)
|
130
127
|
|
@@ -173,7 +170,7 @@ class CustomGraphQLEnumType(GraphQLEnumType):
|
|
173
170
|
return self.wrapped_cls(super().parse_value(input_value))
|
174
171
|
|
175
172
|
def parse_literal(
|
176
|
-
self, value_node: ValueNode, _variables: Optional[
|
173
|
+
self, value_node: ValueNode, _variables: Optional[dict[str, Any]] = None
|
177
174
|
) -> Any:
|
178
175
|
return self.wrapped_cls(super().parse_literal(value_node, _variables))
|
179
176
|
|
@@ -185,8 +182,8 @@ def get_arguments(
|
|
185
182
|
info: Info,
|
186
183
|
kwargs: Any,
|
187
184
|
config: StrawberryConfig,
|
188
|
-
scalar_registry:
|
189
|
-
) ->
|
185
|
+
scalar_registry: dict[object, Union[ScalarWrapper, ScalarDefinition]],
|
186
|
+
) -> tuple[list[Any], dict[str, Any]]:
|
190
187
|
# TODO: An extension might have changed the resolver arguments,
|
191
188
|
# but we need them here since we are calling it.
|
192
189
|
# This is a bit of a hack, but it's the easiest way to get the arguments
|
@@ -242,10 +239,10 @@ class GraphQLCoreConverter:
|
|
242
239
|
def __init__(
|
243
240
|
self,
|
244
241
|
config: StrawberryConfig,
|
245
|
-
scalar_registry:
|
246
|
-
get_fields: Callable[[StrawberryObjectDefinition],
|
242
|
+
scalar_registry: dict[object, Union[ScalarWrapper, ScalarDefinition]],
|
243
|
+
get_fields: Callable[[StrawberryObjectDefinition], list[StrawberryField]],
|
247
244
|
) -> None:
|
248
|
-
self.type_map:
|
245
|
+
self.type_map: dict[str, ConcreteType] = {}
|
249
246
|
self.config = config
|
250
247
|
self.scalar_registry = scalar_registry
|
251
248
|
self.get_fields = get_fields
|
@@ -329,13 +326,14 @@ class GraphQLCoreConverter:
|
|
329
326
|
},
|
330
327
|
)
|
331
328
|
|
332
|
-
def from_schema_directive(self, cls:
|
329
|
+
def from_schema_directive(self, cls: type) -> GraphQLDirective:
|
333
330
|
strawberry_directive = cast(
|
334
|
-
"StrawberrySchemaDirective",
|
331
|
+
"StrawberrySchemaDirective",
|
332
|
+
cls.__strawberry_directive__, # type: ignore[attr-defined]
|
335
333
|
)
|
336
334
|
module = sys.modules[cls.__module__]
|
337
335
|
|
338
|
-
args:
|
336
|
+
args: dict[str, GraphQLArgument] = {}
|
339
337
|
for field in strawberry_directive.fields:
|
340
338
|
default = field.default
|
341
339
|
if default == dataclasses.MISSING:
|
@@ -436,7 +434,7 @@ class GraphQLCoreConverter:
|
|
436
434
|
|
437
435
|
def get_graphql_fields(
|
438
436
|
self, type_definition: StrawberryObjectDefinition
|
439
|
-
) ->
|
437
|
+
) -> dict[str, GraphQLField]:
|
440
438
|
return _get_thunk_mapping(
|
441
439
|
type_definition=type_definition,
|
442
440
|
name_converter=self.config.name_converter.from_field,
|
@@ -446,7 +444,7 @@ class GraphQLCoreConverter:
|
|
446
444
|
|
447
445
|
def get_graphql_input_fields(
|
448
446
|
self, type_definition: StrawberryObjectDefinition
|
449
|
-
) ->
|
447
|
+
) -> dict[str, GraphQLInputField]:
|
450
448
|
return _get_thunk_mapping(
|
451
449
|
type_definition=type_definition,
|
452
450
|
name_converter=self.config.name_converter.from_field,
|
@@ -672,8 +670,8 @@ class GraphQLCoreConverter:
|
|
672
670
|
def _get_result(
|
673
671
|
_source: Any,
|
674
672
|
info: Info,
|
675
|
-
field_args:
|
676
|
-
field_kwargs:
|
673
|
+
field_args: list[Any],
|
674
|
+
field_kwargs: dict[str, Any],
|
677
675
|
) -> Any:
|
678
676
|
return field.get_result(
|
679
677
|
_source, info=info, args=field_args, kwargs=field_kwargs
|
@@ -762,7 +760,7 @@ class GraphQLCoreConverter:
|
|
762
760
|
_resolver._is_default = not field.base_resolver # type: ignore
|
763
761
|
return _resolver
|
764
762
|
|
765
|
-
def from_scalar(self, scalar:
|
763
|
+
def from_scalar(self, scalar: type) -> GraphQLScalarType:
|
766
764
|
scalar_definition: ScalarDefinition
|
767
765
|
|
768
766
|
if scalar in self.scalar_registry:
|
@@ -773,7 +771,7 @@ class GraphQLCoreConverter:
|
|
773
771
|
else:
|
774
772
|
scalar_definition = _scalar_definition
|
775
773
|
else:
|
776
|
-
scalar_definition = scalar._scalar_definition
|
774
|
+
scalar_definition = scalar._scalar_definition # type: ignore[attr-defined]
|
777
775
|
|
778
776
|
scalar_name = self.config.name_converter.from_type(scalar_definition)
|
779
777
|
|
@@ -864,7 +862,7 @@ class GraphQLCoreConverter:
|
|
864
862
|
assert isinstance(graphql_union, GraphQLUnionType) # For mypy
|
865
863
|
return graphql_union
|
866
864
|
|
867
|
-
graphql_types:
|
865
|
+
graphql_types: list[GraphQLObjectType] = []
|
868
866
|
for type_ in union.types:
|
869
867
|
graphql_type = self.from_type(type_)
|
870
868
|
|
strawberry/schema/subscribe.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from
|
3
|
+
from collections.abc import AsyncGenerator, AsyncIterator
|
4
|
+
from typing import TYPE_CHECKING, Optional, Union
|
4
5
|
|
5
6
|
from graphql import (
|
6
7
|
ExecutionResult as OriginalExecutionResult,
|
@@ -44,7 +45,7 @@ async def _subscribe(
|
|
44
45
|
extensions_runner: SchemaExtensionsRunner,
|
45
46
|
process_errors: ProcessErrors,
|
46
47
|
middleware_manager: MiddlewareManager,
|
47
|
-
execution_context_class: Optional[
|
48
|
+
execution_context_class: Optional[type[GraphQLExecutionContext]] = None,
|
48
49
|
) -> AsyncGenerator[Union[PreExecutionError, ExecutionResult], None]:
|
49
50
|
async with extensions_runner.operation():
|
50
51
|
if initial_error := await _parse_and_validate_async(
|
@@ -128,7 +129,7 @@ async def subscribe(
|
|
128
129
|
extensions_runner: SchemaExtensionsRunner,
|
129
130
|
process_errors: ProcessErrors,
|
130
131
|
middleware_manager: MiddlewareManager,
|
131
|
-
execution_context_class: Optional[
|
132
|
+
execution_context_class: Optional[type[GraphQLExecutionContext]] = None,
|
132
133
|
) -> SubscriptionResult:
|
133
134
|
asyncgen = _subscribe(
|
134
135
|
schema,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
-
from typing import TYPE_CHECKING,
|
4
|
+
from typing import TYPE_CHECKING, Union
|
5
5
|
|
6
6
|
from graphql import GraphQLField, GraphQLInputField, GraphQLType
|
7
7
|
|
@@ -22,7 +22,7 @@ class ConcreteType:
|
|
22
22
|
implementation: GraphQLType
|
23
23
|
|
24
24
|
|
25
|
-
TypeMap =
|
25
|
+
TypeMap = dict[str, ConcreteType]
|
26
26
|
|
27
27
|
|
28
28
|
__all__ = ["ConcreteType", "Field", "GraphQLType", "TypeMap"]
|
@@ -1,6 +1,5 @@
|
|
1
1
|
import datetime
|
2
2
|
import decimal
|
3
|
-
from typing import Dict, Type
|
4
3
|
from uuid import UUID
|
5
4
|
|
6
5
|
from graphql import (
|
@@ -45,11 +44,11 @@ def _make_scalar_definition(scalar_type: GraphQLScalarType) -> ScalarDefinition:
|
|
45
44
|
)
|
46
45
|
|
47
46
|
|
48
|
-
def _get_scalar_definition(scalar:
|
49
|
-
return scalar._scalar_definition
|
47
|
+
def _get_scalar_definition(scalar: type) -> ScalarDefinition:
|
48
|
+
return scalar._scalar_definition # type: ignore[attr-defined]
|
50
49
|
|
51
50
|
|
52
|
-
DEFAULT_SCALAR_REGISTRY:
|
51
|
+
DEFAULT_SCALAR_REGISTRY: dict[object, ScalarDefinition] = {
|
53
52
|
type(None): _get_scalar_definition(base_scalars.Void),
|
54
53
|
None: _get_scalar_definition(base_scalars.Void),
|
55
54
|
str: _make_scalar_definition(GraphQLString),
|
@@ -3,11 +3,11 @@ from __future__ import annotations
|
|
3
3
|
import dataclasses
|
4
4
|
import keyword
|
5
5
|
from collections import defaultdict
|
6
|
-
from
|
6
|
+
from graphlib import TopologicalSorter
|
7
|
+
from typing import TYPE_CHECKING, Union
|
7
8
|
from typing_extensions import Protocol, TypeAlias
|
8
9
|
|
9
10
|
import libcst as cst
|
10
|
-
from graphlib import TopologicalSorter
|
11
11
|
from graphql import (
|
12
12
|
EnumTypeDefinitionNode,
|
13
13
|
EnumValueDefinitionNode,
|
@@ -42,7 +42,7 @@ if TYPE_CHECKING:
|
|
42
42
|
|
43
43
|
|
44
44
|
class HasDirectives(Protocol):
|
45
|
-
directives:
|
45
|
+
directives: tuple[ConstDirectiveNode, ...]
|
46
46
|
|
47
47
|
|
48
48
|
_SCALAR_MAP = {
|
@@ -256,7 +256,7 @@ def _get_field(
|
|
256
256
|
)
|
257
257
|
|
258
258
|
|
259
|
-
ArgumentValue: TypeAlias = Union[str, bool,
|
259
|
+
ArgumentValue: TypeAlias = Union[str, bool, list["ArgumentValue"]]
|
260
260
|
|
261
261
|
|
262
262
|
def _get_argument_value(argument_value: ConstValueNode) -> ArgumentValue:
|
strawberry/schema_directive.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import dataclasses
|
2
2
|
from enum import Enum
|
3
|
-
from typing import Callable,
|
3
|
+
from typing import Callable, Optional, TypeVar
|
4
4
|
from typing_extensions import dataclass_transform
|
5
5
|
|
6
6
|
from strawberry.types.field import StrawberryField, field
|
@@ -28,15 +28,15 @@ class Location(Enum):
|
|
28
28
|
class StrawberrySchemaDirective:
|
29
29
|
python_name: str
|
30
30
|
graphql_name: Optional[str]
|
31
|
-
locations:
|
32
|
-
fields:
|
31
|
+
locations: list[Location]
|
32
|
+
fields: list["StrawberryField"]
|
33
33
|
description: Optional[str] = None
|
34
34
|
repeatable: bool = False
|
35
35
|
print_definition: bool = True
|
36
|
-
origin: Optional[
|
36
|
+
origin: Optional[type] = None
|
37
37
|
|
38
38
|
|
39
|
-
T = TypeVar("T", bound=
|
39
|
+
T = TypeVar("T", bound=type)
|
40
40
|
|
41
41
|
|
42
42
|
@dataclass_transform(
|
@@ -46,17 +46,17 @@ T = TypeVar("T", bound=Type)
|
|
46
46
|
)
|
47
47
|
def schema_directive(
|
48
48
|
*,
|
49
|
-
locations:
|
49
|
+
locations: list[Location],
|
50
50
|
description: Optional[str] = None,
|
51
51
|
name: Optional[str] = None,
|
52
52
|
repeatable: bool = False,
|
53
53
|
print_definition: bool = True,
|
54
|
-
) -> Callable[
|
54
|
+
) -> Callable[[T], T]:
|
55
55
|
def _wrap(cls: T) -> T:
|
56
56
|
cls = _wrap_dataclass(cls) # type: ignore
|
57
57
|
fields = _get_fields(cls, {})
|
58
58
|
|
59
|
-
cls.__strawberry_directive__ = StrawberrySchemaDirective(
|
59
|
+
cls.__strawberry_directive__ = StrawberrySchemaDirective( # type: ignore[attr-defined]
|
60
60
|
python_name=cls.__name__,
|
61
61
|
graphql_name=name,
|
62
62
|
locations=locations,
|
@@ -2,14 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import asyncio
|
4
4
|
import logging
|
5
|
+
from collections.abc import Awaitable
|
5
6
|
from contextlib import suppress
|
6
7
|
from typing import (
|
7
8
|
TYPE_CHECKING,
|
8
9
|
Any,
|
9
|
-
Awaitable,
|
10
|
-
Dict,
|
11
10
|
Generic,
|
12
|
-
List,
|
13
11
|
Optional,
|
14
12
|
cast,
|
15
13
|
)
|
@@ -40,6 +38,7 @@ from strawberry.utils.debug import pretty_print_graphql_operation
|
|
40
38
|
from strawberry.utils.operation import get_operation_type
|
41
39
|
|
42
40
|
if TYPE_CHECKING:
|
41
|
+
from collections.abc import Awaitable
|
43
42
|
from datetime import timedelta
|
44
43
|
|
45
44
|
from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter
|
@@ -71,8 +70,8 @@ class BaseGraphQLTransportWSHandler(Generic[Context, RootValue]):
|
|
71
70
|
self.connection_init_received = False
|
72
71
|
self.connection_acknowledged = False
|
73
72
|
self.connection_timed_out = False
|
74
|
-
self.operations:
|
75
|
-
self.completed_tasks:
|
73
|
+
self.operations: dict[str, Operation[Context, RootValue]] = {}
|
74
|
+
self.completed_tasks: list[asyncio.Task] = []
|
76
75
|
|
77
76
|
async def handle(self) -> None:
|
78
77
|
self.on_request_accepted()
|
@@ -343,14 +342,14 @@ class Operation(Generic[Context, RootValue]):
|
|
343
342
|
"""A class encapsulating a single operation with its id. Helps enforce protocol state transition."""
|
344
343
|
|
345
344
|
__slots__ = [
|
345
|
+
"completed",
|
346
346
|
"handler",
|
347
347
|
"id",
|
348
|
+
"operation_name",
|
348
349
|
"operation_type",
|
349
350
|
"query",
|
350
|
-
"variables",
|
351
|
-
"operation_name",
|
352
|
-
"completed",
|
353
351
|
"task",
|
352
|
+
"variables",
|
354
353
|
]
|
355
354
|
|
356
355
|
def __init__(
|
@@ -359,7 +358,7 @@ class Operation(Generic[Context, RootValue]):
|
|
359
358
|
id: str,
|
360
359
|
operation_type: OperationType,
|
361
360
|
query: str,
|
362
|
-
variables: Optional[
|
361
|
+
variables: Optional[dict[str, object]],
|
363
362
|
operation_name: Optional[str],
|
364
363
|
) -> None:
|
365
364
|
self.handler = handler
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import TypedDict, Union
|
2
2
|
from typing_extensions import Literal, NotRequired
|
3
3
|
|
4
4
|
from graphql import GraphQLFormattedError
|
@@ -8,35 +8,35 @@ class ConnectionInitMessage(TypedDict):
|
|
8
8
|
"""Direction: Client -> Server."""
|
9
9
|
|
10
10
|
type: Literal["connection_init"]
|
11
|
-
payload: NotRequired[Union[
|
11
|
+
payload: NotRequired[Union[dict[str, object], None]]
|
12
12
|
|
13
13
|
|
14
14
|
class ConnectionAckMessage(TypedDict):
|
15
15
|
"""Direction: Server -> Client."""
|
16
16
|
|
17
17
|
type: Literal["connection_ack"]
|
18
|
-
payload: NotRequired[Union[
|
18
|
+
payload: NotRequired[Union[dict[str, object], None]]
|
19
19
|
|
20
20
|
|
21
21
|
class PingMessage(TypedDict):
|
22
22
|
"""Direction: bidirectional."""
|
23
23
|
|
24
24
|
type: Literal["ping"]
|
25
|
-
payload: NotRequired[Union[
|
25
|
+
payload: NotRequired[Union[dict[str, object], None]]
|
26
26
|
|
27
27
|
|
28
28
|
class PongMessage(TypedDict):
|
29
29
|
"""Direction: bidirectional."""
|
30
30
|
|
31
31
|
type: Literal["pong"]
|
32
|
-
payload: NotRequired[Union[
|
32
|
+
payload: NotRequired[Union[dict[str, object], None]]
|
33
33
|
|
34
34
|
|
35
35
|
class SubscribeMessagePayload(TypedDict):
|
36
36
|
operationName: NotRequired[Union[str, None]]
|
37
37
|
query: str
|
38
|
-
variables: NotRequired[Union[
|
39
|
-
extensions: NotRequired[Union[
|
38
|
+
variables: NotRequired[Union[dict[str, object], None]]
|
39
|
+
extensions: NotRequired[Union[dict[str, object], None]]
|
40
40
|
|
41
41
|
|
42
42
|
class SubscribeMessage(TypedDict):
|
@@ -48,9 +48,9 @@ class SubscribeMessage(TypedDict):
|
|
48
48
|
|
49
49
|
|
50
50
|
class NextMessagePayload(TypedDict):
|
51
|
-
errors: NotRequired[
|
52
|
-
data: NotRequired[Union[
|
53
|
-
extensions: NotRequired[
|
51
|
+
errors: NotRequired[list[GraphQLFormattedError]]
|
52
|
+
data: NotRequired[Union[dict[str, object], None]]
|
53
|
+
extensions: NotRequired[dict[str, object]]
|
54
54
|
|
55
55
|
|
56
56
|
class NextMessage(TypedDict):
|
@@ -66,7 +66,7 @@ class ErrorMessage(TypedDict):
|
|
66
66
|
|
67
67
|
id: str
|
68
68
|
type: Literal["error"]
|
69
|
-
payload:
|
69
|
+
payload: list[GraphQLFormattedError]
|
70
70
|
|
71
71
|
|
72
72
|
class CompleteMessage(TypedDict):
|
@@ -89,13 +89,13 @@ Message = Union[
|
|
89
89
|
|
90
90
|
|
91
91
|
__all__ = [
|
92
|
-
"
|
92
|
+
"CompleteMessage",
|
93
93
|
"ConnectionAckMessage",
|
94
|
+
"ConnectionInitMessage",
|
95
|
+
"ErrorMessage",
|
96
|
+
"Message",
|
97
|
+
"NextMessage",
|
94
98
|
"PingMessage",
|
95
99
|
"PongMessage",
|
96
100
|
"SubscribeMessage",
|
97
|
-
"NextMessage",
|
98
|
-
"ErrorMessage",
|
99
|
-
"CompleteMessage",
|
100
|
-
"Message",
|
101
101
|
]
|
@@ -1,12 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
from collections.abc import AsyncGenerator
|
4
5
|
from contextlib import suppress
|
5
6
|
from typing import (
|
6
7
|
TYPE_CHECKING,
|
7
8
|
Any,
|
8
|
-
AsyncGenerator,
|
9
|
-
Dict,
|
10
9
|
Generic,
|
11
10
|
Optional,
|
12
11
|
cast,
|
@@ -28,6 +27,8 @@ from strawberry.types.unset import UnsetType
|
|
28
27
|
from strawberry.utils.debug import pretty_print_graphql_operation
|
29
28
|
|
30
29
|
if TYPE_CHECKING:
|
30
|
+
from collections.abc import AsyncGenerator
|
31
|
+
|
31
32
|
from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter
|
32
33
|
from strawberry.schema import BaseSchema
|
33
34
|
|
@@ -56,8 +57,8 @@ class BaseGraphQLWSHandler(Generic[Context, RootValue]):
|
|
56
57
|
self.keep_alive = keep_alive
|
57
58
|
self.keep_alive_interval = keep_alive_interval
|
58
59
|
self.keep_alive_task: Optional[asyncio.Task] = None
|
59
|
-
self.subscriptions:
|
60
|
-
self.tasks:
|
60
|
+
self.subscriptions: dict[str, AsyncGenerator] = {}
|
61
|
+
self.tasks: dict[str, asyncio.Task] = {}
|
61
62
|
|
62
63
|
async def handle(self) -> None:
|
63
64
|
try:
|
@@ -164,7 +165,7 @@ class BaseGraphQLWSHandler(Generic[Context, RootValue]):
|
|
164
165
|
operation_id: str,
|
165
166
|
query: str,
|
166
167
|
operation_name: Optional[str],
|
167
|
-
variables: Optional[
|
168
|
+
variables: Optional[dict[str, object]],
|
168
169
|
) -> None:
|
169
170
|
try:
|
170
171
|
agen_or_err = await self.schema.subscribe(
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import TypedDict, Union
|
2
2
|
from typing_extensions import Literal, NotRequired
|
3
3
|
|
4
4
|
from graphql import GraphQLFormattedError
|
@@ -6,12 +6,12 @@ from graphql import GraphQLFormattedError
|
|
6
6
|
|
7
7
|
class ConnectionInitMessage(TypedDict):
|
8
8
|
type: Literal["connection_init"]
|
9
|
-
payload: NotRequired[
|
9
|
+
payload: NotRequired[dict[str, object]]
|
10
10
|
|
11
11
|
|
12
12
|
class StartMessagePayload(TypedDict):
|
13
13
|
query: str
|
14
|
-
variables: NotRequired[
|
14
|
+
variables: NotRequired[dict[str, object]]
|
15
15
|
operationName: NotRequired[str]
|
16
16
|
|
17
17
|
|
@@ -32,20 +32,20 @@ class ConnectionTerminateMessage(TypedDict):
|
|
32
32
|
|
33
33
|
class ConnectionErrorMessage(TypedDict):
|
34
34
|
type: Literal["connection_error"]
|
35
|
-
payload: NotRequired[
|
35
|
+
payload: NotRequired[dict[str, object]]
|
36
36
|
|
37
37
|
|
38
38
|
class ConnectionAckMessage(TypedDict):
|
39
39
|
type: Literal["connection_ack"]
|
40
|
-
payload: NotRequired[
|
40
|
+
payload: NotRequired[dict[str, object]]
|
41
41
|
|
42
42
|
|
43
43
|
class DataMessagePayload(TypedDict):
|
44
44
|
data: object
|
45
|
-
errors: NotRequired[
|
45
|
+
errors: NotRequired[list[GraphQLFormattedError]]
|
46
46
|
|
47
47
|
# Non-standard field:
|
48
|
-
extensions: NotRequired[
|
48
|
+
extensions: NotRequired[dict[str, object]]
|
49
49
|
|
50
50
|
|
51
51
|
class DataMessage(TypedDict):
|
@@ -84,15 +84,15 @@ OperationMessage = Union[
|
|
84
84
|
|
85
85
|
|
86
86
|
__all__ = [
|
87
|
+
"CompleteMessage",
|
88
|
+
"ConnectionAckMessage",
|
89
|
+
"ConnectionErrorMessage",
|
87
90
|
"ConnectionInitMessage",
|
88
|
-
"
|
89
|
-
"StopMessage",
|
91
|
+
"ConnectionKeepAliveMessage",
|
90
92
|
"ConnectionTerminateMessage",
|
91
|
-
"ConnectionErrorMessage",
|
92
|
-
"ConnectionAckMessage",
|
93
93
|
"DataMessage",
|
94
94
|
"ErrorMessage",
|
95
|
-
"CompleteMessage",
|
96
|
-
"ConnectionKeepAliveMessage",
|
97
95
|
"OperationMessage",
|
96
|
+
"StartMessage",
|
97
|
+
"StopMessage",
|
98
98
|
]
|
strawberry/test/__init__.py
CHANGED
strawberry/test/client.py
CHANGED
@@ -4,23 +4,25 @@ import json
|
|
4
4
|
import warnings
|
5
5
|
from abc import ABC, abstractmethod
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import TYPE_CHECKING, Any,
|
7
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
8
8
|
from typing_extensions import Literal, TypedDict
|
9
9
|
|
10
10
|
if TYPE_CHECKING:
|
11
|
+
from collections.abc import Coroutine, Mapping
|
12
|
+
|
11
13
|
from graphql import GraphQLFormattedError
|
12
14
|
|
13
15
|
|
14
16
|
@dataclass
|
15
17
|
class Response:
|
16
|
-
errors: Optional[
|
17
|
-
data: Optional[
|
18
|
-
extensions: Optional[
|
18
|
+
errors: Optional[list[GraphQLFormattedError]]
|
19
|
+
data: Optional[dict[str, object]]
|
20
|
+
extensions: Optional[dict[str, object]]
|
19
21
|
|
20
22
|
|
21
23
|
class Body(TypedDict, total=False):
|
22
24
|
query: str
|
23
|
-
variables: Optional[
|
25
|
+
variables: Optional[dict[str, object]]
|
24
26
|
|
25
27
|
|
26
28
|
class BaseGraphQLTestClient(ABC):
|
@@ -35,10 +37,10 @@ class BaseGraphQLTestClient(ABC):
|
|
35
37
|
def query(
|
36
38
|
self,
|
37
39
|
query: str,
|
38
|
-
variables: Optional[
|
39
|
-
headers: Optional[
|
40
|
+
variables: Optional[dict[str, Mapping]] = None,
|
41
|
+
headers: Optional[dict[str, object]] = None,
|
40
42
|
asserts_errors: Optional[bool] = None,
|
41
|
-
files: Optional[
|
43
|
+
files: Optional[dict[str, object]] = None,
|
42
44
|
assert_no_errors: Optional[bool] = True,
|
43
45
|
) -> Union[Coroutine[Any, Any, Response], Response]:
|
44
46
|
body = self._build_body(query, variables, files)
|
@@ -71,19 +73,19 @@ class BaseGraphQLTestClient(ABC):
|
|
71
73
|
@abstractmethod
|
72
74
|
def request(
|
73
75
|
self,
|
74
|
-
body:
|
75
|
-
headers: Optional[
|
76
|
-
files: Optional[
|
76
|
+
body: dict[str, object],
|
77
|
+
headers: Optional[dict[str, object]] = None,
|
78
|
+
files: Optional[dict[str, object]] = None,
|
77
79
|
) -> Any:
|
78
80
|
raise NotImplementedError
|
79
81
|
|
80
82
|
def _build_body(
|
81
83
|
self,
|
82
84
|
query: str,
|
83
|
-
variables: Optional[
|
84
|
-
files: Optional[
|
85
|
-
) ->
|
86
|
-
body:
|
85
|
+
variables: Optional[dict[str, Mapping]] = None,
|
86
|
+
files: Optional[dict[str, object]] = None,
|
87
|
+
) -> dict[str, object]:
|
88
|
+
body: dict[str, object] = {"query": query}
|
87
89
|
|
88
90
|
if variables:
|
89
91
|
body["variables"] = variables
|
@@ -103,8 +105,8 @@ class BaseGraphQLTestClient(ABC):
|
|
103
105
|
|
104
106
|
@staticmethod
|
105
107
|
def _build_multipart_file_map(
|
106
|
-
variables:
|
107
|
-
) ->
|
108
|
+
variables: dict[str, Mapping], files: dict[str, object]
|
109
|
+
) -> dict[str, list[str]]:
|
108
110
|
"""Creates the file mapping between the variables and the files objects passed as key arguments.
|
109
111
|
|
110
112
|
Args:
|
@@ -158,7 +160,7 @@ class BaseGraphQLTestClient(ABC):
|
|
158
160
|
# }
|
159
161
|
```
|
160
162
|
"""
|
161
|
-
map:
|
163
|
+
map: dict[str, list[str]] = {}
|
162
164
|
for key, values in variables.items():
|
163
165
|
reference = key
|
164
166
|
variable_values = values
|
@@ -195,4 +197,4 @@ class BaseGraphQLTestClient(ABC):
|
|
195
197
|
return response.json()
|
196
198
|
|
197
199
|
|
198
|
-
__all__ = ["BaseGraphQLTestClient", "
|
200
|
+
__all__ = ["BaseGraphQLTestClient", "Body", "Response"]
|