strawberry-graphql 0.166.0__py3-none-any.whl → 0.167.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. strawberry/aiohttp/test/client.py +10 -2
  2. strawberry/annotation.py +1 -1
  3. strawberry/asgi/handlers/http_handler.py +3 -2
  4. strawberry/asgi/test/client.py +2 -2
  5. strawberry/chalice/views.py +1 -0
  6. strawberry/channels/context.py +1 -1
  7. strawberry/channels/handlers/base.py +1 -1
  8. strawberry/channels/handlers/graphql_transport_ws_handler.py +1 -1
  9. strawberry/channels/handlers/graphql_ws_handler.py +1 -1
  10. strawberry/channels/handlers/http_handler.py +5 -5
  11. strawberry/channels/handlers/ws_handler.py +4 -4
  12. strawberry/channels/testing.py +1 -1
  13. strawberry/cli/__init__.py +1 -1
  14. strawberry/cli/commands/codegen.py +3 -3
  15. strawberry/cli/commands/export_schema.py +1 -1
  16. strawberry/cli/commands/server.py +1 -1
  17. strawberry/custom_scalar.py +1 -1
  18. strawberry/dataloader.py +9 -9
  19. strawberry/django/context.py +2 -2
  20. strawberry/django/test/client.py +2 -2
  21. strawberry/django/views.py +11 -5
  22. strawberry/exceptions/handler.py +5 -5
  23. strawberry/exceptions/missing_arguments_annotations.py +1 -1
  24. strawberry/experimental/pydantic/conversion.py +7 -5
  25. strawberry/experimental/pydantic/error_type.py +16 -4
  26. strawberry/experimental/pydantic/fields.py +1 -1
  27. strawberry/experimental/pydantic/object_type.py +2 -1
  28. strawberry/experimental/pydantic/utils.py +1 -1
  29. strawberry/ext/dataclasses/dataclasses.py +2 -1
  30. strawberry/ext/mypy_plugin.py +5 -3
  31. strawberry/extensions/context.py +3 -3
  32. strawberry/extensions/tracing/apollo.py +7 -7
  33. strawberry/extensions/tracing/datadog.py +6 -6
  34. strawberry/extensions/tracing/opentelemetry.py +9 -7
  35. strawberry/fastapi/router.py +1 -1
  36. strawberry/federation/scalar.py +1 -1
  37. strawberry/federation/schema.py +8 -1
  38. strawberry/federation/schema_directive.py +2 -2
  39. strawberry/printer/printer.py +1 -1
  40. strawberry/sanic/context.py +1 -0
  41. strawberry/sanic/views.py +3 -1
  42. strawberry/schema/base.py +1 -1
  43. strawberry/schema/name_converter.py +1 -1
  44. strawberry/schema/schema.py +27 -11
  45. strawberry/schema/schema_converter.py +1 -1
  46. strawberry/schema_directive.py +2 -2
  47. strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +2 -2
  48. strawberry/test/client.py +1 -1
  49. strawberry/tools/merge_types.py +3 -1
  50. strawberry/types/fields/resolver.py +1 -1
  51. strawberry/types/info.py +1 -0
  52. strawberry/types/type_resolver.py +1 -1
  53. strawberry/utils/await_maybe.py +3 -3
  54. strawberry/utils/dataclasses.py +1 -1
  55. strawberry/utils/debug.py +1 -1
  56. strawberry/utils/inspect.py +2 -2
  57. strawberry/utils/operation.py +1 -1
  58. strawberry/utils/typing.py +1 -1
  59. {strawberry_graphql-0.166.0.dist-info → strawberry_graphql-0.167.1.dist-info}/METADATA +1 -1
  60. {strawberry_graphql-0.166.0.dist-info → strawberry_graphql-0.167.1.dist-info}/RECORD +63 -63
  61. {strawberry_graphql-0.166.0.dist-info → strawberry_graphql-0.167.1.dist-info}/LICENSE +0 -0
  62. {strawberry_graphql-0.166.0.dist-info → strawberry_graphql-0.167.1.dist-info}/WHEEL +0 -0
  63. {strawberry_graphql-0.166.0.dist-info → strawberry_graphql-0.167.1.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,11 @@
1
- from typing import Dict, Mapping, Optional
1
+ from __future__ import annotations
2
+
3
+ from typing import (
4
+ Any,
5
+ Dict,
6
+ Mapping,
7
+ Optional,
8
+ )
2
9
 
3
10
  from strawberry.test.client import BaseGraphQLTestClient, Response
4
11
 
@@ -33,10 +40,11 @@ class GraphQLTestClient(BaseGraphQLTestClient):
33
40
  body: Dict[str, object],
34
41
  headers: Optional[Dict[str, object]] = None,
35
42
  files: Optional[Dict[str, object]] = None,
36
- ):
43
+ ) -> Any:
37
44
  response = await self._client.post(
38
45
  self.url,
39
46
  json=body if not files else None,
40
47
  data=body if files else None,
41
48
  )
