strawberry-graphql 0.256.1__py3-none-any.whl → 0.257.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. strawberry/__init__.py +2 -0
  2. strawberry/aiohttp/test/client.py +1 -3
  3. strawberry/aiohttp/views.py +3 -3
  4. strawberry/annotation.py +5 -7
  5. strawberry/asgi/__init__.py +4 -4
  6. strawberry/channels/handlers/ws_handler.py +3 -3
  7. strawberry/channels/testing.py +3 -1
  8. strawberry/cli/__init__.py +7 -5
  9. strawberry/cli/commands/codegen.py +12 -14
  10. strawberry/cli/commands/server.py +2 -2
  11. strawberry/cli/commands/upgrade/_run_codemod.py +2 -3
  12. strawberry/cli/utils/__init__.py +1 -1
  13. strawberry/codegen/plugins/python.py +4 -5
  14. strawberry/codegen/plugins/typescript.py +3 -2
  15. strawberry/codegen/query_codegen.py +12 -11
  16. strawberry/codemods/annotated_unions.py +1 -1
  17. strawberry/codemods/update_imports.py +2 -4
  18. strawberry/dataloader.py +2 -2
  19. strawberry/django/__init__.py +2 -2
  20. strawberry/django/views.py +2 -3
  21. strawberry/exceptions/__init__.py +4 -2
  22. strawberry/exceptions/exception.py +1 -1
  23. strawberry/exceptions/permission_fail_silently_requires_optional.py +2 -1
  24. strawberry/exceptions/utils/source_finder.py +1 -1
  25. strawberry/experimental/pydantic/_compat.py +4 -4
  26. strawberry/experimental/pydantic/conversion.py +4 -5
  27. strawberry/experimental/pydantic/fields.py +1 -2
  28. strawberry/experimental/pydantic/object_type.py +6 -2
  29. strawberry/experimental/pydantic/utils.py +3 -9
  30. strawberry/ext/mypy_plugin.py +7 -14
  31. strawberry/extensions/context.py +15 -19
  32. strawberry/extensions/field_extension.py +53 -54
  33. strawberry/extensions/query_depth_limiter.py +27 -33
  34. strawberry/extensions/tracing/datadog.py +1 -1
  35. strawberry/extensions/tracing/opentelemetry.py +9 -14
  36. strawberry/fastapi/router.py +2 -3
  37. strawberry/federation/schema.py +3 -3
  38. strawberry/flask/views.py +3 -2
  39. strawberry/http/async_base_view.py +2 -4
  40. strawberry/http/ides.py +1 -3
  41. strawberry/http/sync_base_view.py +1 -2
  42. strawberry/litestar/controller.py +6 -5
  43. strawberry/permission.py +1 -1
  44. strawberry/quart/views.py +2 -2
  45. strawberry/relay/fields.py +28 -3
  46. strawberry/relay/types.py +1 -1
  47. strawberry/schema/base.py +0 -2
  48. strawberry/schema/execute.py +11 -11
  49. strawberry/schema/name_converter.py +4 -5
  50. strawberry/schema/schema.py +6 -4
  51. strawberry/schema/schema_converter.py +24 -17
  52. strawberry/schema/subscribe.py +4 -4
  53. strawberry/schema/types/base_scalars.py +4 -2
  54. strawberry/schema/types/scalar.py +1 -1
  55. strawberry/schema_codegen/__init__.py +5 -6
  56. strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +2 -2
  57. strawberry/subscriptions/protocols/graphql_ws/handlers.py +0 -3
  58. strawberry/test/client.py +1 -2
  59. strawberry/types/arguments.py +2 -2
  60. strawberry/types/auto.py +3 -3
  61. strawberry/types/base.py +12 -16
  62. strawberry/types/cast.py +35 -0
  63. strawberry/types/field.py +11 -7
  64. strawberry/types/fields/resolver.py +12 -19
  65. strawberry/types/union.py +1 -1
  66. strawberry/types/unset.py +1 -2
  67. strawberry/utils/debug.py +1 -1
  68. strawberry/utils/deprecations.py +1 -1
  69. strawberry/utils/graphql_lexer.py +6 -4
  70. strawberry/utils/typing.py +1 -2
  71. {strawberry_graphql-0.256.1.dist-info → strawberry_graphql-0.257.0.dist-info}/METADATA +2 -2
  72. {strawberry_graphql-0.256.1.dist-info → strawberry_graphql-0.257.0.dist-info}/RECORD +75 -74
  73. {strawberry_graphql-0.256.1.dist-info → strawberry_graphql-0.257.0.dist-info}/WHEEL +1 -1
  74. {strawberry_graphql-0.256.1.dist-info → strawberry_graphql-0.257.0.dist-info}/LICENSE +0 -0
  75. {strawberry_graphql-0.256.1.dist-info → strawberry_graphql-0.257.0.dist-info}/entry_points.txt +0 -0
