strawberry-graphql 0.224.1__py3-none-any.whl → 0.225.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. strawberry/aiohttp/handlers/graphql_transport_ws_handler.py +1 -1
  2. strawberry/aiohttp/handlers/graphql_ws_handler.py +1 -1
  3. strawberry/aiohttp/views.py +2 -2
  4. strawberry/annotation.py +1 -1
  5. strawberry/arguments.py +1 -1
  6. strawberry/asgi/__init__.py +1 -1
  7. strawberry/asgi/handlers/graphql_transport_ws_handler.py +1 -1
  8. strawberry/asgi/handlers/graphql_ws_handler.py +1 -1
  9. strawberry/asgi/test/client.py +1 -1
  10. strawberry/auto.py +5 -5
  11. strawberry/chalice/views.py +2 -2
  12. strawberry/channels/handlers/base.py +1 -1
  13. strawberry/channels/handlers/graphql_transport_ws_handler.py +1 -1
  14. strawberry/channels/handlers/graphql_ws_handler.py +1 -1
  15. strawberry/channels/handlers/http_handler.py +2 -2
  16. strawberry/channels/handlers/ws_handler.py +1 -1
  17. strawberry/channels/router.py +1 -1
  18. strawberry/channels/testing.py +1 -1
  19. strawberry/codegen/query_codegen.py +2 -2
  20. strawberry/custom_scalar.py +2 -2
  21. strawberry/django/__init__.py +3 -1
  22. strawberry/django/context.py +1 -1
  23. strawberry/django/views.py +3 -3
  24. strawberry/exceptions/__init__.py +12 -12
  25. strawberry/exceptions/conflicting_arguments.py +1 -1
  26. strawberry/exceptions/duplicated_type_name.py +1 -1
  27. strawberry/exceptions/handler.py +1 -1
  28. strawberry/exceptions/invalid_argument_type.py +1 -1
  29. strawberry/exceptions/missing_arguments_annotations.py +5 -1
  30. strawberry/exceptions/missing_field_annotation.py +1 -1
  31. strawberry/exceptions/missing_return_annotation.py +5 -1
  32. strawberry/exceptions/not_a_strawberry_enum.py +1 -1
  33. strawberry/exceptions/object_is_not_a_class.py +1 -1
  34. strawberry/exceptions/object_is_not_an_enum.py +1 -1
  35. strawberry/exceptions/permission_fail_silently_requires_optional.py +1 -1
  36. strawberry/exceptions/private_strawberry_field.py +1 -1
  37. strawberry/exceptions/scalar_already_registered.py +1 -1
  38. strawberry/exceptions/unresolved_field_type.py +1 -1
  39. strawberry/experimental/pydantic/_compat.py +1 -1
  40. strawberry/experimental/pydantic/conversion_types.py +1 -1
  41. strawberry/experimental/pydantic/exceptions.py +9 -4
  42. strawberry/ext/mypy_plugin.py +1 -1
  43. strawberry/extensions/__init__.py +2 -1
  44. strawberry/extensions/add_validation_rules.py +1 -1
  45. strawberry/extensions/base_extension.py +1 -1
  46. strawberry/extensions/context.py +67 -61
  47. strawberry/extensions/disable_validation.py +1 -1
  48. strawberry/extensions/mask_errors.py +1 -1
  49. strawberry/extensions/max_aliases.py +2 -2
  50. strawberry/extensions/max_tokens.py +1 -1
  51. strawberry/extensions/parser_cache.py +1 -1
  52. strawberry/extensions/query_depth_limiter.py +2 -2
  53. strawberry/extensions/runner.py +1 -1
  54. strawberry/extensions/tracing/__init__.py +2 -2
  55. strawberry/extensions/tracing/apollo.py +1 -1
  56. strawberry/extensions/tracing/datadog.py +2 -2
  57. strawberry/extensions/tracing/opentelemetry.py +1 -1
  58. strawberry/extensions/tracing/sentry.py +2 -2
  59. strawberry/extensions/validation_cache.py +1 -1
  60. strawberry/fastapi/router.py +47 -5
  61. strawberry/federation/schema.py +1 -1
  62. strawberry/federation/schema_directives.py +1 -1
  63. strawberry/field.py +2 -2
  64. strawberry/flask/views.py +3 -3
  65. strawberry/http/exceptions.py +1 -1
  66. strawberry/lazy_type.py +16 -3
  67. strawberry/litestar/controller.py +2 -2
  68. strawberry/litestar/handlers/graphql_transport_ws_handler.py +1 -1
  69. strawberry/litestar/handlers/graphql_ws_handler.py +1 -1
  70. strawberry/object_type.py +3 -3
  71. strawberry/permission.py +1 -1
  72. strawberry/quart/views.py +2 -2
  73. strawberry/relay/exceptions.py +3 -3
  74. strawberry/relay/types.py +2 -2
  75. strawberry/sanic/views.py +2 -2
  76. strawberry/schema/config.py +1 -1
  77. strawberry/schema/exceptions.py +1 -1
  78. strawberry/schema/schema.py +1 -1
  79. strawberry/schema/schema_converter.py +7 -2
  80. strawberry/schema_codegen/__init__.py +2 -2
  81. strawberry/starlite/controller.py +6 -6
  82. strawberry/starlite/handlers/graphql_transport_ws_handler.py +1 -1
  83. strawberry/starlite/handlers/graphql_ws_handler.py +1 -1
  84. strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +14 -4
  85. strawberry/subscriptions/protocols/graphql_ws/handlers.py +1 -1
  86. strawberry/test/client.py +6 -2
  87. strawberry/type.py +3 -3
  88. strawberry/types/execution.py +1 -1
  89. strawberry/types/fields/resolver.py +3 -3
  90. strawberry/types/types.py +1 -1
  91. strawberry/union.py +2 -2
  92. strawberry/unset.py +2 -2
  93. {strawberry_graphql-0.224.1.dist-info → strawberry_graphql-0.225.0.dist-info}/METADATA +1 -1
  94. {strawberry_graphql-0.224.1.dist-info → strawberry_graphql-0.225.0.dist-info}/RECORD +97 -97
  95. {strawberry_graphql-0.224.1.dist-info → strawberry_graphql-0.225.0.dist-info}/LICENSE +0 -0
  96. {strawberry_graphql-0.224.1.dist-info → strawberry_graphql-0.225.0.dist-info}/WHEEL +0 -0
  97. {strawberry_graphql-0.224.1.dist-info → strawberry_graphql-0.225.0.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import warnings
2
+ from typing import Type
2
3
 
3
4
  from .add_validation_rules import AddValidationRules
4
5
  from .base_extension import LifecycleStep, SchemaExtension
@@ -12,7 +13,7 @@ from .query_depth_limiter import IgnoreContext, QueryDepthLimiter
12
13
  from .validation_cache import ValidationCache
13
14
 
14
15
 
15
- def __getattr__(name: str):
16
+ def __getattr__(name: str) -> Type[SchemaExtension]:
16
17
  if name == "Extension":
17
18
  warnings.warn(
18
19
  (
@@ -38,7 +38,7 @@ class AddValidationRules(SchemaExtension):
38
38
 
39
39
  validation_rules: List[Type[ASTValidationRule]]
40
40
 
41
- def __init__(self, validation_rules: List[Type[ASTValidationRule]]):
41
+ def __init__(self, validation_rules: List[Type[ASTValidationRule]]) -> None:
42
42
  self.validation_rules = validation_rules
43
43
 
44
44
  def on_operation(self) -> Iterator[None]:
@@ -21,7 +21,7 @@ class LifecycleStep(Enum):
21
21
  class SchemaExtension:
22
22
  execution_context: ExecutionContext
23
23
 
24
- def __init__(self, *, execution_context: ExecutionContext):
24
+ def __init__(self, *, execution_context: ExecutionContext) -> None:
25
25
  self.execution_context = execution_context
26
26
 
27
27
  def on_operation(
@@ -2,13 +2,16 @@ from __future__ import annotations
2
2
 
3
3
  import contextlib
4
4
  import inspect
5
+ import types
5
6
  import warnings
6
7
  from asyncio import iscoroutinefunction
7
8
  from typing import (
8
9
  TYPE_CHECKING,
9
10
  Any,
11
+ AsyncContextManager,
10
12
  AsyncIterator,
11
13
  Callable,
14
+ ContextManager,
12
15
  Iterator,
13
16
  List,
14
17
  NamedTuple,
@@ -28,14 +31,20 @@ if TYPE_CHECKING:
28
31
 
29
32
  class WrappedHook(NamedTuple):
30
33
  extension: SchemaExtension
31
- initialized_hook: Union[AsyncIterator[None], Iterator[None]]
34
+ hook: Callable[..., Union[AsyncContextManager[None], ContextManager[None]]]
32
35
  is_async: bool
33
36
 
34
37
 
35
38
  class ExtensionContextManagerBase:
36
- __slots__ = ("hooks", "deprecation_message", "default_hook")
37
-
38
- def __init_subclass__(cls):
39
+ __slots__ = (
40
+ "hooks",
41
+ "deprecation_message",
42
+ "default_hook",
43
+ "async_exit_stack",
44
+ "exit_stack",
45
+ )
46
+
47
+ def __init_subclass__(cls) -> None:
39
48
  cls.DEPRECATION_MESSAGE = (
40
49
  f"Event driven styled extensions for "
41
50
  f"{cls.LEGACY_ENTER} or {cls.LEGACY_EXIT}"
@@ -47,7 +56,7 @@ class ExtensionContextManagerBase:
47
56
  LEGACY_ENTER: str
48
57
  LEGACY_EXIT: str
49
58
 
50
- def __init__(self, extensions: List[SchemaExtension]):
59
+ def __init__(self, extensions: List[SchemaExtension]) -> None:
51
60
  self.hooks: List[WrappedHook] = []
52
61
  self.default_hook: Hook = getattr(SchemaExtension, self.HOOK_NAME)
53
62
  for extension in extensions:
@@ -73,10 +82,20 @@ class ExtensionContextManagerBase:
73
82
 
74
83
  if hook_fn:
75
84
  if inspect.isgeneratorfunction(hook_fn):
76
- return WrappedHook(extension, hook_fn(extension), False)
85
+ context_manager = contextlib.contextmanager(
86
+ types.MethodType(hook_fn, extension)
87
+ )
88
+ return WrappedHook(
89
+ extension=extension, hook=context_manager, is_async=False
90
+ )
77
91
 
78
92
  if inspect.isasyncgenfunction(hook_fn):
79
- return WrappedHook(extension, hook_fn(extension), True)
93
+ context_manager_async = contextlib.asynccontextmanager(
94
+ types.MethodType(hook_fn, extension)
95
+ )
96
+ return WrappedHook(
97
+ extension=extension, hook=context_manager_async, is_async=True
98
+ )
80
99
 
81
100
  if callable(hook_fn):
82
101
  return self.from_callable(extension, hook_fn)
@@ -96,27 +115,31 @@ class ExtensionContextManagerBase:
96
115
  ) -> WrappedHook:
97
116
  if iscoroutinefunction(on_start) or iscoroutinefunction(on_end):
98
117
 
99
- async def iterator():
118
+ @contextlib.asynccontextmanager
119
+ async def iterator() -> AsyncIterator:
100
120
  if on_start:
101
121
  await await_maybe(on_start())
122
+
102
123
  yield
124
+
103
125
  if on_end:
104
126
  await await_maybe(on_end())
105
127
 
106
- hook = iterator()
107
- return WrappedHook(extension, hook, True)
128
+ return WrappedHook(extension=extension, hook=iterator, is_async=True)
108
129
 
109
130
  else:
110
131
 
111
- def iterator():
132
+ @contextlib.contextmanager
133
+ def iterator_async() -> Iterator[None]:
112
134
  if on_start:
113
135
  on_start()
136
+
114
137
  yield
138
+
115
139
  if on_end:
116
140
  on_end()
117
141
 
118
- hook = iterator()
119
- return WrappedHook(extension, hook, False)
142
+ return WrappedHook(extension=extension, hook=iterator_async, is_async=False)
120
143
 
121
144
  @staticmethod
122
145
  def from_callable(
@@ -125,78 +148,61 @@ class ExtensionContextManagerBase:
125
148
  ) -> WrappedHook:
126
149
  if iscoroutinefunction(func):
127
150
 
128
- async def async_iterator():
151
+ @contextlib.asynccontextmanager
152
+ async def iterator() -> AsyncIterator[None]:
129
153
  await func(extension)
130
154
  yield
131
155
 
132
- hook = async_iterator()
133
- return WrappedHook(extension, hook, True)
156
+ return WrappedHook(extension=extension, hook=iterator, is_async=True)
134
157
  else:
135
158
 
136
- def iterator():
159
+ @contextlib.contextmanager
160
+ def iterator() -> Iterator[None]:
137
161
  func(extension)
138
162
  yield
139
163
 
140
- hook = iterator()
141
- return WrappedHook(extension, hook, False)
164
+ return WrappedHook(extension=extension, hook=iterator, is_async=False)
142
165
 
143
- def run_hooks_sync(self, is_exit: bool = False) -> None:
144
- """Run extensions synchronously."""
145
- ctx = (
146
- contextlib.suppress(StopIteration, StopAsyncIteration)
147
- if is_exit
148
- else contextlib.nullcontext()
149
- )
150
- for hook in self.hooks:
151
- with ctx:
152
- if hook.is_async:
153
- raise RuntimeError(
154
- f"SchemaExtension hook {hook.extension}.{self.HOOK_NAME} "
155
- "failed to complete synchronously."
156
- )
157
- else:
158
- hook.initialized_hook.__next__() # type: ignore[union-attr]
159
-
160
- async def run_hooks_async(self, is_exit: bool = False) -> None:
161
- """Run extensions asynchronously with support for sync lifecycle hooks.
162
-
163
- The ``is_exit`` flag is required as a `StopIteration` cannot be raised from
164
- within a coroutine.
165
- """
166
- ctx = (
167
- contextlib.suppress(StopIteration, StopAsyncIteration)
168
- if is_exit
169
- else contextlib.nullcontext()
170
- )
166
+ def __enter__(self) -> None:
167
+ self.exit_stack = contextlib.ExitStack()
171
168
 
172
- for hook in self.hooks:
173
- with ctx:
174
- if hook.is_async:
175
- await hook.initialized_hook.__anext__() # type: ignore[union-attr]
176
- else:
177
- hook.initialized_hook.__next__() # type: ignore[union-attr]
169
+ self.exit_stack.__enter__()
178
170
 
179
- def __enter__(self):
180
- self.run_hooks_sync()
171
+ for hook in self.hooks:
172
+ if hook.is_async:
173
+ raise RuntimeError(
174
+ f"SchemaExtension hook {hook.extension}.{self.HOOK_NAME} "
175
+ "failed to complete synchronously."
176
+ )
177
+ else:
178
+ self.exit_stack.enter_context(hook.hook()) # type: ignore
181
179
 
182
180
  def __exit__(
183
181
  self,
184
182
  exc_type: Optional[Type[BaseException]],
185
183
  exc_val: Optional[BaseException],
186
184
  exc_tb: Optional[TracebackType],
187
- ):
188
- self.run_hooks_sync(is_exit=True)
185
+ ) -> None:
186
+ self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
189
187
 
190
- async def __aenter__(self):
191
- await self.run_hooks_async()
188
+ async def __aenter__(self) -> None:
189
+ self.async_exit_stack = contextlib.AsyncExitStack()
190
+
191
+ await self.async_exit_stack.__aenter__()
192
+
193
+ for hook in self.hooks:
194
+ if hook.is_async:
195
+ await self.async_exit_stack.enter_async_context(hook.hook()) # type: ignore
196
+ else:
197
+ self.async_exit_stack.enter_context(hook.hook()) # type: ignore
192
198
 
193
199
  async def __aexit__(
194
200
  self,
195
201
  exc_type: Optional[Type[BaseException]],
196
202
  exc_val: Optional[BaseException],
197
203
  exc_tb: Optional[TracebackType],
198
- ):
199
- await self.run_hooks_async(is_exit=True)
204
+ ) -> None:
205
+ await self.async_exit_stack.__aexit__(exc_type, exc_val, exc_tb)
200
206
 
201
207
 
202
208
  class OperationContextManager(ExtensionContextManagerBase):
@@ -21,7 +21,7 @@ class DisableValidation(SchemaExtension):
21
21
 
22
22
  """
23
23
 
24
- def __init__(self):
24
+ def __init__(self) -> None:
25
25
  # There aren't any arguments to this extension yet but we might add
26
26
  # some in the future
27
27
  pass
@@ -18,7 +18,7 @@ class MaskErrors(SchemaExtension):
18
18
  self,
19
19
  should_mask_error: Callable[[GraphQLError], bool] = default_should_mask_error,
20
20
  error_message: str = "Unexpected error.",
21
- ):
21
+ ) -> None:
22
22
  self.should_mask_error = should_mask_error
23
23
  self.error_message = error_message
24
24
 
@@ -37,14 +37,14 @@ class MaxAliasesLimiter(AddValidationRules):
37
37
  def __init__(
38
38
  self,
39
39
  max_alias_count: int,
40
- ):
40
+ ) -> None:
41
41
  validator = create_validator(max_alias_count)
42
42
  super().__init__([validator])
43
43
 
44
44
 
45
45
  def create_validator(max_alias_count: int) -> Type[ValidationRule]:
46
46
  class MaxAliasesValidator(ValidationRule):
47
- def __init__(self, validation_context: ValidationContext):
47
+ def __init__(self, validation_context: ValidationContext) -> None:
48
48
  document = validation_context.document
49
49
  def_that_can_contain_alias = (
50
50
  def_
@@ -36,7 +36,7 @@ class MaxTokensLimiter(SchemaExtension):
36
36
  def __init__(
37
37
  self,
38
38
  max_token_count: int,
39
- ):
39
+ ) -> None:
40
40
  self.max_token_count = max_token_count
41
41
 
42
42
  def on_operation(self) -> Iterator[None]:
@@ -30,7 +30,7 @@ class ParserCache(SchemaExtension):
30
30
 
31
31
  """
32
32
 
33
- def __init__(self, maxsize: Optional[int] = None):
33
+ def __init__(self, maxsize: Optional[int] = None) -> None:
34
34
  self.cached_parse_document = lru_cache(maxsize=maxsize)(parse_document)
35
35
 
36
36
  def on_parse(self) -> Iterator[None]:
@@ -114,7 +114,7 @@ class QueryDepthLimiter(AddValidationRules):
114
114
  max_depth: int,
115
115
  callback: Optional[Callable[[Dict[str, int]], None]] = None,
116
116
  should_ignore: Optional[ShouldIgnoreType] = None,
117
- ):
117
+ ) -> None:
118
118
  if should_ignore is not None and not callable(should_ignore):
119
119
  raise TypeError(
120
120
  "The `should_ignore` argument to "
@@ -130,7 +130,7 @@ def create_validator(
130
130
  callback: Optional[Callable[[Dict[str, int]], None]] = None,
131
131
  ) -> Type[ValidationRule]:
132
132
  class DepthLimitValidator(ValidationRule):
133
- def __init__(self, validation_context: ValidationContext):
133
+ def __init__(self, validation_context: ValidationContext) -> None:
134
134
  document = validation_context.document
135
135
  definitions = document.definitions
136
136
 
@@ -28,7 +28,7 @@ class SchemaExtensionsRunner:
28
28
  extensions: Optional[
29
29
  List[Union[Type[SchemaExtension], SchemaExtension]]
30
30
  ] = None,
31
- ):
31
+ ) -> None:
32
32
  self.execution_context = execution_context
