strawberry-graphql 0.256.0__py3-none-any.whl → 0.257.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. strawberry/__init__.py +2 -0
  2. strawberry/aiohttp/test/client.py +1 -3
  3. strawberry/aiohttp/views.py +3 -3
  4. strawberry/annotation.py +5 -7
  5. strawberry/asgi/__init__.py +4 -4
  6. strawberry/channels/handlers/ws_handler.py +3 -3
  7. strawberry/channels/testing.py +3 -1
  8. strawberry/cli/__init__.py +7 -5
  9. strawberry/cli/commands/codegen.py +12 -14
  10. strawberry/cli/commands/server.py +2 -2
  11. strawberry/cli/commands/upgrade/_run_codemod.py +2 -3
  12. strawberry/cli/utils/__init__.py +1 -1
  13. strawberry/codegen/plugins/python.py +4 -5
  14. strawberry/codegen/plugins/typescript.py +3 -2
  15. strawberry/codegen/query_codegen.py +12 -11
  16. strawberry/codemods/annotated_unions.py +1 -1
  17. strawberry/codemods/update_imports.py +2 -4
  18. strawberry/dataloader.py +2 -2
  19. strawberry/django/__init__.py +2 -2
  20. strawberry/django/views.py +2 -3
  21. strawberry/exceptions/__init__.py +4 -2
  22. strawberry/exceptions/exception.py +1 -1
  23. strawberry/exceptions/permission_fail_silently_requires_optional.py +2 -1
  24. strawberry/exceptions/utils/source_finder.py +1 -1
  25. strawberry/experimental/pydantic/_compat.py +4 -4
  26. strawberry/experimental/pydantic/conversion.py +4 -5
  27. strawberry/experimental/pydantic/fields.py +1 -2
  28. strawberry/experimental/pydantic/object_type.py +6 -2
  29. strawberry/experimental/pydantic/utils.py +3 -9
  30. strawberry/ext/mypy_plugin.py +7 -14
  31. strawberry/extensions/context.py +15 -19
  32. strawberry/extensions/field_extension.py +53 -54
  33. strawberry/extensions/pyinstrument.py +1 -1
  34. strawberry/extensions/query_depth_limiter.py +27 -33
  35. strawberry/extensions/tracing/datadog.py +1 -1
  36. strawberry/extensions/tracing/opentelemetry.py +9 -14
  37. strawberry/fastapi/router.py +2 -3
  38. strawberry/federation/schema.py +3 -3
  39. strawberry/flask/views.py +3 -2
  40. strawberry/http/async_base_view.py +2 -4
  41. strawberry/http/ides.py +1 -3
  42. strawberry/http/sync_base_view.py +1 -2
  43. strawberry/litestar/controller.py +6 -5
  44. strawberry/permission.py +1 -1
  45. strawberry/quart/views.py +2 -2
  46. strawberry/relay/fields.py +28 -3
  47. strawberry/relay/types.py +1 -1
  48. strawberry/schema/base.py +0 -2
  49. strawberry/schema/execute.py +11 -11
  50. strawberry/schema/name_converter.py +4 -5
  51. strawberry/schema/schema.py +6 -4
  52. strawberry/schema/schema_converter.py +24 -17
  53. strawberry/schema/subscribe.py +4 -4
  54. strawberry/schema/types/base_scalars.py +4 -2
  55. strawberry/schema/types/scalar.py +1 -1
  56. strawberry/schema_codegen/__init__.py +5 -6
  57. strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +2 -2
  58. strawberry/subscriptions/protocols/graphql_ws/handlers.py +0 -3
  59. strawberry/test/client.py +1 -2
  60. strawberry/types/arguments.py +2 -2
  61. strawberry/types/auto.py +3 -3
  62. strawberry/types/base.py +12 -16
  63. strawberry/types/cast.py +35 -0
  64. strawberry/types/field.py +11 -7
  65. strawberry/types/fields/resolver.py +12 -19
  66. strawberry/types/union.py +1 -1
  67. strawberry/types/unset.py +1 -2
  68. strawberry/utils/debug.py +1 -1
  69. strawberry/utils/deprecations.py +1 -1
  70. strawberry/utils/graphql_lexer.py +6 -4
  71. strawberry/utils/typing.py +1 -2
  72. {strawberry_graphql-0.256.0.dist-info → strawberry_graphql-0.257.0.dist-info}/METADATA +2 -2
  73. {strawberry_graphql-0.256.0.dist-info → strawberry_graphql-0.257.0.dist-info}/RECORD +76 -75
  74. {strawberry_graphql-0.256.0.dist-info → strawberry_graphql-0.257.0.dist-info}/WHEEL +1 -1
  75. {strawberry_graphql-0.256.0.dist-info → strawberry_graphql-0.257.0.dist-info}/LICENSE +0 -0
  76. {strawberry_graphql-0.256.0.dist-info → strawberry_graphql-0.257.0.dist-info}/entry_points.txt +0 -0