@@ -111,9 +111,7 @@ def lazy_type_analyze_callback(ctx: AnalyzeTypeContext) -> Type:
111
111
  return AnyType(TypeOfAny.special_form)
112
112
 
113
113
  type_name = ctx.type.args[0]
114
- type_ = ctx.api.analyze_type(type_name)
115
-
116
- return type_
114
+ return ctx.api.analyze_type(type_name)
117
115
 
118
116
 
119
117
  def _get_named_type(name: str, api: SemanticAnalyzerPluginInterface) -> Any:
@@ -147,14 +145,12 @@ def _get_type_for_expr(expr: Expression, api: SemanticAnalyzerPluginInterface) -
147
145
  if isinstance(expr, MemberExpr):
148
146
  if expr.fullname:
149
147
  return _get_named_type(expr.fullname, api)
150
- else:
151
- raise InvalidNodeTypeException(expr)
148
+ raise InvalidNodeTypeException(expr)
152
149
 
153
150
  if isinstance(expr, CallExpr):
154
151
  if expr.analyzed:
155
152
  return _get_type_for_expr(expr.analyzed, api)
156
- else:
157
- raise InvalidNodeTypeException(expr)
153
+ raise InvalidNodeTypeException(expr)
158
154
 
159
155
  if isinstance(expr, CastExpr):
160
156
  return expr.type
@@ -178,8 +174,6 @@ def create_type_hook(ctx: DynamicClassDefContext) -> None:
178
174
  SymbolTableNode(GDEF, type_alias, plugin_generated=True),
179
175
  )
180
176
 
181
- return
182
-
183
177
 
184
178
  def union_hook(ctx: DynamicClassDefContext) -> None:
185
179
  try:
@@ -343,13 +337,12 @@ def add_static_method_to_class(
343
337
  cls.defs.body.remove(sym.node)
344
338
 
345
339
  # For compat with mypy < 0.93
346
- if MypyVersion.VERSION < Decimal("0.93"):
340
+ if Decimal("0.93") > MypyVersion.VERSION:
347
341
  function_type = api.named_type("__builtins__.function")
342
+ elif isinstance(api, SemanticAnalyzerPluginInterface):
343
+ function_type = api.named_type("builtins.function")
348
344
  else:
349
- if isinstance(api, SemanticAnalyzerPluginInterface):
350
- function_type = api.named_type("builtins.function")
351
- else:
352
- function_type = api.named_generic_type("builtins.function", [])
345
+ function_type = api.named_generic_type("builtins.function", [])
353
346
 
354
347
  arg_types, arg_names, arg_kinds = [], [], []
355
348
  for arg in args:
@@ -77,7 +77,7 @@ class ExtensionContextManagerBase:
77
77
  f"{extension} defines both legacy and new style extension hooks for "
78
78
  "{self.HOOK_NAME}"
79
79
  )
80
- elif is_legacy:
80
+ if is_legacy:
81
81
  warnings.warn(self.DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=3)
82
82
  return self.from_legacy(extension, on_start, on_end)
83
83
 
@@ -128,19 +128,17 @@ class ExtensionContextManagerBase:
128
128
 
129
129
  return WrappedHook(extension=extension, hook=iterator, is_async=True)
130
130
 
131
- else:
131
+ @contextlib.contextmanager
132
+ def iterator_async() -> Iterator[None]:
133
+ if on_start:
134
+ on_start()
132
135
 