49
+
42
50
  return response
strawberry/annotation.py CHANGED
@@ -115,7 +115,7 @@ class StrawberryAnnotation:
115
115
  # ... raise NotImplementedError(f"Unknown type {evaled_type}")
116
116
  return evaled_type
117
117
 
118
- def set_namespace_from_field(self, field: StrawberryField):
118
+ def set_namespace_from_field(self, field: StrawberryField) -> None:
119
119
  module = sys.modules[field.origin.__module__]
120
120
  self.namespace = module.__dict__
121
121
 
@@ -19,6 +19,7 @@ if TYPE_CHECKING:
19
19
  from starlette.types import Receive, Scope, Send
20
20
 
21
21
  from strawberry.schema import BaseSchema
22
+ from strawberry.types.execution import ExecutionResult
22
23
 
23
24
 
24
25
  class HTTPHandler:
@@ -42,7 +43,7 @@ class HTTPHandler:
42
43
  self.process_result = process_result
43
44
  self.encode_json = encode_json
44
45
 
45
- async def handle(self, scope: Scope, receive: Receive, send: Send):
46
+ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
46
47
  request = Request(scope=scope, receive=receive)
47
48
  root_value = await self.get_root_value(request)
48
49
 
@@ -199,7 +200,7 @@ class HTTPHandler:
199
200
  operation_name: Optional[str] = None,
200
201
  root_value: Any = None,
201
202
  allowed_operation_types: Optional[Iterable[OperationType]] = None,
202
- ):
203
+ ) -> ExecutionResult:
203
204
  if self.debug:
204
205
  pretty_print_graphql_operation(operation_name, query, variables)
205
206
 
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import json
4
- from typing import TYPE_CHECKING, Dict, Mapping, Optional
4
+ from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional
5
5
 
6
6
  from strawberry.test import BaseGraphQLTestClient
7
7
 
@@ -37,7 +37,7 @@ class GraphQLTestClient(BaseGraphQLTestClient):
37
37
  body: Dict[str, object],
38
38
  headers: Optional[Dict[str, object]] = None,
39
39
  files: Optional[Dict[str, object]] = None,
40
- ):
40
+ ) -> Any:
41
41
  return self._client.post(
42
42
  self.url,
43
43
  json=body if not files else None,
@@ -37,6 +37,7 @@ class GraphQLView:
37
37
  "The `render_graphiql` argument is deprecated. "
38
38
  "Use `graphiql` instead.",
39
39
  DeprecationWarning,
40
+ stacklevel=2,
40
41
  )
41
42
  else:
42
43
  self.graphiql = graphiql
@@ -15,7 +15,7 @@ class StrawberryChannelsContext:
15
15
  connection_params: Optional[Dict[str, Any]] = None
16
16
 
17
17
  @property
18
- def ws(self):
18
+ def ws(self) -> "ChannelsConsumer":
19
19
  return self.request
20
20
 
21
21
  def __getitem__(self, item: str) -> Any:
@@ -94,7 +94,7 @@ class ChannelsConsumer(AsyncConsumer):
94
94
  request=request or self, connection_params=connection_params
95
95
  )
96
96
 
