strawberry-graphql 0.257.0.dev1735244504__py3-none-any.whl → 0.258.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. strawberry/__init__.py +2 -0
  2. strawberry/aiohttp/test/client.py +1 -3
  3. strawberry/aiohttp/views.py +3 -3
  4. strawberry/annotation.py +5 -7
  5. strawberry/asgi/__init__.py +4 -4
  6. strawberry/channels/handlers/ws_handler.py +3 -3
  7. strawberry/channels/testing.py +3 -1
  8. strawberry/cli/__init__.py +7 -5
  9. strawberry/cli/commands/codegen.py +12 -14
  10. strawberry/cli/commands/server.py +2 -2
  11. strawberry/cli/commands/upgrade/_run_codemod.py +2 -3
  12. strawberry/cli/utils/__init__.py +1 -1
  13. strawberry/codegen/plugins/python.py +4 -5
  14. strawberry/codegen/plugins/typescript.py +3 -2
  15. strawberry/codegen/query_codegen.py +12 -11
  16. strawberry/codemods/annotated_unions.py +1 -1
  17. strawberry/codemods/update_imports.py +2 -4
  18. strawberry/dataloader.py +2 -2
  19. strawberry/django/__init__.py +2 -2
  20. strawberry/django/views.py +2 -3
  21. strawberry/exceptions/__init__.py +4 -2
  22. strawberry/exceptions/exception.py +1 -1
  23. strawberry/exceptions/permission_fail_silently_requires_optional.py +2 -1
  24. strawberry/exceptions/utils/source_finder.py +1 -1
  25. strawberry/experimental/pydantic/_compat.py +4 -4
  26. strawberry/experimental/pydantic/conversion.py +4 -5
  27. strawberry/experimental/pydantic/fields.py +1 -2
  28. strawberry/experimental/pydantic/object_type.py +6 -2
  29. strawberry/experimental/pydantic/utils.py +3 -9
  30. strawberry/ext/mypy_plugin.py +7 -14
  31. strawberry/extensions/context.py +15 -19
  32. strawberry/extensions/field_extension.py +53 -54
  33. strawberry/extensions/query_depth_limiter.py +27 -33
  34. strawberry/extensions/tracing/datadog.py +1 -1
  35. strawberry/extensions/tracing/opentelemetry.py +9 -14
  36. strawberry/fastapi/router.py +2 -3
  37. strawberry/federation/schema.py +3 -3
  38. strawberry/flask/views.py +3 -2
  39. strawberry/http/async_base_view.py +2 -4
  40. strawberry/http/ides.py +1 -3
  41. strawberry/http/sync_base_view.py +1 -2
  42. strawberry/litestar/controller.py +6 -5
  43. strawberry/permission.py +1 -1
  44. strawberry/quart/views.py +2 -2
  45. strawberry/relay/fields.py +38 -4
  46. strawberry/relay/types.py +6 -1
  47. strawberry/relay/utils.py +6 -1
  48. strawberry/schema/base.py +0 -2
  49. strawberry/schema/execute.py +11 -11
  50. strawberry/schema/name_converter.py +4 -5
  51. strawberry/schema/schema.py +6 -4
  52. strawberry/schema/schema_converter.py +24 -19
  53. strawberry/schema/subscribe.py +4 -4
  54. strawberry/schema/types/base_scalars.py +4 -2
  55. strawberry/schema/types/scalar.py +1 -1
  56. strawberry/schema_codegen/__init__.py +5 -6
  57. strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +2 -2
  58. strawberry/subscriptions/protocols/graphql_ws/handlers.py +0 -3
  59. strawberry/test/client.py +1 -2
  60. strawberry/types/arguments.py +2 -2
  61. strawberry/types/auto.py +3 -3
  62. strawberry/types/base.py +12 -16
  63. strawberry/types/cast.py +35 -0
  64. strawberry/types/field.py +11 -7
  65. strawberry/types/fields/resolver.py +12 -19
  66. strawberry/types/union.py +1 -1
  67. strawberry/types/unset.py +1 -2
  68. strawberry/utils/debug.py +1 -1
  69. strawberry/utils/deprecations.py +1 -1
  70. strawberry/utils/graphql_lexer.py +6 -4
  71. strawberry/utils/typing.py +1 -2
  72. {strawberry_graphql-0.257.0.dev1735244504.dist-info → strawberry_graphql-0.258.0.dist-info}/METADATA +3 -3
  73. {strawberry_graphql-0.257.0.dev1735244504.dist-info → strawberry_graphql-0.258.0.dist-info}/RECORD +76 -75
  74. {strawberry_graphql-0.257.0.dev1735244504.dist-info → strawberry_graphql-0.258.0.dist-info}/WHEEL +1 -1
  75. {strawberry_graphql-0.257.0.dev1735244504.dist-info → strawberry_graphql-0.258.0.dist-info}/LICENSE +0 -0
  76. {strawberry_graphql-0.257.0.dev1735244504.dist-info → strawberry_graphql-0.258.0.dist-info}/entry_points.txt +0 -0