33
33
 
34
34
  if not extensions:
@@ -1,5 +1,5 @@
1
1
  import importlib
2
- from typing import TYPE_CHECKING
2
+ from typing import TYPE_CHECKING, Any
3
3
 
4
4
  if TYPE_CHECKING:
5
5
  from .apollo import ApolloTracingExtension, ApolloTracingExtensionSync
@@ -22,7 +22,7 @@ __all__ = [
22
22
  ]
23
23
 
24
24
 
25
- def __getattr__(name: str):
25
+ def __getattr__(name: str) -> Any:
26
26
  if name in {"DatadogTracingExtension", "DatadogTracingExtensionSync"}:
27
27
  return getattr(importlib.import_module(".datadog", __name__), name)
28
28
 
@@ -80,7 +80,7 @@ class ApolloTracingStats:
80
80
 
81
81
 
82
82
  class ApolloTracingExtension(SchemaExtension):
83
- def __init__(self, execution_context: ExecutionContext):
83
+ def __init__(self, execution_context: ExecutionContext) -> None:
84
84
  self._resolver_stats: List[ApolloResolverStats] = []
85
85
  self.execution_context = execution_context
86
86
 
@@ -21,12 +21,12 @@ class DatadogTracingExtension(SchemaExtension):
21
21
  self,