strawberry/relay/types.py CHANGED
@@ -758,7 +758,7 @@ class ListConnection(Connection[NodeType]):
758
758
  )
759
759
 
760
760
  @classmethod
761
- def resolve_connection(
761
+ def resolve_connection( # noqa: PLR0915
762
762
  cls,
763
763
  nodes: NodeIterableType[NodeType],
764
764
  *,
strawberry/schema/base.py CHANGED
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import abstractmethod
4
- from functools import lru_cache
5
4
  from typing import TYPE_CHECKING, Any, Optional, Union
6
5
  from typing_extensions import Protocol
7
6
 
@@ -88,7 +87,6 @@ class BaseSchema(Protocol):
88
87
  raise NotImplementedError
89
88
 
90
89
  @abstractmethod
91
- @lru_cache
92
90
  def get_directive_by_name(self, graphql_name: str) -> Optional[StrawberryDirective]:
93
91
  raise NotImplementedError
94
92
 
@@ -85,7 +85,7 @@ async def _parse_and_validate_async(
85
85
  context: ExecutionContext, extensions_runner: SchemaExtensionsRunner
86
86
  ) -> Optional[PreExecutionError]:
87
87
  if not context.query:
88
- raise MissingQueryError()
88
+ raise MissingQueryError
89
89
 
90
90
  async with extensions_runner.parsing():
91
91
  try:
@@ -96,7 +96,7 @@ async def _parse_and_validate_async(
96
96
  context.errors = [error]
97
97
  return PreExecutionError(data=None, errors=[error])
98
98
 
99
- except Exception as error:
99
+ except Exception as error: # noqa: BLE001
100
100
  error = GraphQLError(str(error), original_error=error)
101
101
  context.errors = [error]
102
102
  return PreExecutionError(data=None, errors=[error])
@@ -189,9 +189,9 @@ async def execute(
189
189
  # only return a sanitised version to the client.
190
190
  process_errors(result.errors, execution_context)
191
191
 
192
- except (MissingQueryError, InvalidOperationTypeError) as e:
193
- raise e
194
- except Exception as exc:
192
+ except (MissingQueryError, InvalidOperationTypeError):
193
+ raise
194
+ except Exception as exc: # noqa: BLE001
195
195
  return await _handle_execution_result(
196
196
  execution_context,
197
197
  PreExecutionError(data=None, errors=[_coerce_error(exc)]),
@@ -219,7 +219,7 @@ def execute_sync(
219
219
  # Note: In graphql-core the schema would be validated here but in
220
220
  # Strawberry we are validating it at initialisation time instead
221
221
  if not execution_context.query:
222
- raise MissingQueryError()
222
+ raise MissingQueryError # noqa: TRY301
223
223
 
224
224
  with extensions_runner.parsing():
225
225
  try:
@@ -238,7 +238,7 @@ def execute_sync(
238
238
  )
239
239
 
240
240
  if execution_context.operation_type not in allowed_operation_types:
241
- raise InvalidOperationTypeError(execution_context.operation_type)
241
+ raise InvalidOperationTypeError(execution_context.operation_type) # noqa: TRY301
242
242
 
243
243
  with extensions_runner.validation():
244
244
  _run_validation(execution_context)
@@ -266,7 +266,7 @@ def execute_sync(
266
266
  if isawaitable(result):
267
267
  result = cast(Awaitable[GraphQLExecutionResult], result) # type: ignore[redundant-cast]
268
268
  ensure_future(result).cancel()
269
- raise RuntimeError(
269
+ raise RuntimeError( # noqa: TRY301
270
270
  "GraphQL execution failed to complete synchronously."
271
271
  )
272
272
 
@@ -282,9 +282,9 @@ def execute_sync(
282
282
  # extension). That way we can log the original errors but
283
283
  # only return a sanitised version to the client.
284
284
  process_errors(result.errors, execution_context)
285
- except (MissingQueryError, InvalidOperationTypeError) as e:
286
- raise e
287
- except Exception as exc:
285
+ except (MissingQueryError, InvalidOperationTypeError):
286
+ raise
287
+ except Exception as exc: # noqa: BLE001
288
288
  errors = [_coerce_error(exc)]
289
289
  execution_context.errors = errors
290
290
  process_errors(errors, execution_context)
@@ -47,18 +47,17 @@ class NameConverter:
47
47
  return self.from_directive(type_)
48
48
  if isinstance(type_, EnumDefinition): # TODO: Replace with StrawberryEnum
49
49
  return self.from_enum(type_)
50
- elif isinstance(type_, StrawberryObjectDefinition):
50
+ if isinstance(type_, StrawberryObjectDefinition):
51
51
  if type_.is_input:
52
52
  return self.from_input_object(type_)
53
53
  if type_.is_interface:
54
54
  return self.from_interface(type_)
55
55
  return self.from_object(type_)
56
- elif isinstance(type_, StrawberryUnion):
56
+ if isinstance(type_, StrawberryUnion):
57
57
  return self.from_union(type_)
58
- elif isinstance(type_, ScalarDefinition): # TODO: Replace with StrawberryScalar
58
+ if isinstance(type_, ScalarDefinition): # TODO: Replace with StrawberryScalar
59
59
  return self.from_scalar(type_)
60
- else:
61
- return str(type_)
60
+ return str(type_)
62
61
 
63
62
  def from_argument(self, argument: StrawberryArgument) -> str:
64
63
  return self.get_graphql_name(argument)
@@ -30,6 +30,7 @@ from strawberry.extensions.directives import (
30
30
  DirectivesExtensionSync,
31
31
  )
32
32
  from strawberry.extensions.runner import SchemaExtensionsRunner
33
+ from strawberry.printer import print_schema
33
34
  from strawberry.schema.schema_converter import GraphQLCoreConverter
34
35
  from strawberry.schema.types.scalar import DEFAULT_SCALAR_REGISTRY
35
36
  from strawberry.types import ExecutionContext
@@ -40,7 +41,6 @@ from strawberry.types.base import (
40
41
  )
41
42
  from strawberry.types.graphql import OperationType
42
43
 
43
- from ..printer import print_schema
44
44
  from . import compat
45
45
  from .base import BaseSchema
46
46
  from .config import StrawberryConfig
@@ -177,9 +177,11 @@ class Schema(BaseSchema):
177
177
  self.schema_converter.from_schema_directive(type_)
178
178
  )
179
179
  else:
180
- if has_object_definition(type_):
181
- if type_.__strawberry_definition__.is_graphql_generic:
182
- type_ = StrawberryAnnotation(type_).resolve() # noqa: PLW2901
180
+ if (
181
+ has_object_definition(type_)
182
+ and type_.__strawberry_definition__.is_graphql_generic
183
+ ):
184
+ type_ = StrawberryAnnotation(type_).resolve() # noqa: PLW2901
183
185
  graphql_type = self.schema_converter.from_maybe_optional(type_)
184
186
  if isinstance(graphql_type, GraphQLNonNull):
185
187
  graphql_type = graphql_type.of_type
@@ -47,6 +47,7 @@ from strawberry.exceptions import (
47
47
  ScalarAlreadyRegisteredError,
48
48
  UnresolvedFieldTypeError,
49
49
  )
50
+ from strawberry.extensions.field_extension import build_field_extension_resolvers
50
51
  from strawberry.schema.types.scalar import _make_scalar_type
51
52
  from strawberry.types.arguments import StrawberryArgument, convert_arguments
52
53
  from strawberry.types.base import (
@@ -57,6 +58,7 @@ from strawberry.types.base import (
57
58
  get_object_definition,
58
59
  has_object_definition,
59
60
  )
61
+ from strawberry.types.cast import get_strawberry_type_cast
60
62
  from strawberry.types.enum import EnumDefinition
61
63
  from strawberry.types.field import UNRESOLVED
62
64
  from strawberry.types.lazy_type import LazyType
@@ -66,7 +68,6 @@ from strawberry.types.union import StrawberryUnion
66
68
  from strawberry.types.unset import UNSET
67
69
  from strawberry.utils.await_maybe import await_maybe
68
70
 
69
- from ..extensions.field_extension import build_field_extension_resolvers
70
71
  from . import compat
71
72
  from .types.concrete_type import ConcreteType
72
73
 
@@ -91,7 +92,9 @@ if TYPE_CHECKING:
91
92
 
92
93
 
93
94
  FieldType = TypeVar(
94
- "FieldType", bound=Union[GraphQLField, GraphQLInputField], covariant=True
95
+ "FieldType",
96
+ bound=Union[GraphQLField, GraphQLInputField],
97
+ covariant=True,
95
98
  )
96
99
 
97
100
 
@@ -617,6 +620,9 @@ class GraphQLCoreConverter:
617
620
  )
618
621
 
619
622
  def is_type_of(obj: Any, _info: GraphQLResolveInfo) -> bool:
623
+ if (type_cast := get_strawberry_type_cast(obj)) is not None:
624
+ return type_cast in possible_types
625
+
620
626
  if object_type.concrete_of and (
621
627
  has_object_definition(obj)
622
628
  and obj.__strawberry_definition__.origin
@@ -756,9 +762,8 @@ class GraphQLCoreConverter:
756
762
  if field.is_async:
757
763
  _async_resolver._is_default = not field.base_resolver # type: ignore
758
764
  return _async_resolver
759
- else:
760
- _resolver._is_default = not field.base_resolver # type: ignore
761
- return _resolver
765
+ _resolver._is_default = not field.base_resolver # type: ignore
766
+ return _resolver
762
767
 
763
768
  def from_scalar(self, scalar: type) -> GraphQLScalarType:
764
769
  scalar_definition: ScalarDefinition
@@ -808,10 +813,9 @@ class GraphQLCoreConverter:
808
813
  NoneType = type(None)
809
814
  if type_ is None or type_ is NoneType:
810
815
  return self.from_type(type_)
811
- elif isinstance(type_, StrawberryOptional):
816
+ if isinstance(type_, StrawberryOptional):
812
817
  return self.from_type(type_.of_type)
813
- else:
814
- return GraphQLNonNull(self.from_type(type_))
818
+ return GraphQLNonNull(self.from_type(type_))
815
819
 
816
820
  def from_type(self, type_: Union[StrawberryType, type]) -> GraphQLNullableType:
817
821
  if compat.is_graphql_generic(type_):
@@ -819,27 +823,27 @@ class GraphQLCoreConverter:
819
823
 
820
824
  if isinstance(type_, EnumDefinition): # TODO: Replace with StrawberryEnum
821
825
  return self.from_enum(type_)
822
- elif compat.is_input_type(type_): # TODO: Replace with StrawberryInputObject
826
+ if compat.is_input_type(type_): # TODO: Replace with StrawberryInputObject
823
827
  return self.from_input_object(type_)
824
- elif isinstance(type_, StrawberryList):
828
+ if isinstance(type_, StrawberryList):
825
829
  return self.from_list(type_)
826
- elif compat.is_interface_type(type_): # TODO: Replace with StrawberryInterface
830
+ if compat.is_interface_type(type_): # TODO: Replace with StrawberryInterface
827
831
  type_definition: StrawberryObjectDefinition = (
828
832
  type_.__strawberry_definition__ # type: ignore
829
833
  )
830
834
  return self.from_interface(type_definition)
831
- elif has_object_definition(type_):
835
+ if has_object_definition(type_):
832
836
  return self.from_object(type_.__strawberry_definition__)
833
- elif compat.is_enum(type_): # TODO: Replace with StrawberryEnum
837
+ if compat.is_enum(type_): # TODO: Replace with StrawberryEnum
834
838
  enum_definition: EnumDefinition = type_._enum_definition # type: ignore
835
839
  return self.from_enum(enum_definition)
836
- elif isinstance(type_, StrawberryObjectDefinition):
840
+ if isinstance(type_, StrawberryObjectDefinition):
837
841
  return self.from_object(type_)
838
- elif isinstance(type_, StrawberryUnion):
842
+ if isinstance(type_, StrawberryUnion):
839
843
  return self.from_union(type_)
840
- elif isinstance(type_, LazyType):
844
+ if isinstance(type_, LazyType):
841
845
  return self.from_type(type_.resolve_type())
842
- elif compat.is_scalar(
846
+ if compat.is_scalar(
843
847
  type_, self.scalar_registry
844
848
  ): # TODO: Replace with StrawberryScalar
845
849
  return self.from_scalar(type_)
@@ -898,6 +902,9 @@ class GraphQLCoreConverter:
898
902
  if object_type.interfaces:
899
903
 
900
904
  def is_type_of(obj: Any, _info: GraphQLResolveInfo) -> bool:
905
+ if (type_cast := get_strawberry_type_cast(obj)) is not None:
906
+ return type_cast is object_type.origin
907
+
901
908
  if object_type.concrete_of and (
902
909
  has_object_definition(obj)
903
910
  and obj.__strawberry_definition__.origin
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
27
27
  from graphql.execution.middleware import MiddlewareManager
28
28
  from graphql.type.schema import GraphQLSchema
29
29
 
30
- from ..extensions.runner import SchemaExtensionsRunner
30
+ from strawberry.extensions.runner import SchemaExtensionsRunner
31
31
 
32
32
  SubscriptionResult: TypeAlias = Union[
33
33
  PreExecutionError, AsyncGenerator[ExecutionResult, None]
@@ -80,7 +80,7 @@ async def _subscribe(
80
80
  )
81
81
  # graphql-core 3.2 doesn't handle some of the pre-execution errors.
82
82
  # see `test_subscription_immediate_error`
83
- except Exception as exc:
83
+ except Exception as exc: # noqa: BLE001
84
84
  aiter_or_result = OriginalExecutionResult(
85
85
  data=None, errors=[_coerce_error(exc)]
86
86
  )
@@ -103,7 +103,7 @@ async def _subscribe(
103
103
  process_errors,
104
104
  )
105
105
  # graphql-core doesn't handle exceptions raised while executing.
106
- except Exception as exc:
106
+ except Exception as exc: # noqa: BLE001
107
107
  yield await _handle_execution_result(
108
108
  execution_context,
109
109
  OriginalExecutionResult(data=None, errors=[_coerce_error(exc)]),
@@ -111,7 +111,7 @@ async def _subscribe(
111
111
  process_errors,
112
112
  )
113
113
  # catch exceptions raised in `on_execute` hook.
114
- except Exception as exc:
114
+ except Exception as exc: # noqa: BLE001
115
115
  origin_result = OriginalExecutionResult(
116
116
  data=None, errors=[_coerce_error(exc)]
117
117
  )
@@ -15,7 +15,9 @@ def wrap_parser(parser: Callable, type_: str) -> Callable:
15
15
  try:
16
16
  return parser(value)
17
17
  except ValueError as e:
18
- raise GraphQLError(f'Value cannot represent a {type_}: "{value}". {e}')
18
+ raise GraphQLError( # noqa: B904
19
+ f'Value cannot represent a {type_}: "{value}". {e}'
20
+ )
19
21
 
20
22
  return inner
21
23
 
@@ -24,7 +26,7 @@ def parse_decimal(value: object) -> decimal.Decimal:
24
26
  try:
25
27
  return decimal.Decimal(str(value))
26
28
  except decimal.DecimalException:
27
- raise GraphQLError(f'Value cannot represent a Decimal: "{value}".')
29
+ raise GraphQLError(f'Value cannot represent a Decimal: "{value}".') # noqa: B904
28
30
 
29
31
 
30
32
  isoformat = methodcaller("isoformat")
@@ -62,7 +62,7 @@ DEFAULT_SCALAR_REGISTRY: dict[object, ScalarDefinition] = {
62
62
  datetime.datetime: _get_scalar_definition(base_scalars.DateTime),
63
63
  datetime.time: _get_scalar_definition(base_scalars.Time),
64
64
  decimal.Decimal: _get_scalar_definition(base_scalars.Decimal),
65
- # We can't wrap GLobalID with @scalar because it has custom attributes/methods
65
+ # We can't wrap GlobalID with @scalar because it has custom attributes/methods
66
66
  GlobalID: _get_scalar_definition(
67
67
  scalar(
68
68
  GlobalID,
@@ -115,7 +115,7 @@ def _get_field_type(
115
115
 
116
116
  if isinstance(field_type, NonNullTypeNode):
117
117
  return _get_field_type(field_type.type, was_non_nullable=True)
118
- elif isinstance(field_type, ListTypeNode):
118
+ if isinstance(field_type, ListTypeNode):
119
119
  expr = cst.Subscript(
120
120
  value=cst.Name("list"),
121
121
  slice=[
@@ -262,14 +262,13 @@ ArgumentValue: TypeAlias = Union[str, bool, list["ArgumentValue"]]
262
262
  def _get_argument_value(argument_value: ConstValueNode) -> ArgumentValue:
263
263
  if isinstance(argument_value, StringValueNode):
264
264
  return argument_value.value
265
- elif isinstance(argument_value, EnumValueDefinitionNode):
265
+ if isinstance(argument_value, EnumValueDefinitionNode):
266
266
  return argument_value.name.value
267
- elif isinstance(argument_value, ListValueNode):
267
+ if isinstance(argument_value, ListValueNode):
268
268
  return [_get_argument_value(arg) for arg in argument_value.values]
269
- elif isinstance(argument_value, BooleanValueNode):
269
+ if isinstance(argument_value, BooleanValueNode):
270
270
  return argument_value.value
271
- else:
272
- raise NotImplementedError(f"Unknown argument value {argument_value}")
271
+ raise NotImplementedError(f"Unknown argument value {argument_value}")
273
272
 
274
273
 
275
274
  def _get_directives(
@@ -122,7 +122,7 @@ class BaseGraphQLTransportWSHandler(Generic[Context, RootValue]):
122
122
  self.connection_timed_out = True
123
123
  reason = "Connection initialisation timeout"
124
124
  await self.websocket.close(code=4408, reason=reason)
125
- except Exception as error:
125
+ except Exception as error: # noqa: BLE001
126
126
  await self.handle_task_exception(error) # pragma: no cover
127
127
  finally:
128
128
  # do not clear self.connection_init_timeout_task
@@ -298,7 +298,7 @@ class BaseGraphQLTransportWSHandler(Generic[Context, RootValue]):
298
298
  {"id": operation.id, "type": "complete"}
299
299
  )
300
300
 
301
- except BaseException as e: # pragma: no cover
301
+ except BaseException: # pragma: no cover
302
302
  self.operations.pop(operation.id, None)
303
303
  raise
304
304
  finally:
@@ -34,9 +34,6 @@ if TYPE_CHECKING:
34
34
 
35
35
 
36
36
  class BaseGraphQLWSHandler(Generic[Context, RootValue]):
37
- context: Context
38
- root_value: RootValue
39
-
40
37
  def __init__(
41
38
  self,
42
39
  view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue],
strawberry/test/client.py CHANGED
@@ -188,8 +188,7 @@ class BaseGraphQLTestClient(ABC):
188
188
  # Variables can be mixed files and other data, we don't want to map non-files
189
189
  # vars so we need to remove them, we can't remove them before
190
190
  # because they can be part of a list of files or folder
191
- map_without_vars = {k: v for k, v in map.items() if k in files}
192
- return map_without_vars
191
+ return {k: v for k, v in map.items() if k in files}
193
192
 
194
193
  def _decode(self, response: Any, type: Literal["multipart", "json"]) -> Any:
195
194
  if type == "multipart":
@@ -23,8 +23,8 @@ from strawberry.types.base import (
23
23
  )
24
24
  from strawberry.types.enum import EnumDefinition
25
25
  from strawberry.types.lazy_type import LazyType, StrawberryLazyReference
26
- from strawberry.types.unset import UNSET as _deprecated_UNSET
27
- from strawberry.types.unset import _deprecated_is_unset # noqa # type: ignore
26
+ from strawberry.types.unset import UNSET as _deprecated_UNSET # noqa: N811
27
+ from strawberry.types.unset import _deprecated_is_unset # noqa: F401
28
28
 
29
29
  if TYPE_CHECKING:
30
30
  from strawberry.schema.config import StrawberryConfig
strawberry/types/auto.py CHANGED
@@ -23,8 +23,8 @@ class StrawberryAutoMeta(type):
23
23
 
24
24
  """
25
25
 
26
- def __init__(self, *args: str, **kwargs: Any) -> None:
27
- self._instance: Optional[StrawberryAuto] = None
26
+ def __init__(cls, *args: str, **kwargs: Any) -> None:
27
+ cls._instance: Optional[StrawberryAuto] = None
28
28
  super().__init__(*args, **kwargs)
29
29
 
30
30
  def __call__(cls, *args: str, **kwargs: Any) -> Any:
@@ -34,7 +34,7 @@ class StrawberryAutoMeta(type):
34
34
  return cls._instance
35
35
 
36
36
  def __instancecheck__(
37
- self,
37
+ cls,
38
38
  instance: Union[StrawberryAuto, StrawberryAnnotation, StrawberryType, type],
39
39
  ) -> bool:
40
40
  if isinstance(instance, StrawberryAnnotation):
strawberry/types/base.py CHANGED
@@ -54,12 +54,12 @@ class StrawberryType(ABC):
54
54
  str, Union[StrawberryType, type[WithStrawberryObjectDefinition]]
55
55
  ],
56
56
  ) -> Union[StrawberryType, type[WithStrawberryObjectDefinition]]:
57
- raise NotImplementedError()
57
+ raise NotImplementedError
58
58
 
59
59
  @property
60
60
  @abstractmethod
61
61
  def is_graphql_generic(self) -> bool:
62
- raise NotImplementedError()
62
+ raise NotImplementedError
63
63
 
64
64
  def has_generic(self, type_var: TypeVar) -> bool:
65
65
  return False
@@ -70,17 +70,15 @@ class StrawberryType(ABC):
70
70
  if isinstance(other, StrawberryType):
71
71
  return self is other
72
72
 
73
- elif isinstance(other, StrawberryAnnotation):
73
+ if isinstance(other, StrawberryAnnotation):
74
74
  return self == other.resolve()
75
75
 
76
- else:
77
- # This could be simplified if StrawberryAnnotation.resolve() always returned
78
- # a StrawberryType
79
- resolved = StrawberryAnnotation(other).resolve()
80
- if isinstance(resolved, StrawberryType):
81
- return self == resolved
82
- else:
83
- return NotImplemented
76
+ # This could be simplified if StrawberryAnnotation.resolve() always returned
77
+ # a StrawberryType
78
+ resolved = StrawberryAnnotation(other).resolve()
79
+ if isinstance(resolved, StrawberryType):
80
+ return self == resolved
81
+ return NotImplemented
84
82
 
85
83
  def __hash__(self) -> int:
86
84
  # TODO: Is this a bad idea? __eq__ objects are supposed to have the same hash
@@ -100,8 +98,7 @@ class StrawberryContainer(StrawberryType):
100
98
  if isinstance(other, StrawberryType):
101
99
  if isinstance(other, StrawberryContainer):
102
100
  return self.of_type == other.of_type
103
- else:
104
- return False
101
+ return False
105
102
 
106
103
  return super().__eq__(other)
107
104
 
@@ -112,11 +109,10 @@ class StrawberryContainer(StrawberryType):
112
109
 
113
110
  return list(parameters) if parameters else []
114
111
 
115
- elif isinstance(self.of_type, StrawberryType):
112
+ if isinstance(self.of_type, StrawberryType):
116
113
  return self.of_type.type_params
117
114
 
118
- else:
119
- return []
115
+ return []
120
116
 
121
117
  def copy_with(
122
118
  self,
@@ -0,0 +1,35 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, TypeVar, overload
4
+
5
+ _T = TypeVar("_T", bound=object)
6
+
7
+ TYPE_CAST_ATTRIBUTE = "__as_strawberry_type__"
8
+
9
+
10
+ @overload
11
+ def cast(type_: type, obj: None) -> None: ...
12
+
13
+
14
+ @overload
15
+ def cast(type_: type, obj: _T) -> _T: ...
16
+
17
+
18
+ def cast(type_: type, obj: _T | None) -> _T | None:
19
+ """Cast an object to given type.
20
+
21
+ This is used to mark an object as a cast object, so that the type can be
22
+ picked up when resolving unions/interfaces in case of ambiguity, which can
23
+ happen when returning an alike object instead of an instance of the type
24
+ (e.g. returning a Django, Pydantic or SQLAlchemy object)
25
+ """
26
+ if obj is None:
27
+ return None
28
+
29
+ setattr(obj, TYPE_CAST_ATTRIBUTE, type_)
30
+ return obj
31
+
32
+
33
+ def get_strawberry_type_cast(obj: Any) -> type | None:
34
+ """Get the type of a cast object."""
35
+ return getattr(obj, TYPE_CAST_ATTRIBUTE, None)
strawberry/types/field.py CHANGED
@@ -191,17 +191,21 @@ class StrawberryField(dataclasses.Field):
191
191
  for argument in resolver.arguments:
192
192
  if isinstance(argument.type_annotation.annotation, str):
193
193
  continue
194
- elif isinstance(argument.type, StrawberryUnion):
194
+
195
+ if isinstance(argument.type, StrawberryUnion):
196
+ raise InvalidArgumentTypeError(
197
+ resolver,
198
+ argument,
199
+ )
200
+
201
+ if (
202
+ has_object_definition(argument.type)
203
+ and argument.type.__strawberry_definition__.is_interface
204
+ ):
195
205
  raise InvalidArgumentTypeError(
196
206
  resolver,
197
207
  argument,
198
208
  )
199
- elif has_object_definition(argument.type):
200
- if argument.type.__strawberry_definition__.is_interface:
201
- raise InvalidArgumentTypeError(
202
- resolver,
203
- argument,
204
- )
205
209
 
206
210
  self.base_resolver = resolver
207
211
 
@@ -91,8 +91,7 @@ class ReservedNameBoundParameter(NamedTuple):
91
91
  if parameters: # Add compatibility for resolvers with no arguments
92
92
  first_parameter = parameters[0]
93
93
  return first_parameter if first_parameter.name == self.name else None
94
- else:
95
- return None
94
+ return None
96
95
 
97
96
 
98
97
  class ReservedType(NamedTuple):
@@ -145,21 +144,19 @@ class ReservedType(NamedTuple):
145
144
  )
146
145
  warnings.warn(warning, stacklevel=3)
147
146
  return reserved_name
148
- else:
149
- return None
147
+ return None
150
148
 
151
149
  def is_reserved_type(self, other: builtins.type) -> bool:
152
150
  origin = cast(type, get_origin(other)) or other
153
151
  if origin is Annotated:
154
152
  # Handle annotated arguments such as Private[str] and DirectiveValue[str]
155
153
  return type_has_annotation(other, self.type)
156
- else:
157
- # Handle both concrete and generic types (i.e Info, and Info)
158
- return (
159
- issubclass(origin, self.type)
160
- if isinstance(origin, type)
161
- else origin is self.type
162
- )
154
+ # Handle both concrete and generic types (i.e Info, and Info)
155
+ return (
156
+ issubclass(origin, self.type)
157
+ if isinstance(origin, type)
158
+ else origin is self.type
159
+ )
163
160
 
164
161
 
165
162
  SELF_PARAMSPEC = ReservedNameBoundParameter("self")
@@ -309,24 +306,20 @@ class StrawberryResolver(Generic[T]):
309
306
  reserved_names = {p.name for p in reserved_parameters.values() if p is not None}
310
307
 
311
308
  annotations = self._unbound_wrapped_func.__annotations__
312
- annotations = {
309
+ return {
313
310
  name: annotation
314
311
  for name, annotation in annotations.items()
315
312
  if name not in reserved_names
316
313
  }
317
314
 
318
- return annotations
319
-
320
315
  @cached_property
321
316
  def type_annotation(self) -> Optional[StrawberryAnnotation]:
322
317
  return_annotation = self.signature.return_annotation
323
318
  if return_annotation is inspect.Signature.empty:
324
319
  return None
325
- else:
326
- type_annotation = StrawberryAnnotation(
327
- annotation=return_annotation, namespace=self._namespace
328
- )
329
- return type_annotation
320
+ return StrawberryAnnotation(
321
+ annotation=return_annotation, namespace=self._namespace
322
+ )
330
323
 
331
324
  @property
332
325
  def type(self) -> Optional[Union[StrawberryType, type]]:
strawberry/types/union.py CHANGED
@@ -54,7 +54,7 @@ class StrawberryUnion(StrawberryType):
54
54
  def __init__(
55
55
  self,
56
56
  name: Optional[str] = None,
57
- type_annotations: tuple[StrawberryAnnotation, ...] = tuple(),
57
+ type_annotations: tuple[StrawberryAnnotation, ...] = (),
58
58
  description: Optional[str] = None,
59
59
  directives: Iterable[object] = (),
60
60
  ) -> None:
strawberry/types/unset.py CHANGED
@@ -14,8 +14,7 @@ class UnsetType:
14
14
  ret = super().__new__(cls)
15
15
  cls.__instance = ret
16
16
  return ret
17
- else:
18
- return cls.__instance
17
+ return cls.__instance
19
18
 
20
19
  def __str__(self) -> str:
21
20
  return ""
strawberry/utils/debug.py CHANGED
@@ -30,7 +30,7 @@ def pretty_print_graphql_operation(
30
30
  if operation_name == "IntrospectionQuery":
31
31
  return
32
32
 
33
- now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
33
+ now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") # noqa: DTZ005
34
34
 
35
35
  print(f"[{now}]: {operation_name or 'No operation name'}") # noqa: T201
36
36
  print(highlight(query, GraphQLLexer(), Terminal256Formatter())) # noqa: T201