97
- async def dispatch(self, message: ChannelsMessage):
97
+ async def dispatch(self, message: ChannelsMessage) -> None:
98
98
  # AsyncConsumer will try to get a function for message["type"] to handle
99
99
  # for both http/websocket types and also for layers communication.
100
100
  # In case the type isn't one of those, pass it to the listen queue so
@@ -54,7 +54,7 @@ class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
54
54
  async def handle_request(self) -> Any:
55
55
  await self._ws.accept(subprotocol=GRAPHQL_TRANSPORT_WS_PROTOCOL)
56
56
 
57
- async def handle_disconnect(self, code):
57
+ async def handle_disconnect(self, code) -> None:
58
58
  for operation_id in list(self.subscriptions.keys()):
59
59
  await self.cleanup_operation(operation_id)
60
60
 
@@ -46,7 +46,7 @@ class GraphQLWSHandler(BaseGraphQLWSHandler):
46
46
  async def handle_request(self) -> Any:
47
47
  await self._ws.accept(subprotocol=GRAPHQL_WS_PROTOCOL)
48
48
 
49
- async def handle_disconnect(self, code):
49
+ async def handle_disconnect(self, code) -> None:
50
50
  if self.keep_alive_task:
51
51
  self.keep_alive_task.cancel()
52
52
  with suppress(BaseException):
@@ -81,7 +81,7 @@ class GraphQLHTTPConsumer(ChannelsConsumer, AsyncHttpConsumer):
81
81
  self.subscriptions_enabled = subscriptions_enabled
82
82
  super().__init__(**kwargs)
83
83
 
84
- async def handle(self, body: bytes):
84
+ async def handle(self, body: bytes) -> None:
85
85
  try:
86
86
  if self.scope["method"] == "GET":
87
87
  result = await self.get(body)
@@ -157,7 +157,7 @@ class GraphQLHTTPConsumer(ChannelsConsumer, AsyncHttpConsumer):
157
157
  async def parse_multipart_body(self, body: bytes) -> GraphQLRequestData:
158
158
  raise ExecutionError("Unable to parse the multipart body")
159
159
 
160
- async def execute(self, request_data: GraphQLRequestData):
160
+ async def execute(self, request_data: GraphQLRequestData) -> GraphQLHTTPResponse:
161
161
  context = await self.get_context()
162
162
  root_value = await self.get_root_value()
163
163
 
@@ -179,11 +179,11 @@ class GraphQLHTTPConsumer(ChannelsConsumer, AsyncHttpConsumer):
179
179
  async def process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse:
180
180
  return process_result(result)
181
181
 
182
- async def render_graphiql(self, body):
182
+ async def render_graphiql(self, body) -> Result:
183
183
  html = get_graphiql_html(self.subscriptions_enabled)
184
184
  return Result(response=html.encode(), content_type="text/html")
185
185
 
186
- def should_render_graphiql(self):
186
+ def should_render_graphiql(self) -> bool:
187
187
  accept_list = self.headers.get("accept", "").split(",")