22
22
  *,
23
23
  execution_context: Optional[ExecutionContext] = None,
24
- ):
24
+ ) -> None:
25
25
  if execution_context:
26
26
  self.execution_context = execution_context
27
27
 
28
28
  @cached_property
29
- def _resource_name(self):
29
+ def _resource_name(self) -> str:
30
30
  assert self.execution_context.query
31
31
 
32
32
  query_hash = self.hash_query(self.execution_context.query)
@@ -45,7 +45,7 @@ class OpenTelemetryExtension(SchemaExtension):
45
45
  *,
46
46
  execution_context: Optional[ExecutionContext] = None,
47
47
  arg_filter: Optional[ArgFilter] = None,
48
- ):
48
+ ) -> None:
49
49
  self._arg_filter = arg_filter
50
50
  self._tracer = trace.get_tracer("strawberry")
51
51
  if execution_context:
@@ -22,7 +22,7 @@ class SentryTracingExtension(SchemaExtension):
22
22
  self,
23
23
  *,
24
24
  execution_context: Optional[ExecutionContext] = None,
25
- ):
25
+ ) -> None:
26
26
  warnings.warn(
27
27
  "The Sentry tracing extension is deprecated, please update to sentry>=1.32.0",
28
28
  DeprecationWarning,
@@ -33,7 +33,7 @@ class SentryTracingExtension(SchemaExtension):
33
33
  self.execution_context = execution_context
34
34
 
35
35
  @cached_property
36
- def _resource_name(self):
36
+ def _resource_name(self) -> str:
37
37
  assert self.execution_context.query
38
38
 
39
39
  query_hash = self.hash_query(self.execution_context.query)
@@ -30,7 +30,7 @@ class ValidationCache(SchemaExtension):
30
30
 
31
31
  """
32
32
 
33
- def __init__(self, maxsize: Optional[int] = None):
33
+ def __init__(self, maxsize: Optional[int] = None) -> None:
34
34
  self.cached_validate_document = lru_cache(maxsize=maxsize)(validate_document)
35
35
 
36
36
  def on_validate(self) -> Iterator[None]:
@@ -8,9 +8,12 @@ from typing import (
8
8
  Any,
9
9
  Awaitable,
10
10
  Callable,
11
+ Dict,
12
+ List,
11
13
  Mapping,
12
14
  Optional,
13
15
  Sequence,
16
+ Type,
14
17
  Union,
15
18
  cast,
16
19
  )
@@ -20,12 +23,16 @@ from starlette.background import BackgroundTasks # noqa: TCH002
20
23
  from starlette.requests import HTTPConnection, Request
21
24
  from starlette.responses import (
22
25
  HTMLResponse,
26
+ JSONResponse,
23
27
  PlainTextResponse,
24
28
  Response,
25
29
  )
26
30
  from starlette.websockets import WebSocket
27
31
 
28
- from fastapi import APIRouter, Depends
32
+ from fastapi import APIRouter, Depends, params
33
+ from fastapi.datastructures import Default
34
+ from fastapi.routing import APIRoute
35
+ from fastapi.utils import generate_unique_id
29
36
  from strawberry.exceptions import InvalidCustomContext
30
37
  from strawberry.fastapi.context import BaseContext, CustomContext
31
38
  from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler
@@ -42,7 +49,10 @@ from strawberry.http.typevars import (
42
49
  from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
43
50
 
44
51
  if TYPE_CHECKING:
45
- from starlette.types import ASGIApp
52
+ from enum import Enum
53
+
54
+ from starlette.routing import BaseRoute
55
+ from starlette.types import ASGIApp, Lifespan
46
56
 
47
57
  from strawberry.fastapi.context import MergedContext
48
58
  from strawberry.http import GraphQLHTTPResponse
@@ -52,7 +62,7 @@ if TYPE_CHECKING:
52
62
 
53
63
 
54
64
  class FastAPIRequestAdapter(AsyncHTTPRequestAdapter):
55
- def __init__(self, request: Request):
65
+ def __init__(self, request: Request) -> None:
56
66
  self.request = request
57
67
 
58
68
  @property
@@ -89,7 +99,7 @@ class GraphQLRouter(
89
99
  request_adapter_class = FastAPIRequestAdapter
90
100
 
91
101
  @staticmethod
92
- async def __get_root_value():
102
+ async def __get_root_value() -> None:
93
103
  return None
94
104
 
95
105
  @staticmethod
@@ -160,14 +170,46 @@ class GraphQLRouter(
160
170
  GRAPHQL_WS_PROTOCOL,
161
171
  ),
162
172
  connection_init_wait_timeout: timedelta = timedelta(minutes=1),
173
+ prefix: str = "",
174
+ tags: Optional[List[Union[str, Enum]]] = None,
175
+ dependencies: Optional[Sequence[params.Depends]] = None,
176
+ default_response_class: Type[Response] = Default(JSONResponse),
177
+ responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
178
+ callbacks: Optional[List[BaseRoute]] = None,
179
+ routes: Optional[List[BaseRoute]] = None,
180
+ redirect_slashes: bool = True,
163
181
  default: Optional[ASGIApp] = None,
182
+ dependency_overrides_provider: Optional[Any] = None,
183
+ route_class: Type[APIRoute] = APIRoute,
164
184
  on_startup: Optional[Sequence[Callable[[], Any]]] = None,
165
185
  on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
166
- ):
186
+ lifespan: Optional[Lifespan[Any]] = None,
187
+ deprecated: Optional[bool] = None,
188
+ include_in_schema: bool = True,
189
+ generate_unique_id_function: Callable[[APIRoute], str] = Default(
190
+ generate_unique_id
191
+ ),
192
+ **kwargs: Any,
193
+ ) -> None:
167
194
  super().__init__(
195
+ prefix=prefix,
196
+ tags=tags,
197
+ dependencies=dependencies,
198
+ default_response_class=default_response_class,
199
+ responses=responses,
200
+ callbacks=callbacks,
201
+ routes=routes,
202
+ redirect_slashes=redirect_slashes,
168
203
  default=default,
204
+ dependency_overrides_provider=dependency_overrides_provider,
205
+ route_class=route_class,
169
206
  on_startup=on_startup,
170
207
  on_shutdown=on_shutdown,
208
+ lifespan=lifespan,
209
+ deprecated=deprecated,
210
+ include_in_schema=include_in_schema,
211
+ generate_unique_id_function=generate_unique_id_function,
212
+ **kwargs,
171
213
  )
172
214
  self.schema = schema
173
215
  self.allow_queries_via_get = allow_queries_via_get
@@ -66,7 +66,7 @@ class Schema(BaseSchema):
66
66
  ] = None,
67
67
  schema_directives: Iterable[object] = (),
68
68
  enable_federation_2: bool = False,
69
- ):
69
+ ) -> None:
70
70
  query = self._get_federation_query_type(query)
71
71
 
72
72
  super().__init__(
@@ -92,7 +92,7 @@ class Link:
92
92
  as_: Optional[str] = UNSET,
93
93
  for_: Optional[LinkPurpose] = UNSET,
94
94
  import_: Optional[List[Optional[LinkImport]]] = UNSET,
95
- ):
95
+ ) -> None:
96
96
  self.url = url
97
97
  self.as_ = as_
98
98
  self.for_ = for_
strawberry/field.py CHANGED
@@ -90,7 +90,7 @@ class StrawberryField(dataclasses.Field):
90
90
  deprecation_reason: Optional[str] = None,
91
91
  directives: Sequence[object] = (),
92
92
  extensions: List[FieldExtension] = (), # type: ignore
93
- ):
93
+ ) -> None:
94
94
  # basic fields are fields with no provided resolver
95
95
  is_basic_field = not base_resolver
96
96
 
@@ -242,7 +242,7 @@ class StrawberryField(dataclasses.Field):
242
242
  return self._arguments
243
243
 
244
244
  @arguments.setter
245
- def arguments(self, value: List[StrawberryArgument]):
245
+ def arguments(self, value: List[StrawberryArgument]) -> None:
246
246
  self._arguments = value
247
247
 
248
248
  @property
strawberry/flask/views.py CHANGED
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
22
22
 
23
23
 
24
24
  class FlaskHTTPRequestAdapter(SyncHTTPRequestAdapter):
25
- def __init__(self, request: Request):
25
+ def __init__(self, request: Request) -> None:
26
26
  self.request = request
27
27
 
28
28
  @property
@@ -64,7 +64,7 @@ class BaseGraphQLView:
64
64
  graphiql: Optional[bool] = None,
65
65
  graphql_ide: Optional[GraphQL_IDE] = "graphiql",
66
66
  allow_queries_via_get: bool = True,
67
- ):
67
+ ) -> None:
68
68
  self.schema = schema
69
69
  self.graphiql = graphiql
70
70
  self.allow_queries_via_get = allow_queries_via_get
@@ -119,7 +119,7 @@ class GraphQLView(
119
119
 
120
120
 
121
121
  class AsyncFlaskHTTPRequestAdapter(AsyncHTTPRequestAdapter):
122
- def __init__(self, request: Request):
122
+ def __init__(self, request: Request) -> None:
123
123
  self.request = request
124
124
 
125
125
  @property
@@ -1,4 +1,4 @@
1
1
  class HTTPException(Exception):
2
- def __init__(self, status_code: int, reason: str):
2
+ def __init__(self, status_code: int, reason: str) -> None:
3
3
  self.status_code = status_code
4
4
  self.reason = reason
strawberry/lazy_type.py CHANGED
@@ -4,7 +4,20 @@ import sys
4
4
  import warnings
5
5
  from dataclasses import dataclass
6
6
  from pathlib import Path
7
- from typing import Any, ForwardRef, Generic, Optional, Tuple, Type, TypeVar, cast
7
+ from typing import (
8
+ TYPE_CHECKING,
9
+ Any,
10
+ ForwardRef,
11
+ Generic,
12
+ Optional,
13
+ Tuple,
14
+ Type,
15
+ TypeVar,
16
+ cast,
17
+ )
18
+
19
+ if TYPE_CHECKING:
20
+ from typing_extensions import Self
8
21
 
9
22
  TypeName = TypeVar("TypeName")
10
23
  Module = TypeVar("Module")
@@ -16,7 +29,7 @@ class LazyType(Generic[TypeName, Module]):
16
29
  module: str
17
30
  package: Optional[str] = None
18
31
 
19
- def __class_getitem__(cls, params: Tuple[str, str]):
32
+ def __class_getitem__(cls, params: Tuple[str, str]) -> "Self":
20
33
  warnings.warn(
21
34
  (
22
35
  "LazyType is deprecated, use "
@@ -63,7 +76,7 @@ class LazyType(Generic[TypeName, Module]):
63
76
  # this empty call method allows LazyTypes to be used in generic types
64
77
  # for example: List[LazyType["A", "module"]]
65
78
 
66
- def __call__(self): # pragma: no cover
79
+ def __call__(self) -> None: # pragma: no cover
67
80
  return None
68
81
 
69
82
 
@@ -325,10 +325,10 @@ class GraphQLController(
325
325
  context_ws: Any,
326
326
  root_value: Any,
327
327
  ) -> None:
328
- async def _get_context():
328
+ async def _get_context() -> Any:
329
329
  return context_ws
330
330
 
331
- async def _get_root_value():
331
+ async def _get_root_value() -> Any:
332
332
  return root_value
333
333
 
334
334
  preferred_protocol = self.pick_preferred_protocol(socket)
@@ -20,7 +20,7 @@ class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
20
20
  get_context: Callable,
21
21
  get_root_value: Callable,
22
22
  ws: WebSocket,
23
- ):
23
+ ) -> None:
24
24
  super().__init__(schema, debug, connection_init_wait_timeout)
25
25
  self._get_context = get_context
26
26
  self._get_root_value = get_root_value