133
- @contextlib.contextmanager
134
- def iterator_async() -> Iterator[None]:
135
- if on_start:
136
- on_start()
137
-
138
- yield
136
+ yield
139
137
 
140
- if on_end:
141
- on_end()
138
+ if on_end:
139
+ on_end()
142
140
 
143
- return WrappedHook(extension=extension, hook=iterator_async, is_async=False)
141
+ return WrappedHook(extension=extension, hook=iterator_async, is_async=False)
144
142
 
145
143
  @staticmethod
146
144
  def from_callable(
@@ -155,14 +153,13 @@ class ExtensionContextManagerBase:
155
153
  yield
156
154
 
157
155
  return WrappedHook(extension=extension, hook=iterator, is_async=True)
158
- else:
159
156
 
160
- @contextlib.contextmanager
161
- def iterator() -> Iterator[None]:
162
- func(extension)
163
- yield
157
+ @contextlib.contextmanager # type: ignore[no-redef]
158
+ def iterator() -> Iterator[None]:
159
+ func(extension)
160
+ yield
164
161
 
165
- return WrappedHook(extension=extension, hook=iterator, is_async=False)
162
+ return WrappedHook(extension=extension, hook=iterator, is_async=False)
166
163
 
167
164
  def __enter__(self) -> None:
168
165
  self.exit_stack = contextlib.ExitStack()
@@ -175,8 +172,7 @@ class ExtensionContextManagerBase:
175
172
  f"SchemaExtension hook {hook.extension}.{self.HOOK_NAME} "
176
173
  "failed to complete synchronously."
177
174
  )
178
- else:
179
- self.exit_stack.enter_context(hook.hook()) # type: ignore
175
+ self.exit_stack.enter_context(hook.hook()) # type: ignore
180
176
 