188
188
  return self.graphiql and any(
189
189
  accepted in accept_list for accepted in ["text/html", "*/*"]
@@ -216,7 +216,7 @@ class SyncGraphQLHTTPConsumer(GraphQLHTTPConsumer):
216
216
  # handlers in a threadpool. Check SyncConsumer's documentation for more info:
217
217
  # https://github.com/django/channels/blob/main/channels/consumer.py#L104
218
218
  @database_sync_to_async
219
- def execute(self, request_data: GraphQLRequestData):
219
+ def execute(self, request_data: GraphQLRequestData) -> GraphQLHTTPResponse:
220
220
  context = self.get_context(self)
221
221
  root_value = self.get_root_value(self)
222
222
 
@@ -69,7 +69,7 @@ class GraphQLWSConsumer(ChannelsWSConsumer):
69
69
  sorted_intersection = sorted(intersection, key=accepted_subprotocols.index)
70
70
  return next(iter(sorted_intersection), None)
71
71
 
72
- async def connect(self):
72
+ async def connect(self) -> None:
73
73
  preferred_protocol = self.pick_preferred_protocol(self.scope["subprotocols"])
74
74
 
75
75
  if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
@@ -98,15 +98,15 @@ class GraphQLWSConsumer(ChannelsWSConsumer):
98
98
  await self._handler.handle()
99
99
  return None
100
100
 
101
- async def receive(self, *args, **kwargs):
101
+ async def receive(self, *args, **kwargs) -> None:
102
102
  # Overriding this so that we can pass the errors to handle_invalid_message
103
103
  try:
104
104
  await super().receive(*args, **kwargs)
105
105
  except ValueError as e:
106
106
  await self._handler.handle_invalid_message(str(e))
107
107
 
108
- async def receive_json(self, content, **kwargs):
108
+ async def receive_json(self, content, **kwargs) -> None:
109
109
  await self._handler.handle_message(content)
110
110
 
111
- async def disconnect(self, code):
111
+ async def disconnect(self, code) -> None:
112
112
  await self._handler.handle_disconnect(code)
@@ -77,7 +77,7 @@ class GraphQLWebsocketCommunicator(WebsocketCommunicator):
77
77
  async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
78
78
  await self.disconnect()
79
79
 
80
- async def gql_init(self):
80
+ async def gql_init(self) -> None:
81
81
  res = await self.connect()
82
82
  if self.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
83
83
  assert res == (True, GRAPHQL_TRANSPORT_WS_PROTOCOL)
@@ -6,7 +6,7 @@ from .commands.server import server as cmd_server
6
6
 
7
7
 
8
8
  @click.group()
9
- def run(): # pragma: no cover
9
+ def run() -> None: # pragma: no cover
10
10
  pass
11
11
 
12
12
 
@@ -87,7 +87,7 @@ class ConsolePlugin(QueryCodegenPlugin):
87
87
  self.output_dir = output_dir
88
88
  self.plugins = plugins
89
89
 
90
- def on_start(self):
90
+ def on_start(self) -> None:
91
91
  click.echo(
92
92
  click.style(
93
93
  "The codegen is experimental. Please submit any bug at "
@@ -107,7 +107,7 @@ class ConsolePlugin(QueryCodegenPlugin):
107
107
  )
108
108
  )
109
109
 
110
- def on_end(self, result: CodegenResult):
110
+ def on_end(self, result: CodegenResult) -> None:
111
111
  self.output_dir.mkdir(parents=True, exist_ok=True)
112
112
  result.write(self.output_dir)
113
113
 
@@ -148,7 +148,7 @@ def codegen(
148
148
  output_dir: Path,
149
149
  selected_plugins: List[str],
150
150
  cli_plugin: Optional[str] = None,
151
- ):
151
+ ) -> None:
152
152
  schema_symbol = load_schema(schema, app_dir)
153
153
 
154
154
  console_plugin = _load_plugin(cli_plugin) if cli_plugin else ConsolePlugin
@@ -17,7 +17,7 @@ from strawberry.printer import print_schema
17
17
  "Works the same as `--app-dir` in uvicorn."
18
18
  ),
19
19
  )
20
- def export_schema(schema: str, app_dir: str):
20
+ def export_schema(schema: str, app_dir: str) -> None:
21
21
  schema_symbol = load_schema(schema, app_dir)
22
22
 
23
23
  print(print_schema(schema_symbol)) # noqa: T201
@@ -38,7 +38,7 @@ from strawberry.cli.utils import load_schema
38
38
  show_default=True,
39
39
  help="Log GraphQL operations",
40
40
  )
41
- def server(schema, host, port, log_level, app_dir, log_operations):
41
+ def server(schema, host, port, log_level, app_dir, log_operations) -> None:
42
42
  sys.path.insert(0, app_dir)
43
43
 
44
44
  try:
@@ -32,7 +32,7 @@ else:
32
32
  _T = TypeVar("_T", bound=type)
33
33
 
34
34
 
35
- def identity(x):
35
+ def identity(x: _T) -> _T:
36
36
  return x
37
37
 
38
38
 
strawberry/dataloader.py CHANGED
@@ -44,7 +44,7 @@ class Batch(Generic[K, T]):
44
44
  tasks: List[LoaderTask] = dataclasses.field(default_factory=list)