strawberry/__init__.py CHANGED
@@ -13,6 +13,7 @@ from .schema import Schema
13
13
  from .schema_directive import schema_directive
14
14
  from .types.arguments import argument
15
15
  from .types.auto import auto
16
+ from .types.cast import cast
16
17
  from .types.enum import enum, enum_value
17
18
  from .types.field import field
18
19
  from .types.info import Info
@@ -36,6 +37,7 @@ __all__ = [
36
37
  "argument",
37
38
  "asdict",
38
39
  "auto",
40
+ "cast",
39
41
  "directive",
40
42
  "directive_field",
41
43
  "enum",
@@ -57,13 +57,11 @@ class GraphQLTestClient(BaseGraphQLTestClient):
57
57
  headers: Optional[dict[str, object]] = None,
58
58
  files: Optional[dict[str, object]] = None,
59
59
  ) -> Any:
60
- response = await self._client.post(
60
+ return await self._client.post(
61
61
  self.url,
62
62
  json=body if not files else None,
63
63
  data=body if files else None,
64
64
  )
65
65
 
66
- return response
67
-
68
66
 
69
67
  __all__ = ["GraphQLTestClient"]
@@ -99,12 +99,12 @@ class AioHTTPWebSocketAdapter(AsyncWebSocketAdapter):
99
99
  if ws_message.type == http.WSMsgType.TEXT:
100
100
  try:
101
101
  yield self.view.decode_json(ws_message.data)
102
- except JSONDecodeError:
102
+ except JSONDecodeError as e:
103
103
  if not ignore_parsing_errors:
104
- raise NonJsonMessageReceived()
104
+ raise NonJsonMessageReceived from e
105
105
 
106
106
  elif ws_message.type == http.WSMsgType.BINARY:
107
- raise NonTextMessageReceived()
107
+ raise NonTextMessageReceived
108
108
 
109
109
  async def send_json(self, message: Mapping[str, object]) -> None:
110
110
  try:
strawberry/annotation.py CHANGED
@@ -107,9 +107,7 @@ class StrawberryAnnotation:
107
107
  if isinstance(annotation, str):
108
108
  annotation = ForwardRef(annotation)
109
109
 
110
- evaled_type = eval_type(annotation, self.namespace, None)
111
-
112
- return evaled_type
110
+ return eval_type(annotation, self.namespace, None)
113
111
 
114
112
  def _get_type_with_args(
115
113
  self, evaled_type: type[Any]
@@ -155,13 +153,13 @@ class StrawberryAnnotation:
155
153
  # a StrawberryType
156
154
  if self._is_enum(evaled_type):
157
155
  return self.create_enum(evaled_type)
158
- elif self._is_optional(evaled_type, args):
156
+ if self._is_optional(evaled_type, args):
159
157
  return self.create_optional(evaled_type)
160
- elif self._is_union(evaled_type, args):
158
+ if self._is_union(evaled_type, args):
161
159
  return self.create_union(evaled_type, args)
162
- elif is_type_var(evaled_type) or evaled_type is Self:
160
+ if is_type_var(evaled_type) or evaled_type is Self:
163
161
  return self.create_type_var(cast(TypeVar, evaled_type))
164
- elif self._is_strawberry_type(evaled_type):
162
+ if self._is_strawberry_type(evaled_type):
165
163
  # Simply return objects that are already StrawberryTypes
166
164
  return evaled_type
167
165
 
@@ -97,11 +97,11 @@ class ASGIWebSocketAdapter(AsyncWebSocketAdapter):
97
97
  try:
98
98
  text = await self.ws.receive_text()
99
99
  yield self.view.decode_json(text)
100
- except JSONDecodeError: # noqa: PERF203
100
+ except JSONDecodeError as e: # noqa: PERF203
101
101
  if not ignore_parsing_errors:
102
- raise NonJsonMessageReceived()
103
- except KeyError:
104
- raise NonTextMessageReceived()
102
+ raise NonJsonMessageReceived from e
103
+ except KeyError as e:
104
+ raise NonTextMessageReceived from e
105
105
  except WebSocketDisconnect: # pragma: no cover
106
106
  pass
107
107
 
@@ -45,13 +45,13 @@ class ChannelsWebSocketAdapter(AsyncWebSocketAdapter):
45
45
  break
46
46
 
47
47
  if message["message"] is None:
48
- raise NonTextMessageReceived()
48
+ raise NonTextMessageReceived
49
49
 
50
50
  try:
51
51
  yield self.view.decode_json(message["message"])
52
- except json.JSONDecodeError:
52
+ except json.JSONDecodeError as e:
53
53
  if not ignore_parsing_errors:
54
- raise NonJsonMessageReceived()
54
+ raise NonJsonMessageReceived from e
55
55
 
56
56
  async def send_json(self, message: Mapping[str, object]) -> None:
57
57
  serialized_message = self.view.encode_json(message)
@@ -55,7 +55,7 @@ class GraphQLWebsocketCommunicator(WebsocketCommunicator):
55
55
  path: str,
56
56
  headers: Optional[list[tuple[bytes, bytes]]] = None,
57
57
  protocol: str = GRAPHQL_TRANSPORT_WS_PROTOCOL,
58
- connection_params: dict = {},
58
+ connection_params: dict | None = None,
59
59
  **kwargs: Any,
60
60
  ) -> None:
61
61
  """Create a new communicator.
@@ -69,6 +69,8 @@ class GraphQLWebsocketCommunicator(WebsocketCommunicator):
69
69
  subprotocols: an ordered list of preferred subprotocols to be sent to the server.
70
70
  **kwargs: additional arguments to be passed to the `WebsocketCommunicator` constructor.
71
71
  """
72
+ if connection_params is None:
73
+ connection_params = {}
72
74
  self.protocol = protocol
73
75
  subprotocols = kwargs.get("subprotocols", [])
74
76
  subprotocols.append(protocol)
@@ -1,10 +1,12 @@
1
1
  try:
2
2
  from .app import app
3
- from .commands.codegen import codegen as codegen # noqa
4
- from .commands.export_schema import export_schema as export_schema # noqa
5
- from .commands.schema_codegen import schema_codegen as schema_codegen # noqa
6
- from .commands.server import server as server # noqa
7
- from .commands.upgrade import upgrade as upgrade # noqa
3
+ from .commands.codegen import codegen as codegen # noqa: PLC0414
4
+ from .commands.export_schema import export_schema as export_schema # noqa: PLC0414
5
+ from .commands.schema_codegen import (
6
+ schema_codegen as schema_codegen, # noqa: PLC0414
7
+ )
8
+ from .commands.server import server as server # noqa: PLC0414
9
+ from .commands.upgrade import upgrade as upgrade # noqa: PLC0414
8
10
 
9
11
  def run() -> None:
10
12
  app()
@@ -39,23 +39,21 @@ def _import_plugin(plugin: str) -> Optional[type[QueryCodegenPlugin]]:
39
39
 
40
40
  assert _is_codegen_plugin(obj)
41
41
  return obj
42
- else:
42
+
43
+ symbols = {
44
+ key: value for key, value in module.__dict__.items() if not key.startswith("__")
45
+ }
46
+
47
+ if "__all__" in module.__dict__:
43
48
  symbols = {
44
- key: value
45
- for key, value in module.__dict__.items()
46
- if not key.startswith("__")
49
+ name: symbol
50
+ for name, symbol in symbols.items()
51
+ if name in module.__dict__["__all__"]
47
52
  }
48
53
 
49
- if "__all__" in module.__dict__:
50
- symbols = {
51
- name: symbol
52
- for name, symbol in symbols.items()
53
- if name in module.__dict__["__all__"]
54
- }
55
-
56
- for obj in symbols.values():
57
- if _is_codegen_plugin(obj):
58
- return obj
54
+ for obj in symbols.values():
55
+ if _is_codegen_plugin(obj):
56
+ return obj
59
57
 
60
58
  return None
61
59
 
@@ -25,7 +25,7 @@ class LogLevel(str, Enum):
25
25
  @app.command(help="Starts debug server")
26
26
  def server(
27
27
  schema: str,
28
- host: str = typer.Option("0.0.0.0", "-h", "--host", show_default=True),
28
+ host: str = typer.Option("0.0.0.0", "-h", "--host", show_default=True), # noqa: S104
29
29
  port: int = typer.Option(8000, "-p", "--port", show_default=True),
30
30
  log_level: LogLevel = typer.Option(
31
31
  "error",
@@ -60,7 +60,7 @@ def server(
60
60
  "install them by running:\n"
61
61
  r"pip install 'strawberry-graphql\[debug-server]'"
62
62
  )
63
- raise typer.Exit(1)
63
+ raise typer.Exit(1) # noqa: B904
64
64
 
65
65
  load_schema(schema, app_dir=app_dir)
66
66
 
@@ -41,9 +41,8 @@ def _execute_transform_wrap(
41
41
  additional_kwargs["scratch"] = {}
42
42
 
43
43
  # TODO: maybe capture warnings?
44
- with open(os.devnull, "w") as null: # noqa: PTH123
45
- with contextlib.redirect_stderr(null):
46
- return _execute_transform(**job, **additional_kwargs)
44
+ with open(os.devnull, "w") as null, contextlib.redirect_stderr(null): # noqa: PTH123
45
+ return _execute_transform(**job, **additional_kwargs)
47
46
 
48
47
 
49
48
  def _get_progress_and_pool(
@@ -16,7 +16,7 @@ def load_schema(schema: str, app_dir: str) -> Schema:
16
16
  message = str(exc)
17
17
 
18
18
  rich.print(f"[red]Error: {message}")
19
- raise typer.Exit(2)
19
+ raise typer.Exit(2) # noqa: B904
20
20
 
21
21
  if not isinstance(schema_symbol, Schema):
22
22
  message = "The `schema` must be an instance of strawberry.Schema"
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import textwrap
4
4
  from collections import defaultdict
5
5
  from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Optional
6
+ from typing import TYPE_CHECKING, ClassVar, Optional
7
7
 
8
8
  from strawberry.codegen import CodegenFile, QueryCodegenPlugin
9
9
  from strawberry.codegen.types import (
@@ -35,7 +35,7 @@ class PythonType:
35
35
 
36
36
 
37
37
  class PythonPlugin(QueryCodegenPlugin):
38
- SCALARS_TO_PYTHON_TYPES: dict[str, PythonType] = {
38
+ SCALARS_TO_PYTHON_TYPES: ClassVar[dict[str, PythonType]] = {
39
39
  "ID": PythonType("str"),
40
40
  "Int": PythonType("int"),
41
41
  "String": PythonType("str"),
@@ -128,7 +128,7 @@ class PythonPlugin(QueryCodegenPlugin):
128
128
  + ", ".join(self._print_argument_value(v) for v in argval.values)
129
129
  + "]"
130
130
  )
131
- elif isinstance(argval.values, dict):
131
+ if isinstance(argval.values, dict):
132
132
  return (
133
133
  "{"
134
134
  + ", ".join(
@@ -137,8 +137,7 @@ class PythonPlugin(QueryCodegenPlugin):
137
137
  )
138
138
  + "}"
139
139
  )
140
- else:
141
- raise TypeError(f"Unrecognized values type: {argval}")
140
+ raise TypeError(f"Unrecognized values type: {argval}")
142
141
  if isinstance(argval, GraphQLEnumValue):
143
142
  # This is an enum. It needs the namespace alongside the name.
144
143
  if argval.enum_type is None:
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import textwrap
4
- from typing import TYPE_CHECKING
4
+ from typing import TYPE_CHECKING, ClassVar
5
5
 
6
6
  from strawberry.codegen import CodegenFile, QueryCodegenPlugin
7
7
  from strawberry.codegen.types import (
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
20
20
 
21
21
 
22
22
  class TypeScriptPlugin(QueryCodegenPlugin):
23
- SCALARS_TO_TS_TYPE = {
23
+ SCALARS_TO_TS_TYPE: ClassVar[dict[str | type, str]] = {
24
24
  "ID": "string",
25
25
  "Int": "number",
26
26
  "String": "string",
@@ -102,6 +102,7 @@ class TypeScriptPlugin(QueryCodegenPlugin):
102
102
  if type_.name in self.SCALARS_TO_TS_TYPE:
103
103
  return ""
104
104
 
105
+ assert type_.python_type is not None
105
106
  return f"type {type_.name} = {self.SCALARS_TO_TS_TYPE[type_.python_type]}"
106
107
 
107
108
  def _print_union_type(self, type_: GraphQLUnion) -> str:
@@ -202,7 +202,7 @@ def _get_deps(t: GraphQLType) -> Iterable[GraphQLType]:
202
202
  yield from _get_deps(gql_type)
203
203
  else:
204
204
  # Want to make sure that all types are covered.
205
- raise ValueError(f"Unknown GraphQLType: {t}")
205
+ raise ValueError(f"Unknown GraphQLType: {t}") # noqa: TRY004
206
206
 
207
207
 
208
208
  _TYPE_TO_GRAPHQL_TYPE = {
@@ -249,14 +249,15 @@ class QueryCodegenPluginManager:
249
249
  def type_cmp(t1: GraphQLType, t2: GraphQLType) -> int:
250
250
  """Compare the types."""
251
251
  if t1 is t2:
252
- return 0
253
-
254
- if t1 in _get_deps(t2):
255
- return -1
252
+ retval = 0
253
+ elif t1 in _get_deps(t2):
254
+ retval = -1
256
255
  elif t2 in _get_deps(t1):
257
- return 1
256
+ retval = 1
258
257
  else:
259
- return 0
258
+ retval = 0
259
+
260
+ return retval
260
261
 
261
262
  return sorted(types, key=cmp_to_key(type_cmp))
262
263
 
@@ -311,15 +312,15 @@ class QueryCodegen:
311
312
  operations = self._get_operations(ast)
312
313
 
313
314
  if not operations:
314
- raise NoOperationProvidedError()
315
+ raise NoOperationProvidedError
315
316
 
316
317
  if len(operations) > 1:
317
- raise MultipleOperationsProvidedError()
318
+ raise MultipleOperationsProvidedError
318
319
 
319
320
  operation = operations[0]
320
321
 
321
322
  if operation.name is None:
322
- raise NoOperationNameProvidedError()
323
+ raise NoOperationNameProvidedError
323
324
 
324
325
  # Look for any free-floating fragments and create types out of them
325
326
  # These types can then be referenced and included later via the
@@ -550,7 +551,7 @@ class QueryCodegen:
550
551
  if isinstance(field_type, ScalarDefinition):
551
552
  return self._collect_scalar(field_type, None)
552
553
 
553
- elif isinstance(field_type, EnumDefinition):
554
+ if isinstance(field_type, EnumDefinition):
554
555
  return self._collect_enum(field_type)
555
556
 
556
557
  raise ValueError(f"Unsupported type: {field_type}") # pragma: no cover
@@ -50,7 +50,7 @@ class ConvertUnionToAnnotatedUnion(VisitorBasedCodemodCommand):
50
50
 
51
51
  super().__init__(context)
52
52
 
53
- def visit_Module(self, node: cst.Module) -> Optional[bool]:
53
+ def visit_Module(self, node: cst.Module) -> Optional[bool]: # noqa: N802
54
54
  self._is_using_named_import = False
55
55
 
56
56
  return super().visit_Module(node)
@@ -126,11 +126,9 @@ class UpdateImportsCodemod(VisitorBasedCodemodCommand):
126
126
 
127
127
  return updated_node
128
128
 
129
- def leave_ImportFrom(
129
+ def leave_ImportFrom( # noqa: N802
130
130
  self, node: cst.ImportFrom, updated_node: cst.ImportFrom
131
131
  ) -> cst.ImportFrom:
132
132
  updated_node = self._update_imports(updated_node, updated_node)
133
133
  updated_node = self._update_types_types_imports(updated_node, updated_node)
134
- updated_node = self._update_strawberry_type_imports(updated_node, updated_node)
135
-
136
- return updated_node
134
+ return self._update_strawberry_type_imports(updated_node, updated_node)
strawberry/dataloader.py CHANGED
@@ -240,7 +240,7 @@ async def dispatch_batch(loader: DataLoader, batch: Batch) -> None:
240
240
  values = list(values)
241
241
 
242
242
  if len(values) != len(batch):
243
- raise WrongNumberOfResultsReturned(
243
+ raise WrongNumberOfResultsReturned( # noqa: TRY301
244
244
  expected=len(batch), received=len(values)
245
245
  )
246
246
 
@@ -254,7 +254,7 @@ async def dispatch_batch(loader: DataLoader, batch: Batch) -> None:
254
254
  task.future.set_exception(value)
255
255
  else:
256
256
  task.future.set_result(value)
257
- except Exception as e:
257
+ except Exception as e: # noqa: BLE001
258
258
  for task in batch.tasks:
259
259
  task.future.set_exception(e)
260
260
 
@@ -12,9 +12,9 @@ except ModuleNotFoundError:
12
12
  import_symbol = f"{__name__}.{name}"
13
13
  try:
14
14
  return importlib.import_module(import_symbol)
15
- except ModuleNotFoundError:
15
+ except ModuleNotFoundError as e:
16
16
  raise AttributeError(
17
17
  f"Attempted import of {import_symbol} failed. Make sure to install the"
18
18
  "'strawberry-graphql-django' package to use the Strawberry Django "
19
19
  "extension API."
20
- )
20
+ ) from e
@@ -45,8 +45,7 @@ if TYPE_CHECKING:
45
45
 
46
46
  from strawberry.http import GraphQLHTTPResponse
47
47
  from strawberry.http.ides import GraphQL_IDE
48
-
49
- from ..schema import BaseSchema
48
+ from strawberry.schema import BaseSchema
50
49
 
51
50
 
52
51
  # TODO: remove this and unify temporal responses
@@ -266,7 +265,7 @@ class AsyncGraphQLView(
266
265
  request_adapter_class = AsyncDjangoHTTPRequestAdapter
267
266
 
268
267
  @classonlymethod # pyright: ignore[reportIncompatibleMethodOverride]
269
- def as_view(cls, **initkwargs: Any) -> Callable[..., HttpResponse]:
268
+ def as_view(cls, **initkwargs: Any) -> Callable[..., HttpResponse]: # noqa: N805
270
269
  # This code tells django that this view is async, see docs here:
271
270
  # https://docs.djangoproject.com/en/3.1/topics/async/#async-views
272
271
 
@@ -50,7 +50,7 @@ class UnallowedReturnTypeForUnion(Exception):
50
50
  def __init__(
51
51
  self, field_name: str, result_type: str, allowed_types: set[GraphQLObjectType]
52
52
  ) -> None:
53
- formatted_allowed_types = list(sorted(type_.name for type_ in allowed_types))
53
+ formatted_allowed_types = sorted(type_.name for type_ in allowed_types)
54
54
 
55
55
  message = (
56
56
  f'The type "{result_type}" of the field "{field_name}" '
@@ -160,7 +160,9 @@ class StrawberryGraphQLError(GraphQLError):
160
160
  class ConnectionRejectionError(Exception):
161
161
  """Use it when you want to reject a WebSocket connection."""
162
162
 
163
- def __init__(self, payload: dict[str, object] = {}) -> None:
163
+ def __init__(self, payload: dict[str, object] | None = None) -> None:
164
+ if payload is None:
165
+ payload = {}
164
166
  self.payload = payload
165
167
 
166
168
 
@@ -67,7 +67,7 @@ class StrawberryException(Exception, ABC):
67
67
  from rich.panel import Panel
68
68
 
69
69
  if self.exception_source is None:
70
- raise UnableToFindExceptionSource() from self
70
+ raise UnableToFindExceptionSource from self
71
71
 
72
72
  content = (
73
73
  self.__rich_header__,
@@ -7,7 +7,8 @@ from .exception import StrawberryException
7
7
  from .utils.source_finder import SourceFinder
8
8
 
9
9
  if TYPE_CHECKING:
10
- from ..field import StrawberryField
10
+ from strawberry.field import StrawberryField
11
+
11
12
  from .exception_source import ExceptionSource
12
13
 
13
14
 
@@ -8,7 +8,7 @@ from functools import cached_property
8
8
  from pathlib import Path
9
9
  from typing import TYPE_CHECKING, Any, Callable, Optional, cast
10
10
 
11
- from ..exception_source import ExceptionSource
11
+ from strawberry.exceptions.exception_source import ExceptionSource
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  from collections.abc import Sequence
@@ -119,7 +119,7 @@ def get_fields_map_for_v2() -> dict[Any, Any]:
119
119
 
120
120
  class PydanticV2Compat:
121
121
  @property
122
- def PYDANTIC_MISSING_TYPE(self) -> Any:
122
+ def PYDANTIC_MISSING_TYPE(self) -> Any: # noqa: N802
123
123
  from pydantic_core import PydanticUndefined
124
124
 
125
125
  return PydanticUndefined
@@ -155,7 +155,7 @@ class PydanticV2Compat:
155
155
  type_ = self.fields_map[type_]
156
156
 
157
157
  if type_ is None:
158
- raise UnsupportedTypeError()
158
+ raise UnsupportedTypeError
159
159
 
160
160
  if is_new_type(type_):
161
161
  return new_type_supertype(type_)
@@ -168,7 +168,7 @@ class PydanticV2Compat:
168
168
 
169
169
  class PydanticV1Compat:
170
170
  @property
171
- def PYDANTIC_MISSING_TYPE(self) -> Any:
171
+ def PYDANTIC_MISSING_TYPE(self) -> Any: # noqa: N802
172
172
  return dataclasses.MISSING
173
173
 
174
174
  def get_model_fields(self, model: type[BaseModel]) -> dict[str, CompatModelField]:
@@ -231,7 +231,7 @@ class PydanticV1Compat:
231
231
  type_ = self.fields_map[type_]
232
232
 
233
233
  if type_ is None:
234
- raise UnsupportedTypeError()
234
+ raise UnsupportedTypeError
235
235
 
236
236
  if is_new_type(type_):
237
237
  return new_type_supertype(type_)
@@ -101,17 +101,17 @@ def convert_pydantic_model_to_strawberry_class(
101
101
  def convert_strawberry_class_to_pydantic_model(obj: type) -> Any:
102
102
  if hasattr(obj, "to_pydantic"):
103
103
  return obj.to_pydantic()
104
- elif dataclasses.is_dataclass(obj):
104
+ if dataclasses.is_dataclass(obj):
105
105
  result = []
106
106
  for f in dataclasses.fields(obj):
107
107
  value = convert_strawberry_class_to_pydantic_model(getattr(obj, f.name))
108
108
  result.append((f.name, value))
109
109
  return dict(result)
110
- elif isinstance(obj, (list, tuple)):
110
+ if isinstance(obj, (list, tuple)):
111
111
  # Assume we can create an object of this type by passing in a
112
112
  # generator (which is not true for namedtuples, not supported).
113
113
  return type(obj)(convert_strawberry_class_to_pydantic_model(v) for v in obj)
114
- elif isinstance(obj, dict):
114
+ if isinstance(obj, dict):
115
115
  return type(obj)(
116
116
  (
117
117
  convert_strawberry_class_to_pydantic_model(k),
@@ -119,5 +119,4 @@ def convert_strawberry_class_to_pydantic_model(obj: type) -> Any:
119
119
  )
120
120
  for k, v in obj.items()
121
121
  )
122
- else:
123
- return copy.deepcopy(obj)
122
+ return copy.deepcopy(obj)
@@ -32,8 +32,7 @@ def replace_pydantic_types(type_: Any, is_input: bool) -> Any:
32
32
  attr = "_strawberry_input_type" if is_input else "_strawberry_type"
33
33
  if hasattr(type_, attr):
34
34
  return getattr(type_, attr)
35
- else:
36
- raise UnregisteredTypeException(type_)
35
+ raise UnregisteredTypeException(type_)
37
36
  return type_
38
37
 
39
38
 
@@ -29,6 +29,7 @@ from strawberry.experimental.pydantic.utils import (
29
29
  get_private_fields,
30
30
  )
31
31
  from strawberry.types.auto import StrawberryAuto
32
+ from strawberry.types.cast import get_strawberry_type_cast
32
33
  from strawberry.types.field import StrawberryField
33
34
  from strawberry.types.object_type import _process_type, _wrap_dataclass
34
35
  from strawberry.types.type_resolver import _get_fields
@@ -115,7 +116,7 @@ if TYPE_CHECKING:
115
116
  )
116
117
 
117
118
 
118
- def type(
119
+ def type( # noqa: PLR0915
119
120
  model: builtins.type[PydanticModel],
120
121
  *,
121
122
  fields: Optional[list[str]] = None,
@@ -127,7 +128,7 @@ def type(
127
128
  all_fields: bool = False,
128
129
  use_pydantic_alias: bool = True,
129
130
  ) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]:
130
- def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]:
131
+ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]: # noqa: PLR0915
131
132
  compat = PydanticCompat.from_model(model)
132
133
  model_fields = compat.get_model_fields(model)
133
134
  original_fields_set = set(fields) if fields else set()
@@ -207,6 +208,9 @@ def type(
207
208
  # pydantic objects (not the corresponding strawberry type)
208
209
  @classmethod # type: ignore
209
210
  def is_type_of(cls: builtins.type, obj: Any, _info: GraphQLResolveInfo) -> bool:
211
+ if (type_cast := get_strawberry_type_cast(obj)) is not None:
212
+ return type_cast is cls
213
+
210
214
  return isinstance(obj, (cls, model))
211
215
 
212
216
  namespace = {"is_type_of": is_type_of}
@@ -47,8 +47,7 @@ def normalize_type(type_: type) -> Any:
47
47
  def get_strawberry_type_from_model(type_: Any) -> Any:
48
48
  if hasattr(type_, "_strawberry_type"):
49
49
  return type_._strawberry_type
50
- else:
51
- raise UnregisteredTypeException(type_)
50
+ raise UnregisteredTypeException(type_)
52
51
 
53
52
 
54
53
  def get_private_fields(cls: type) -> list[dataclasses.Field]:
@@ -97,9 +96,7 @@ def get_default_factory_for_field(
97
96
  # if we have a default_factory, we should return it
98
97
 
99
98
  if has_factory:
100
- default_factory = cast("NoArgAnyCallable", default_factory)
101
-
102
- return default_factory
99
+ return cast("NoArgAnyCallable", default_factory)
103
100
 
104
101
  # if we have a default, we should return it
105
102
  if has_default:
@@ -108,8 +105,7 @@ def get_default_factory_for_field(
108
105
  # printing the value.
109
106
  if isinstance(default, BaseModel):
110
107
  return lambda: compat.model_dump(default)
111
- else:
112
- return lambda: smart_deepcopy(default)
108
+ return lambda: smart_deepcopy(default)
113
109
 
114
110
  # if we don't have default or default_factory, but the field is not required,
115
111
  # we should return a factory that returns None
@@ -131,5 +127,3 @@ def ensure_all_auto_fields_in_pydantic(
131
127
  raise AutoFieldsNotInBaseModelError(
132
128
  fields=non_existing_fields, cls_name=cls_name, model=model
133
129
  )
134
- else:
135
- return