181
177
  def __exit__(
182
178
  self,
@@ -99,62 +99,61 @@ def build_field_extension_resolvers(
99
99
  f"Please add a resolve_async method to the extension(s)."
100
100
  )
101
101
  return _get_async_resolvers(field.extensions)
102
- else:
103
- # Try to wrap all sync resolvers in async so that we can use async extensions
104
- # on sync fields. This is not possible the other way around since
105
- # the result of an async resolver would have to be awaited before calling
106
- # the sync extension, making it impossible for the extension to modify
107
- # any arguments.
108
- non_sync_extensions = [
109
- extension for extension in field.extensions if not extension.supports_sync
110
- ]
111
-
112
- if len(non_sync_extensions) == 0:
113
- # Resolve everything sync
114
- return _get_sync_resolvers(field.extensions)
115
-
116
- # We have async-only extensions and need to wrap the resolver
117
- # That means we can't have sync-only extensions after the first async one
118
-
119
- # Check if we have a chain of sync-compatible
120
- # extensions before the async extensions
121
- # -> S-S-S-S-A-A-A-A
122
- found_sync_extensions = 0
123
-
124
- # All sync only extensions must be found before the first async-only one
125
- found_sync_only_extensions = 0
126
- for extension in field.extensions:
127
- # ...A, abort
128
- if extension in non_sync_extensions:
129
- break
130
- # ...S
131
- if extension in non_async_extensions:
132
- found_sync_only_extensions += 1
133
- found_sync_extensions += 1
134
-
135
- # Length of the chain equals length of non async extensions
136
- # All sync extensions run first
137
- if len(non_async_extensions) == found_sync_only_extensions:
138
- # Prepend sync to async extension to field extensions
139
- return list(
140
- itertools.chain(
141
- _get_sync_resolvers(field.extensions[:found_sync_extensions]),
142
- [SyncToAsyncExtension().resolve_async],
143
- _get_async_resolvers(field.extensions[found_sync_extensions:]),
144
- )
145
- )
102
+ # Try to wrap all sync resolvers in async so that we can use async extensions
103
+ # on sync fields. This is not possible the other way around since
104
+ # the result of an async resolver would have to be awaited before calling
105
+ # the sync extension, making it impossible for the extension to modify
106
+ # any arguments.
107
+ non_sync_extensions = [
108
+ extension for extension in field.extensions if not extension.supports_sync
109
+ ]
146
110
 
147
- # Some sync extensions follow the first async-only extension. Error case
148
- async_extension_names = ",".join(
149
- [extension.__class__.__name__ for extension in non_sync_extensions]
150
- )
151
- raise TypeError(
152
- f"Cannot mix async-only extension(s) {async_extension_names} "
153
- f"with sync-only extension(s) {non_async_extension_names} "
154
- f"on Field {field.name}. "
155
- f"If possible try to change the execution order so that all sync-only "
156
- f"extensions are executed first."
111
+ if len(non_sync_extensions) == 0:
112
+ # Resolve everything sync
113
+ return _get_sync_resolvers(field.extensions)
114
+
115
+ # We have async-only extensions and need to wrap the resolver
116
+ # That means we can't have sync-only extensions after the first async one
117
+
118
+ # Check if we have a chain of sync-compatible
119
+ # extensions before the async extensions
120
+ # -> S-S-S-S-A-A-A-A
121
+ found_sync_extensions = 0
122
+
123
+ # All sync only extensions must be found before the first async-only one
124
+ found_sync_only_extensions = 0
125
+ for extension in field.extensions:
126
+ # ...A, abort
127
+ if extension in non_sync_extensions:
128
+ break
129
+ # ...S
130
+ if extension in non_async_extensions:
131
+ found_sync_only_extensions += 1
132
+ found_sync_extensions += 1
133
+
134
+ # Length of the chain equals length of non async extensions
135
+ # All sync extensions run first
136
+ if len(non_async_extensions) == found_sync_only_extensions:
137
+ # Prepend sync to async extension to field extensions
138
+ return list(
139
+ itertools.chain(
140
+ _get_sync_resolvers(field.extensions[:found_sync_extensions]),
141
+ [SyncToAsyncExtension().resolve_async],
142
+ _get_async_resolvers(field.extensions[found_sync_extensions:]),
143
+ )
157
144
  )
158
145
 
146
+ # Some sync extensions follow the first async-only extension. Error case
147
+ async_extension_names = ",".join(
148
+ [extension.__class__.__name__ for extension in non_sync_extensions]
149
+ )
150
+ raise TypeError(
151
+ f"Cannot mix async-only extension(s) {async_extension_names} "
152
+ f"with sync-only extension(s) {non_async_extension_names} "
153
+ f"on Field {field.name}. "
154
+ f"If possible try to change the execution order so that all sync-only "
155
+ f"extensions are executed first."
156
+ )
157
+
159
158
 
160
159
  __all__ = ["FieldExtension"]
@@ -189,18 +189,17 @@ def resolve_field_value(
189
189
  ) -> FieldArgumentType:
190
190
  if isinstance(value, StringValueNode):
191
191
  return value.value
192
- elif isinstance(value, IntValueNode):
192
+ if isinstance(value, IntValueNode):
193
193
  return int(value.value)
194
- elif isinstance(value, FloatValueNode):
194
+ if isinstance(value, FloatValueNode):
195
195
  return float(value.value)
196
- elif isinstance(value, BooleanValueNode):
196
+ if isinstance(value, BooleanValueNode):
197
197
  return value.value
198
- elif isinstance(value, ListValueNode):
198
+ if isinstance(value, ListValueNode):
199
199
  return [resolve_field_value(v) for v in value.values]
200
- elif isinstance(value, ObjectValueNode):
200
+ if isinstance(value, ObjectValueNode):
201
201
  return {v.name.value: resolve_field_value(v.value) for v in value.fields}
202
- else:
203
- return {}
202
+ return {}
204
203
 
205
204
 
206
205
  def get_field_arguments(
@@ -250,20 +249,18 @@ def determine_depth(
250
249
  return 0
251
250
 
252
251
  return 1 + max(
253
- map(
254
- lambda selection: determine_depth(
255
- node=selection,
256
- fragments=fragments,
257
- depth_so_far=depth_so_far + 1,
258
- max_depth=max_depth,
259
- context=context,
260
- operation_name=operation_name,
261
- should_ignore=should_ignore,
262
- ),
263
- node.selection_set.selections,
252
+ determine_depth(
253
+ node=selection,
254
+ fragments=fragments,
255
+ depth_so_far=depth_so_far + 1,
256
+ max_depth=max_depth,
257
+ context=context,
258
+ operation_name=operation_name,
259
+ should_ignore=should_ignore,
264
260
  )
261
+ for selection in node.selection_set.selections
265
262
  )
266
- elif isinstance(node, FragmentSpreadNode):
263
+ if isinstance(node, FragmentSpreadNode):
267
264
  return determine_depth(
268
265
  node=fragments[node.name.value],
269
266
  fragments=fragments,
@@ -273,25 +270,22 @@ def determine_depth(
273
270
  operation_name=operation_name,
274
271
  should_ignore=should_ignore,
275
272
  )
276
- elif isinstance(
273
+ if isinstance(
277
274
  node, (InlineFragmentNode, FragmentDefinitionNode, OperationDefinitionNode)
278
275
  ):
279
276
  return max(
280
- map(
281
- lambda selection: determine_depth(
282
- node=selection,
283
- fragments=fragments,
284
- depth_so_far=depth_so_far,
285
- max_depth=max_depth,
286
- context=context,
287
- operation_name=operation_name,
288
- should_ignore=should_ignore,
289
- ),
290
- node.selection_set.selections,
277
+ determine_depth(
278
+ node=selection,
279
+ fragments=fragments,
280
+ depth_so_far=depth_so_far,
281
+ max_depth=max_depth,
282
+ context=context,
283
+ operation_name=operation_name,
284
+ should_ignore=should_ignore,
291
285
  )
286
+ for selection in node.selection_set.selections
292
287
  )
293
- else:
294
- raise TypeError(f"Depth crawler cannot handle: {node.kind}") # pragma: no cover
288
+ raise TypeError(f"Depth crawler cannot handle: {node.kind}") # pragma: no cover
295
289
 
296
290
 
297
291
  def is_ignored(node: FieldNode, ignore: Optional[list[IgnoreType]] = None) -> bool:
@@ -67,7 +67,7 @@ class DatadogTracingExtension(SchemaExtension):
67
67
  )
68
68
 
69
69
  def hash_query(self, query: str) -> str:
70
- return hashlib.md5(query.encode("utf-8")).hexdigest()
70
+ return hashlib.md5(query.encode("utf-8")).hexdigest() # noqa: S324
71
71
 
72
72
  def on_operation(self) -> Iterator[None]:
73
73
  self._operation_name = self.execution_context.operation_name
@@ -45,7 +45,7 @@ class OpenTelemetryExtension(SchemaExtension):
45
45
  ) -> None:
46
46
  self._arg_filter = arg_filter
47
47
  self._tracer = trace.get_tracer("strawberry")
48
- self._span_holder = dict()
48
+ self._span_holder = {}
49
49
  if execution_context:
50
50
  self.execution_context = execution_context
51
51
 
@@ -116,18 +116,17 @@ class OpenTelemetryExtension(SchemaExtension):
116
116
  # Put these in decreasing order of use-cases to exit as soon as possible
117
117
  if isinstance(value, (bool, str, bytes, int, float)):
118
118
  return value
119
- elif isinstance(value, (list, tuple, range)):
119
+ if isinstance(value, (list, tuple, range)):
120
120
  return self.convert_list_or_tuple_to_allowed_types(value)
121
- elif isinstance(value, dict):
121
+ if isinstance(value, dict):
122
122
  return self.convert_dict_to_allowed_types(value)
123
- elif isinstance(value, (set, frozenset)):
123
+ if isinstance(value, (set, frozenset)):
124
124
  return self.convert_set_to_allowed_types(value)
125
- elif isinstance(value, complex):
125
+ if isinstance(value, complex):
126
126
  return str(value) # Convert complex numbers to strings
127
- elif isinstance(value, (bytearray, memoryview)):
127
+ if isinstance(value, (bytearray, memoryview)):
128
128
  return bytes(value) # Convert bytearray and memoryview to bytes
129
- else:
130
- return str(value)
129
+ return str(value)
131
130
 
132
131
  def convert_set_to_allowed_types(self, value: Union[set, frozenset]) -> str:
133
132
  return (
@@ -192,9 +191,7 @@ class OpenTelemetryExtensionSync(OpenTelemetryExtension):
192
191
  **kwargs: Any,
193
192
  ) -> Any:
194
193
  if should_skip_tracing(_next, info):
195
- result = _next(root, info, *args, **kwargs)
196
-
197
- return result
194
+ return _next(root, info, *args, **kwargs)
198
195
 
199
196
  with self._tracer.start_as_current_span(
200
197
  f"GraphQL Resolving: {info.field_name}",
@@ -203,9 +200,7 @@ class OpenTelemetryExtensionSync(OpenTelemetryExtension):
203
200
  ),
204
201
  ) as span:
205
202
  self.add_tags(span, info, kwargs)
206
- result = _next(root, info, *args, **kwargs)
207
-
208
- return result
203
+ return _next(root, info, *args, **kwargs)
209
204
 
210
205
 
211
206
  __all__ = ["OpenTelemetryExtension", "OpenTelemetryExtensionSync"]
@@ -92,10 +92,9 @@ class GraphQLRouter(
92
92
  **default_context,
93
93
  **custom_context,
94
94
  }
95
- elif custom_context is None:
95
+ if custom_context is None:
96
96
  return default_context
97
- else:
98
- raise InvalidCustomContext()
97
+ raise InvalidCustomContext
99
98
 
100
99
  # replace the signature parameters of dependency...
101
100
  # ...with the old parameters minus the first argument as it will be replaced...
@@ -172,7 +172,7 @@ class Schema(BaseSchema):
172
172
 
173
173
  try:
174
174
  result = resolve_reference(**kwargs)
175
- except Exception as e:
175
+ except Exception as e: # noqa: BLE001
176
176
  result = e
177
177
  else:
178
178
  from strawberry.types.arguments import convert_argument
@@ -187,7 +187,7 @@ class Schema(BaseSchema):
187
187
  scalar_registry=scalar_registry,
188
188
  config=config,
189
189
  )
190
- except Exception:
190
+ except Exception: # noqa: BLE001
191
191
  result = TypeError(f"Unable to resolve reference for {type_name}")
192
192
 
193
193
  results.append(result)
@@ -271,7 +271,7 @@ class Schema(BaseSchema):
271
271
  link_directives: list[object] = [
272
272
  Link(
273
273
  url=url,
274
- import_=list(sorted(directives)),
274
+ import_=sorted(directives), # type: ignore[arg-type]
275
275
  )
276
276
  for url, directives in directive_by_url.items()
277
277
  ]
strawberry/flask/views.py CHANGED
@@ -4,6 +4,7 @@ import warnings
4
4
  from typing import (
5
5
  TYPE_CHECKING,
6
6
  Any,
7
+ ClassVar,
7
8
  Optional,
8
9
  Union,
9
10
  cast,
@@ -102,7 +103,7 @@ class GraphQLView(
102
103
  SyncBaseHTTPView[Request, Response, Response, Context, RootValue],
103
104
  View,
104
105
  ):
105
- methods = ["GET", "POST"]
106
+ methods: ClassVar[list[str]] = ["GET", "POST"]
106
107
  allow_queries_via_get: bool = True
107
108
  request_adapter_class = FlaskHTTPRequestAdapter
108
109
 
@@ -165,7 +166,7 @@ class AsyncGraphQLView(
165
166
  ],
166
167
  View,
167
168
  ):
168
- methods = ["GET", "POST"]
169
+ methods: ClassVar[list[str]] = ["GET", "POST"]
169
170
  allow_queries_via_get: bool = True
170
171
  request_adapter_class = AsyncFlaskHTTPRequestAdapter
171
172
 
@@ -306,8 +306,7 @@ class AsyncBaseHTTPView(
306
306
  await websocket.close(4406, "Subprotocol not acceptable")
307
307
 
308
308
  return websocket_response
309
- else:
310
- request = cast(Request, request)
309
+ request = cast(Request, request)
311
310
 
312
311
  request_adapter = self.request_adapter_class(request)
313
312
  sub_response = await self.get_sub_response(request)
@@ -325,8 +324,7 @@ class AsyncBaseHTTPView(
325
324
  if self.should_render_graphql_ide(request_adapter):
326
325
  if self.graphql_ide:
327
326
  return await self.render_graphql_ide(request)
328
- else:
329
- raise HTTPException(404, "Not Found")
327
+ raise HTTPException(404, "Not Found")
330
328
 
331
329
  try:
332
330
  result = await self.execute_operation(
strawberry/http/ides.py CHANGED
@@ -17,9 +17,7 @@ def get_graphql_ide_html(
17
17
  else:
18
18
  path = here / "static/graphiql.html"
19
19
 
20
- template = path.read_text(encoding="utf-8")
21
-
22
- return template
20
+ return path.read_text(encoding="utf-8")
23
21
 
24
22
 
25
23
  __all__ = ["GraphQL_IDE", "get_graphql_ide_html"]
@@ -175,8 +175,7 @@ class SyncBaseHTTPView(
175
175
  if self.should_render_graphql_ide(request_adapter):
176
176
  if self.graphql_ide:
177
177
  return self.render_graphql_ide(request)
178
- else:
179
- raise HTTPException(404, "Not Found")
178
+ raise HTTPException(404, "Not Found")
180
179
 
181
180
  sub_response = self.get_sub_response(request)
182
181
  context = (
@@ -9,6 +9,7 @@ from typing import (
9
9
  TYPE_CHECKING,
10
10
  Any,
11
11
  Callable,
12
+ ClassVar,
12
13
  Optional,
13
14
  TypedDict,
14
15
  Union,
@@ -203,13 +204,13 @@ class LitestarWebSocketAdapter(AsyncWebSocketAdapter):
203
204
 
204
205
  # Litestar internally defaults to an empty string for non-text messages
205
206
  if text == "":
206
- raise NonTextMessageReceived()
207
+ raise NonTextMessageReceived
207
208
 
208
209
  try:
209
210
  yield self.view.decode_json(text)
210
- except json.JSONDecodeError:
211
+ except json.JSONDecodeError as e:
211
212
  if not ignore_parsing_errors:
212
- raise NonJsonMessageReceived()
213
+ raise NonJsonMessageReceived from e
213
214
  except WebSocketDisconnect:
214
215
  pass
215
216
 
@@ -236,7 +237,7 @@ class GraphQLController(
236
237
  ],
237
238
  ):
238
239
  path: str = ""
239
- dependencies: Dependencies = {
240
+ dependencies: ClassVar[Dependencies] = { # type: ignore[misc]
240
241
  "custom_context": Provide(_none_custom_context_getter),
241
242
  "context": Provide(_context_getter_http),
242
243
  "context_ws": Provide(_context_getter_ws),
@@ -445,7 +446,7 @@ def make_graphql_controller(
445
446
 
446
447
  class _GraphQLController(GraphQLController):
447
448
  path: str = routes_path
448
- dependencies: Dependencies = {
449
+ dependencies: ClassVar[Dependencies] = { # type: ignore[misc]
449
450
  "custom_context": Provide(custom_context_getter_),
450
451
  "context": Provide(_context_getter_http),
451
452
  "context_ws": Provide(_context_getter_ws),
strawberry/permission.py CHANGED
@@ -101,7 +101,7 @@ class BasePermission(abc.ABC):
101
101
  if self.error_extensions:
102
102
  # Add our extensions to the error
103
103
  if not error.extensions:
104
- error.extensions = dict()
104
+ error.extensions = {}
105
105
  error.extensions.update(self.error_extensions)
106
106
 
107
107
  raise error
strawberry/quart/views.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import warnings
2
2
  from collections.abc import AsyncGenerator, Mapping
3
- from typing import TYPE_CHECKING, Callable, Optional, cast
3
+ from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast
4
4
  from typing_extensions import TypeGuard
5
5
 
6
6
  from quart import Request, Response, request
@@ -52,7 +52,7 @@ class GraphQLView(
52
52
  ],
53
53
  View,
54
54
  ):
55
- methods = ["GET", "POST"]
55
+ methods: ClassVar[list[str]] = ["GET", "POST"]
56
56
  allow_queries_via_get: bool = True
57
57
  request_adapter_class = QuartHTTPRequestAdapter
58
58
 
@@ -37,6 +37,7 @@ from strawberry.relay.exceptions import (
37
37
  )
38
38
  from strawberry.types.arguments import StrawberryArgument, argument
39
39
  from strawberry.types.base import StrawberryList, StrawberryOptional
40
+ from strawberry.types.cast import cast as strawberry_cast
40
41
  from strawberry.types.field import _RESOLVER_TYPE, StrawberryField, field
41
42
  from strawberry.types.fields.resolver import StrawberryResolver
42
43
  from strawberry.types.lazy_type import LazyType
@@ -88,12 +89,27 @@ class NodeExtension(FieldExtension):
88
89
  info: Info,
89
90
  id: Annotated[GlobalID, argument(description="The ID of the object.")],
90
91
  ) -> Union[Node, None, Awaitable[Union[Node, None]]]:
91
- return id.resolve_type(info).resolve_node(
92
+ node_type = id.resolve_type(info)
93
+ resolved_node = node_type.resolve_node(
92
94
  id.node_id,
93
95
  info=info,
94
96
  required=not is_optional,
95
97
  )
96
98
 
99
+ # We are using `strawberry_cast` here to cast the resolved node to make
100
+ # sure `is_type_of` will not try to find its type again. Very important
101
+ # when returning a non type (e.g. Django/SQLAlchemy/Pydantic model), as
102
+ # we could end up resolving to a different type in case more than one
103
+ # are registered.
104
+ if inspect.isawaitable(resolved_node):
105
+
106
+ async def resolve() -> Any:
107
+ return strawberry_cast(node_type, await resolved_node)
108
+
109
+ return resolve()
110
+
111
+ return cast(Node, strawberry_cast(node_type, resolved_node))
112
+
97
113
  return resolver
98
114
 
99
115
  def get_node_list_resolver(
@@ -139,6 +155,14 @@ class NodeExtension(FieldExtension):
139
155
  if inspect.isasyncgen(nodes)
140
156
  }
141
157
 
158
+ # We are using `strawberry_cast` here to cast the resolved node to make
159
+ # sure `is_type_of` will not try to find its type again. Very important
160
+ # when returning a non type (e.g. Django/SQLAlchemy/Pydantic model), as
161
+ # we could end up resolving to a different type in case more than one
162
+ # are registered
163
+ def cast_nodes(node_t: type[Node], nodes: Iterable[Any]) -> list[Node]:
164
+ return [cast(Node, strawberry_cast(node_t, node)) for node in nodes]
165
+
142
166
  if awaitable_nodes or asyncgen_nodes:
143
167
 
144
168
  async def resolve(resolved: Any = resolved_nodes) -> list[Node]:
@@ -161,7 +185,8 @@ class NodeExtension(FieldExtension):
161
185
 
162
186
  # Resolve any generator to lists
163
187
  resolved = {
164
- node_t: list(nodes) for node_t, nodes in resolved.items()
188
+ node_t: cast_nodes(node_t, nodes)
189
+ for node_t, nodes in resolved.items()
165
190
  }
166
191
  return [
167
192
  resolved[index_map[gid][0]][index_map[gid][1]] for gid in ids
@@ -171,7 +196,7 @@ class NodeExtension(FieldExtension):
171
196
 
172
197
  # Resolve any generator to lists
173
198
  resolved = {
174
- node_t: list(cast(Iterator[Node], nodes))
199
+ node_t: cast_nodes(node_t, cast(Iterable[Node], nodes))
175
200
  for node_t, nodes in resolved_nodes.items()
176
201
  }
177
202
  return [resolved[index_map[gid][0]][index_map[gid][1]] for gid in ids]
strawberry/relay/types.py CHANGED
@@ -758,7 +758,7 @@ class ListConnection(Connection[NodeType]):
758
758
  )
759
759
 
760
760
  @classmethod
761
- def resolve_connection(
761
+ def resolve_connection( # noqa: PLR0915
762
762
  cls,
763
763
  nodes: NodeIterableType[NodeType],
764
764
  *,