45
45
  dispatched: bool = False
46
46
 
47
- def add_task(self, key: Any, future: Future):
47
+ def add_task(self, key: Any, future: Future) -> None:
48
48
  task = LoaderTask[K, T](key, future)
49
49
  self.tasks.append(task)
50
50
 
@@ -71,7 +71,7 @@ class AbstractCache(Generic[K, T], ABC):
71
71
 
72
72
 
73
73
  class DefaultCache(AbstractCache[K, T]):
74
- def __init__(self, cache_key_fn: Optional[Callable[[K], Hashable]] = None):
74
+ def __init__(self, cache_key_fn: Optional[Callable[[K], Hashable]] = None) -> None:
75
75
  self.cache_key_fn: Callable[[K], Hashable] = (
76
76
  cache_key_fn if cache_key_fn is not None else lambda x: x
77
77
  )
@@ -86,7 +86,7 @@ class DefaultCache(AbstractCache[K, T]):
86
86
  def delete(self, key: K) -> None:
87
87
  del self.cache_map[self.cache_key_fn(key)]
88
88
 
89
- def clear(self):
89
+ def clear(self) -> None:
90
90
  self.cache_map.clear()
91
91
 
92
92
 
@@ -169,23 +169,23 @@ class DataLoader(Generic[K, T]):
169
169
  def load_many(self, keys: Iterable[K]) -> Awaitable[List[T]]:
170
170
  return gather(*map(self.load, keys))
171
171
 
172
- def clear(self, key: K):
172
+ def clear(self, key: K) -> None:
173
173
  if self.cache:
174
174
  self.cache_map.delete(key)
175
175
 
176
- def clear_many(self, keys: Iterable[K]):
176
+ def clear_many(self, keys: Iterable[K]) -> None:
177
177
  if self.cache:
178
178
  for key in keys:
179
179
  self.cache_map.delete(key)
180
180
 
181
- def clear_all(self):
181
+ def clear_all(self) -> None:
182
182
  if self.cache:
183
183
  self.cache_map.clear()
184
184
 
185
- def prime(self, key: K, value: T, force: bool = False):
185
+ def prime(self, key: K, value: T, force: bool = False) -> None:
186
186
  self.prime_many({key: value}, force)
187
187
 
188
- def prime_many(self, data: Mapping[K, T], force: bool = False):
188
+ def prime_many(self, data: Mapping[K, T], force: bool = False) -> None:
189
189
  # Populate the cache with the specified values
190
190
  if self.cache:
191
191
  for key, value in data.items():
@@ -231,7 +231,7 @@ def get_current_batch(loader: DataLoader) -> Batch:
231
231
  return loader.batch
232
232
 
233
233
 
234
- def dispatch(loader: DataLoader, batch: Batch):
234
+ def dispatch(loader: DataLoader, batch: Batch) -> None:
235
235
  loader.loop.call_soon(create_task, dispatch_batch(loader, batch))
236
236
 
237
237
 
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import TYPE_CHECKING
4
+ from typing import TYPE_CHECKING, Any
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  from django.http import HttpRequest, HttpResponse
@@ -17,6 +17,6 @@ class StrawberryDjangoContext:
17
17
  # using info.context["request"]
18
18
  return super().__getattribute__(key)
19
19
 
20
- def get(self, key: str):
20
+ def get(self, key: str) -> Any:
21
21
  """Enable .get notation for accessing the request"""
22
22
  return super().__getattribute__(key)
@@ -1,4 +1,4 @@
1
- from typing import Dict, Optional
1
+ from typing import Any, Dict, Optional
2
2
 
3
3
  from strawberry.test import BaseGraphQLTestClient
4
4
 
@@ -9,7 +9,7 @@ class GraphQLTestClient(BaseGraphQLTestClient):
9
9
  body: Dict[str, object],
10
10
  headers: Optional[Dict[str, object]] = None,
11
11
  files: Optional[Dict[str, object]] = None,
12
- ):
12
+ ) -> Any:
13
13
  if files:
14
14
  return self._client.post(
15
15
  self.url, data=body, format="multipart", headers=headers
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  import json
5
5
  import warnings
6
- from typing import TYPE_CHECKING, Any, Dict, Optional, Type
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union
7
7
 
8
8
  from django.core.exceptions import BadRequest, SuspiciousOperation
9
9
  from django.core.serializers.json import DjangoJSONEncoder
@@ -85,6 +85,7 @@ class BaseView(View):
85
85
  warnings.warn(
86
86
  "json_dumps_params is deprecated, override encode_json instead",
87
87
  DeprecationWarning,
88
+ stacklevel=2,
88
89
  )
89
90
 
90
91
  self.json_encoder = DjangoJSONEncoder
@@ -95,6 +96,7 @@ class BaseView(View):
95
96
  warnings.warn(
96
97
  "json_encoder is deprecated, override encode_json instead",
97
98
  DeprecationWarning,
99
+ stacklevel=2,
98
100
  )
99
101
 
100
102
  def parse_body(self, request: HttpRequest) -> Dict[str, Any]:
@@ -139,7 +141,7 @@ class BaseView(View):
139
141
 
140
142
  return parse_request_data(data)
141
143
 
142
- def _render_graphiql(self, request: HttpRequest, context=None):
144
+ def _render_graphiql(self, request: HttpRequest, context=None) -> TemplateResponse:
143
145
  if not self.graphiql:
144
146
  raise Http404()
145
147
 
@@ -204,7 +206,9 @@ class GraphQLView(BaseView):
204
206
  return process_result(result)
205
207
 
206
208
  @method_decorator(csrf_exempt)
207
- def dispatch(self, request, *args, **kwargs):
209
+ def dispatch(
210
+ self, request, *args, **kwargs
211
+ ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponse]:
208
212
  if not self.is_request_allowed(request):
209
213
  return HttpResponseNotAllowed(
210
214
  ["GET", "POST"], "GraphQL only supports GET and POST requests."
@@ -250,7 +254,7 @@ class GraphQLView(BaseView):
250
254
 
251
255
  class AsyncGraphQLView(BaseView):
252
256
  @classonlymethod
253
- def as_view(cls, **initkwargs):
257
+ def as_view(cls, **initkwargs) -> Callable[..., HttpResponse]:
254
258
  # This code tells django that this view is async, see docs here:
255
259
  # https://docs.djangoproject.com/en/3.1/topics/async/#async-views
256
260
 
@@ -259,7 +263,9 @@ class AsyncGraphQLView(BaseView):
259
263
  return view
260
264
 
261
265
  @method_decorator(csrf_exempt)
262
- async def dispatch(self, request, *args, **kwargs):
266
+ async def dispatch(
267
+ self, request, *args, **kwargs
268
+ ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponse]:
263
269
  if not self.is_request_allowed(request):
264
270
  return HttpResponseNotAllowed(
265
271
  ["GET", "POST"], "GraphQL only supports GET and POST requests."
@@ -17,7 +17,7 @@ ExceptionHandler = Callable[
17
17
  ]
18
18
 
19
19
 
20
- def should_use_rich_exceptions():
20
+ def should_use_rich_exceptions() -> bool:
21
21
  errors_disabled = os.environ.get("STRAWBERRY_DISABLE_RICH_ERRORS", "")
22
22
 
23
23
  return errors_disabled.lower() not in ["true", "1", "yes"]
@@ -53,7 +53,7 @@ def strawberry_exception_handler(
53
53
  exception_type: Type[BaseException],
54
54
  exception: BaseException,
55
55
  traceback: Optional[TracebackType],
56
- ):
56
+ ) -> None:
57
57
  _get_handler(exception_type)(exception_type, exception, traceback)
58
58
 
59
59
 
@@ -64,7 +64,7 @@ def strawberry_threading_exception_handler(
64
64
  Optional[TracebackType],
65
65
  Optional[threading.Thread],
66
66
  ]
67
- ):
67
+ ) -> None:
68
68
  (exception_type, exception, traceback, _) = args
69
69
 
70
70
  if exception is None:
@@ -81,14 +81,14 @@ def strawberry_threading_exception_handler(
81
81
  _get_handler(exception_type)(exception_type, exception, traceback)
82
82
 
83
83
 
84
- def reset_exception_handler():
84
+ def reset_exception_handler() -> None:
85
85
  sys.excepthook = sys.__excepthook__
86
86
 
87
87
  if sys.version_info >= (3, 8):
88
88
  threading.excepthook = original_threading_exception_hook
89
89
 
90
90
 
91
- def setup_exception_handler():
91
+ def setup_exception_handler() -> None:
92
92
  if should_use_rich_exceptions():
93
93
  sys.excepthook = strawberry_exception_handler
94
94
 
@@ -39,7 +39,7 @@ class MissingArgumentsAnnotationsError(StrawberryException):
39
39
  self.annotation_message = f"{first}argument missing annotation"
40
40
 
41
41
  @property
42
- def missing_arguments_str(self):
42
+ def missing_arguments_str(self) -> str:
43
43
  arguments = self.missing_arguments
44
44
 
45
45
  if len(arguments) == 1:
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import copy
4
4
  import dataclasses
5
- from typing import TYPE_CHECKING, Union, cast
5
+ from typing import TYPE_CHECKING, Any, Union, cast
6
6
 
7
7
  from strawberry.enum import EnumDefinition
8
8
  from strawberry.type import StrawberryList, StrawberryOptional
@@ -64,12 +64,14 @@ def _convert_from_pydantic_to_strawberry_type(
64
64
  return data
65
65
 
66
66
 
67
- def convert_pydantic_model_to_strawberry_class(cls, *, model_instance=None, extra=None):
67
+ def convert_pydantic_model_to_strawberry_class(
68
+ cls, *, model_instance=None, extra=None
69
+ ) -> Any:
68
70
  extra = extra or {}
69
71
  kwargs = {}
70
72
 
71
- for field in cls._type_definition.fields:
72
- field = cast("StrawberryField", field)
73
+ for field_ in cls._type_definition.fields:
74
+ field = cast("StrawberryField", field_)
73
75
  python_name = field.python_name
74
76
 
75
77
  data_from_extra = extra.get(python_name, None)
@@ -87,7 +89,7 @@ def convert_pydantic_model_to_strawberry_class(cls, *, model_instance=None, extr
87
89
  return cls(**kwargs)
88
90
 
89
91
 
90
- def convert_strawberry_class_to_pydantic_model(obj):
92
+ def convert_strawberry_class_to_pydantic_model(obj) -> Any:
91
93
  if hasattr(obj, "to_pydantic"):
92
94
  return obj.to_pydantic()
93
95
  elif dataclasses.is_dataclass(obj):
@@ -2,7 +2,18 @@ from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
4
  import warnings
5
- from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Type, cast
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Any,
8
+ Callable,
9
+ List,
10
+ Optional,
11
+ Sequence,
12
+ Tuple,
13
+ Type,
14
+ Union,
15
+ cast,
16
+ )
6
17
 
7
18
  from pydantic import BaseModel
8
19
  from pydantic.utils import lenient_issubclass
@@ -23,13 +34,13 @@ if TYPE_CHECKING:
23
34
  from pydantic.fields import ModelField
24
35
 
25
36
 
26
- def get_type_for_field(field: ModelField):
37
+ def get_type_for_field(field: ModelField) -> Union[Any, Type[None], Type[List]]:
27
38
  type_ = field.outer_type_
28
39
  type_ = normalize_type(type_)
29
40
  return field_type_to_type(type_)
30
41
 
31
42
 
32
- def field_type_to_type(type_):
43
+ def field_type_to_type(type_) -> Union[Any, List[Any], None]:
33
44
  error_class: Any = str
34
45
  strawberry_type: Any = error_class
35
46
 
@@ -59,7 +70,7 @@ def error_type(
59
70
  description: Optional[str] = None,
60
71
  directives: Optional[Sequence[object]] = (),
61
72
  all_fields: bool = False,
62
- ):
73
+ ) -> Callable[..., Type]:
63
74
  def wrap(cls):
64
75
  model_fields = model.__fields__
65
76
  fields_set = set(fields) if fields else set()
@@ -68,6 +79,7 @@ def error_type(
68
79
  warnings.warn(
69
80
  "`fields` is deprecated, use `auto` type annotations instead",
70
81
  DeprecationWarning,
82
+ stacklevel=2,
71
83
  )
72
84
 
73
85
  existing_fields = getattr(cls, "__annotations__", {})
@@ -95,7 +95,7 @@ def get_basic_type(type_) -> Type[Any]:
95
95
  return type_
96
96
 
97
97
 
98
- def replace_pydantic_types(type_: Any, is_input: bool):
98
+ def replace_pydantic_types(type_: Any, is_input: bool) -> Any:
99
99
  if lenient_issubclass(type_, BaseModel):
100
100
  attr = "_strawberry_input_type" if is_input else "_strawberry_type"
101
101
  if hasattr(type_, attr):
@@ -40,7 +40,7 @@ if TYPE_CHECKING:
40
40
  from pydantic.fields import ModelField
41
41
 
42
42
 
43
- def get_type_for_field(field: ModelField, is_input: bool):
43
+ def get_type_for_field(field: ModelField, is_input: bool): # noqa: ANN201
44
44
  outer_type = field.outer_type_
45
45
  replaced_type = replace_types_recursively(outer_type, is_input)
46
46
 
@@ -134,6 +134,7 @@ def type(
134
134
  warnings.warn(
135
135
  "`fields` is deprecated, use `auto` type annotations instead",
136
136
  DeprecationWarning,
137
+ stacklevel=2,
137
138
  )
138
139
 
139
140
  existing_fields = getattr(cls, "__annotations__", {})
@@ -46,7 +46,7 @@ def normalize_type(type_) -> Any:
46
46
  return type_
47
47
 
48
48
 
49
- def get_strawberry_type_from_model(type_: Any):
49
+ def get_strawberry_type_from_model(type_: Any) -> Any:
50
50
  if hasattr(type_, "_strawberry_type"):
51
51
  return type_._strawberry_type
52
52
  else:
@@ -11,9 +11,10 @@ from dataclasses import ( # type: ignore
11
11
  _field_init,
12
12
  _init_param,
13
13
  )
14
+ from typing import Any
14
15
 
15
16
 
16
- def dataclass_init_fn(fields, frozen, has_post_init, self_name, globals_):
17
+ def dataclass_init_fn(fields, frozen, has_post_init, self_name, globals_) -> Any:
17
18
  """
18
19
  We create a custom __init__ function for the dataclasses that back
19
20
  Strawberry object types to only accept keyword arguments. This allows us to
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import re
4
+ import typing
4
5
  import warnings
5
6
  from decimal import Decimal
6
7
  from typing import (
@@ -905,7 +906,7 @@ class StrawberryPlugin(Plugin):
905
906
 
906
907
  return None
907
908
 
908
- def get_type_analyze_hook(self, fullname: str):
909
+ def get_type_analyze_hook(self, fullname: str) -> Union[Callable[..., Type], None]:
909
910
  if self._is_strawberry_lazy_type(fullname):
910
911
  return lazy_type_analyze_callback
911
912
 
@@ -1030,14 +1031,15 @@ class StrawberryPlugin(Plugin):
1030
1031
  )
1031
1032
 
1032
1033
 
1033
- def plugin(version: str):
1034
+ def plugin(version: str) -> typing.Type[StrawberryPlugin]:
1034
1035
  match = VERSION_RE.match(version)
1035
1036
  if match:
1036
1037
  MypyVersion.VERSION = Decimal(".".join(match.groups()))
1037
1038
  else:
1038
1039
  MypyVersion.VERSION = FALLBACK_VERSION
1039
1040
  warnings.warn(
1040
- f"Mypy version {version} could not be parsed. Reverting to v0.800"
1041
+ f"Mypy version {version} could not be parsed. Reverting to v0.800",
1042
+ stacklevel=1,
1041
1043
  )
1042
1044
 
1043
1045
  return StrawberryPlugin