strawberry-graphql 0.190.0.dev1687447182__py3-none-any.whl → 0.192.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. strawberry/annotation.py +24 -3
  2. strawberry/arguments.py +6 -2
  3. strawberry/channels/testing.py +22 -13
  4. strawberry/cli/__init__.py +4 -4
  5. strawberry/cli/commands/upgrade/__init__.py +75 -0
  6. strawberry/cli/commands/upgrade/_fake_progress.py +21 -0
  7. strawberry/cli/commands/upgrade/_run_codemod.py +74 -0
  8. strawberry/codemods/__init__.py +0 -0
  9. strawberry/codemods/annotated_unions.py +185 -0
  10. strawberry/exceptions/invalid_union_type.py +23 -3
  11. strawberry/exceptions/utils/source_finder.py +147 -11
  12. strawberry/extensions/field_extension.py +2 -5
  13. strawberry/fastapi/router.py +5 -4
  14. strawberry/federation/union.py +4 -5
  15. strawberry/field.py +116 -75
  16. strawberry/http/__init__.py +1 -3
  17. strawberry/permission.py +3 -166
  18. strawberry/relay/fields.py +2 -0
  19. strawberry/relay/types.py +14 -4
  20. strawberry/schema/schema.py +1 -1
  21. strawberry/schema/schema_converter.py +106 -38
  22. strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +20 -8
  23. strawberry/subscriptions/protocols/graphql_transport_ws/types.py +1 -1
  24. strawberry/subscriptions/protocols/graphql_ws/handlers.py +4 -7
  25. strawberry/type.py +2 -2
  26. strawberry/types/type_resolver.py +7 -29
  27. strawberry/types/types.py +6 -0
  28. strawberry/union.py +46 -17
  29. strawberry/utils/typing.py +21 -0
  30. {strawberry_graphql-0.190.0.dev1687447182.dist-info → strawberry_graphql-0.192.1.dist-info}/METADATA +1 -1
  31. {strawberry_graphql-0.190.0.dev1687447182.dist-info → strawberry_graphql-0.192.1.dist-info}/RECORD +34 -30
  32. strawberry/exceptions/permission_fail_silently_requires_optional.py +0 -52
  33. {strawberry_graphql-0.190.0.dev1687447182.dist-info → strawberry_graphql-0.192.1.dist-info}/LICENSE +0 -0
  34. {strawberry_graphql-0.190.0.dev1687447182.dist-info → strawberry_graphql-0.192.1.dist-info}/WHEEL +0 -0
  35. {strawberry_graphql-0.190.0.dev1687447182.dist-info → strawberry_graphql-0.192.1.dist-info}/entry_points.txt +0 -0
strawberry/annotation.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  import sys
4
5
  import typing
5
6
  from collections import abc
@@ -85,13 +86,15 @@ class StrawberryAnnotation:
85
86
  if isinstance(annotation, str):
86
87
  annotation = ForwardRef(annotation)
87
88
 
89
+ args = []
90
+
88
91
  evaled_type = eval_type(annotation, self.namespace, None)
89
92
 
90
93
  if is_private(evaled_type):
91
94
  return evaled_type
92
95
 
93
96
  if get_origin(evaled_type) is Annotated:
94
- evaled_type = get_args(evaled_type)[0]
97
+ evaled_type, *args = get_args(evaled_type)
95
98
 
96
99
  if self._is_async_type(evaled_type):
97
100
  evaled_type = self._strip_async_type(evaled_type)
@@ -116,7 +119,7 @@ class StrawberryAnnotation:
116
119
  elif self._is_optional(evaled_type):
117
120
  return self.create_optional(evaled_type)
118
121
  elif self._is_union(evaled_type):
119
- return self.create_union(evaled_type)
122
+ return self.create_union(evaled_type, args)
120
123
  elif is_type_var(evaled_type) or evaled_type is Self:
121
124
  return self.create_type_var(cast(TypeVar, evaled_type))
122
125
 
@@ -173,7 +176,7 @@ class StrawberryAnnotation:
173
176
  def create_type_var(self, evaled_type: TypeVar) -> StrawberryTypeVar:
174
177
  return StrawberryTypeVar(evaled_type)
175
178
 
176
- def create_union(self, evaled_type: Type) -> StrawberryUnion:
179
+ def create_union(self, evaled_type: Type[Any], args: list[Any]) -> StrawberryUnion:
177
180
  # Prevent import cycles
178
181
  from strawberry.union import StrawberryUnion
179
182
 
@@ -182,9 +185,27 @@ class StrawberryAnnotation:
182
185
  return evaled_type
183
186
 
184
187
  types = evaled_type.__args__
188
+
185
189
  union = StrawberryUnion(
186
190
  type_annotations=tuple(StrawberryAnnotation(type_) for type_ in types),
187
191
  )
192
+
193
+ union_args = [arg for arg in args if isinstance(arg, StrawberryUnion)]
194
+ if len(union_args) > 1:
195
+ logging.warning(
196
+ "Duplicate union definition detected. "
197
+ "Only the first definition will be considered"
198
+ )
199
+
200
+ if union_args:
201
+ arg = union_args[0]
202
+ union.graphql_name = arg.graphql_name
203
+ union.description = arg.description
204
+ union.directives = arg.directives
205
+
206
+ union._source_file = arg._source_file
207
+ union._source_line = arg._source_line
208
+
188
209
  return union
189
210
 
190
211
  @classmethod
strawberry/arguments.py CHANGED
@@ -177,13 +177,17 @@ def convert_argument(
177
177
  if has_object_definition(type_):
178
178
  kwargs = {}
179
179
 
180
- for field in type_.__strawberry_definition__.fields:
180
+ type_definition = type_.__strawberry_definition__
181
+ for field in type_definition.fields:
181
182
  value = cast(Mapping, value)
182
183
  graphql_name = config.name_converter.from_field(field)
183
184
 
184
185
  if graphql_name in value:
185
186
  kwargs[field.python_name] = convert_argument(
186
- value[graphql_name], field.type, scalar_registry, config
187
+ value[graphql_name],
188
+ field.resolve_type(type_definition=type_definition),
189
+ scalar_registry,
190
+ config,
187
191
  )
188
192
 
189
193
  type_ = cast(type, type_)
@@ -14,7 +14,7 @@ from typing import (
14
14
  Union,
15
15
  )
16
16
 
17
- from graphql import GraphQLError
17
+ from graphql import GraphQLError, GraphQLFormattedError
18
18
 
19
19
  from channels.testing.websocket import WebsocketCommunicator
20
20
  from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
@@ -100,6 +100,9 @@ class GraphQLWebsocketCommunicator(WebsocketCommunicator):
100
100
  response = await self.receive_json_from()
101
101
  assert response["type"] == GQL_CONNECTION_ACK
102
102
 
103
+ # Actual `ExecutionResult`` objects are not available client-side, since they
104
+ # get transformed into `FormattedExecutionResult` on the wire, but we attempt
105
+ # to do a limited representation of them here, to make testing simpler.
103
106
  async def subscribe(
104
107
  self, query: str, variables: Optional[Dict] = None
105
108
  ) -> Union[ExecutionResult, AsyncIterator[ExecutionResult]]:
@@ -125,22 +128,28 @@ class GraphQLWebsocketCommunicator(WebsocketCommunicator):
125
128
  message_type = response["type"]
126
129
  if message_type == NextMessage.type:
127
130
  payload = NextMessage(**response).payload
128
- ret = ExecutionResult(None, None)
129
- for field in dataclasses.fields(ExecutionResult):
130
- setattr(ret, field.name, payload.get(field.name, None))
131
- yield ret
131
+ ret = ExecutionResult(payload["data"], None)
132
+ if "errors" in payload:
133
+ ret.errors = self.process_errors(payload["errors"])
134
+ ret.extensions = payload.get("extensions", None)
135
+ yield ret
132
136
  elif message_type == ErrorMessage.type:
133
137
  error_payload = ErrorMessage(**response).payload
134
138
  yield ExecutionResult(
135
- data=None,
136
- errors=[
137
- GraphQLError(
138
- message=message["message"],
139
- extensions=message.get("extensions", None),
140
- )
141
- for message in error_payload
142
- ],
139
+ data=None, errors=self.process_errors(error_payload)
143
140
  )
144
141
  return # an error message is the last message for a subscription
145
142
  else:
146
143
  return
144
+
145
+ def process_errors(self, errors: List[GraphQLFormattedError]) -> List[GraphQLError]:
146
+ """Reconst a GraphQLError from a FormattedGraphQLError"""
147
+ result = []
148
+ for f_error in errors:
149
+ error = GraphQLError(
150
+ message=f_error["message"],
151
+ extensions=f_error.get("extensions", None),
152
+ )
153
+ error.path = f_error.get("path", None)
154
+ result.append(error)
155
+ return result
@@ -1,7 +1,7 @@
1
- from .commands.codegen import codegen # noqa
2
- from .commands.export_schema import export_schema # noqa
3
- from .commands.server import server # noqa
4
-
1
+ from .commands.codegen import codegen as codegen # noqa
2
+ from .commands.export_schema import export_schema as export_schema # noqa
3
+ from .commands.server import server as server # noqa
4
+ from .commands.upgrade import upgrade as upgrade # noqa
5
5
 
6
6
  from .app import app
7
7
 
@@ -0,0 +1,75 @@
1
+ from __future__ import annotations
2
+
3
+ import glob
4
+ import pathlib # noqa: TCH003
5
+ import sys
6
+ from typing import List
7
+
8
+ import rich
9
+ import typer
10
+ from libcst.codemod import CodemodContext
11
+
12
+ from strawberry.cli.app import app
13
+ from strawberry.codemods.annotated_unions import ConvertUnionToAnnotatedUnion
14
+
15
+ from ._run_codemod import run_codemod
16
+
17
+ codemods = {
18
+ "annotated-union": ConvertUnionToAnnotatedUnion,
19
+ }
20
+
21
+
22
+ # TODO: add support for running all of them
23
+ @app.command(help="Upgrades a Strawberry project to the latest version")
24
+ def upgrade(
25
+ codemod: str = typer.Argument(
26
+ ...,
27
+ autocompletion=lambda: list(codemods.keys()),
28
+ help="Name of the upgrade to run",
29
+ ),
30
+ paths: List[pathlib.Path] = typer.Argument(file_okay=True, dir_okay=True),
31
+ python_target: str = typer.Option(
32
+ ".".join(str(x) for x in sys.version_info[:2]),
33
+ "--python-target",
34
+ help="Python version to target",
35
+ ),
36
+ use_typing_extensions: bool = typer.Option(
37
+ False,
38
+ "--use-typing-extensions",
39
+ help="Use typing_extensions instead of typing for newer features",
40
+ ),
41
+ ) -> None:
42
+ if codemod not in codemods:
43
+ rich.print(f'[red]Upgrade named "{codemod}" does not exist')
44
+
45
+ raise typer.Exit(2)
46
+
47
+ python_target_version = tuple(int(x) for x in python_target.split("."))
48
+
49
+ transformer = ConvertUnionToAnnotatedUnion(
50
+ CodemodContext(),
51
+ use_pipe_syntax=python_target_version >= (3, 10),
52
+ use_typing_extensions=use_typing_extensions,
53
+ )
54
+
55
+ files: list[str] = []
56
+
57
+ for path in paths:
58
+ if path.is_dir():
59
+ glob_path = str(path / "**/*.py")
60
+ files.extend(glob.glob(glob_path, recursive=True))
61
+ else:
62
+ files.append(str(path))
63
+
64
+ files = list(set(files))
65
+
66
+ results = list(run_codemod(transformer, files))
67
+ changed = [result for result in results if result.changed]
68
+
69
+ rich.print()
70
+ rich.print("[green]Upgrade completed successfully, here's a summary:")
71
+ rich.print(f" - {len(changed)} files changed")
72
+ rich.print(f" - {len(results) - len(changed)} files skipped")
73
+
74
+ if changed:
75
+ raise typer.Exit(1)
@@ -0,0 +1,21 @@
1
+ from typing import Any
2
+
3
+ from rich.progress import TaskID
4
+
5
+
6
+ class FakeProgress:
7
+ """A fake progress bar that does nothing.
8
+
9
+ This is used when the user has only one file to process."""
10
+
11
+ def advance(self, task_id: TaskID) -> None:
12
+ pass
13
+
14
+ def add_task(self, *args: Any, **kwargs: Any) -> TaskID:
15
+ return TaskID(0)
16
+
17
+ def __enter__(self) -> "FakeProgress":
18
+ return self
19
+
20
+ def __exit__(self, *args: Any, **kwargs: Any) -> None:
21
+ pass
@@ -0,0 +1,74 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import os
5
+ from multiprocessing import Pool, cpu_count
6
+ from typing import TYPE_CHECKING, Any, Dict, Generator, Sequence, Type, Union
7
+
8
+ from libcst.codemod._cli import ExecutionConfig, ExecutionResult, _execute_transform
9
+ from libcst.codemod._dummy_pool import DummyPool
10
+ from rich.progress import Progress
11
+
12
+ from ._fake_progress import FakeProgress
13
+
14
+ if TYPE_CHECKING:
15
+ from libcst.codemod import Codemod
16
+
17
+ ProgressType = Union[Type[Progress], Type[FakeProgress]]
18
+ PoolType = Union[Type[Pool], Type[DummyPool]] # type: ignore
19
+
20
+
21
+ def _execute_transform_wrap(
22
+ job: Dict[str, Any],
23
+ ) -> ExecutionResult:
24
+ # TODO: maybe capture warnings?
25
+ with open(os.devnull, "w") as null: # noqa: PTH123
26
+ with contextlib.redirect_stderr(null):
27
+ return _execute_transform(**job)
28
+
29
+
30
+ def _get_progress_and_pool(
31
+ total_files: int, jobs: int
32
+ ) -> tuple[PoolType, ProgressType]:
33
+ poll_impl: PoolType = Pool # type: ignore
34
+ progress_impl: ProgressType = Progress
35
+
36
+ if total_files == 1 or jobs == 1:
37
+ poll_impl = DummyPool
38
+
39
+ if total_files == 1:
40
+ progress_impl = FakeProgress
41
+
42
+ return poll_impl, progress_impl
43
+
44
+
45
+ def run_codemod(
46
+ codemod: Codemod,
47
+ files: Sequence[str],
48
+ ) -> Generator[ExecutionResult, None, None]:
49
+ chunk_size = 4
50
+ total = len(files)
51
+ jobs = min(cpu_count(), (total + chunk_size - 1) // chunk_size)
52
+
53
+ config = ExecutionConfig()
54
+
55
+ pool_impl, progress_impl = _get_progress_and_pool(total, jobs)
56
+
57
+ tasks = [
58
+ {
59
+ "transformer": codemod,
60
+ "filename": filename,
61
+ "config": config,
62
+ }
63
+ for filename in files
64
+ ]
65
+
66
+ with pool_impl(processes=jobs) as p, progress_impl() as progress: # type: ignore
67
+ task_id = progress.add_task("[cyan]Updating...", total=len(tasks))
68
+
69
+ for result in p.imap_unordered(
70
+ _execute_transform_wrap, tasks, chunksize=chunk_size
71
+ ):
72
+ progress.advance(task_id)
73
+
74
+ yield result
File without changes
@@ -0,0 +1,185 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Sequence
4
+
5
+ import libcst as cst
6
+ import libcst.matchers as m
7
+ from libcst._nodes.expression import BaseExpression, Call # noqa: TCH002
8
+ from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
9
+ from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
10
+
11
+
12
+ def _find_named_argument(args: Sequence[cst.Arg], name: str) -> cst.Arg | None:
13
+ return next(
14
+ (arg for arg in args if arg.keyword and arg.keyword.value == name),
15
+ None,
16
+ )
17
+
18
+
19
+ def _find_positional_argument(
20
+ args: Sequence[cst.Arg], search_index: int
21
+ ) -> cst.Arg | None:
22
+ for index, arg in enumerate(args):
23
+ if index > search_index:
24
+ return None
25
+
26
+ if index == search_index and arg.keyword is None:
27
+ return arg
28
+
29
+ return None
30
+
31
+
32
+ class ConvertUnionToAnnotatedUnion(VisitorBasedCodemodCommand):
33
+ DESCRIPTION: str = (
34
+ "Converts strawberry.union(..., types=(...)) to "
35
+ "Annotated[Union[...], strawberry.union(...)]"
36
+ )
37
+
38
+ def __init__(
39
+ self,
40
+ context: CodemodContext,
41
+ use_pipe_syntax: bool = True,
42
+ use_typing_extensions: bool = False,
43
+ ) -> None:
44
+ self._is_using_named_import = False
45
+ self.use_pipe_syntax = use_pipe_syntax
46
+ self.use_typing_extensions = use_typing_extensions
47
+
48
+ super().__init__(context)
49
+
50
+ def visit_Module(self, node: cst.Module) -> Optional[bool]:
51
+ self._is_using_named_import = False
52
+
53
+ return super().visit_Module(node)
54
+
55
+ @m.visit(
56
+ m.ImportFrom(
57
+ m.Name("strawberry"),
58
+ [
59
+ m.ZeroOrMore(),
60
+ m.ImportAlias(m.Name("union")),
61
+ m.ZeroOrMore(),
62
+ ],
63
+ )
64
+ )
65
+ def visit_import_from(self, original_node: cst.ImportFrom) -> None:
66
+ self._is_using_named_import = True
67
+
68
+ @m.leave(
69
+ m.Call(
70
+ func=m.Attribute(value=m.Name("strawberry"), attr=m.Name("union"))
71
+ | m.Name("union")
72
+ )
73
+ )
74
+ def leave_union_call(
75
+ self, original_node: Call, updated_node: Call
76
+ ) -> BaseExpression:
77
+ if not self._is_using_named_import and isinstance(original_node.func, cst.Name):
78
+ return original_node
79
+
80
+ types = _find_named_argument(original_node.args, "types")
81
+ union_name = _find_named_argument(original_node.args, "name")
82
+
83
+ if types is None:
84
+ types = _find_positional_argument(original_node.args, 1)
85
+
86
+ # this is probably a strawberry.union(name="...") so we skip the conversion
87
+ # as it is going to be used in the new way already 😊
88
+
89
+ if types is None:
90
+ return original_node
91
+
92
+ AddImportsVisitor.add_needed_import(
93
+ self.context,
94
+ "typing_extensions" if self.use_typing_extensions else "typing",
95
+ "Annotated",
96
+ )
97
+
98
+ RemoveImportsVisitor.remove_unused_import(self.context, "strawberry", "union")
99
+
100
+ if union_name is None:
101
+ union_name = _find_positional_argument(original_node.args, 0)
102
+
103
+ assert union_name
104
+ assert isinstance(types.value, (cst.Tuple, cst.List))
105
+
106
+ types = types.value.elements # type: ignore
107
+ union_name = union_name.value # type: ignore
108
+
109
+ description = _find_named_argument(original_node.args, "description")
110
+ directives = _find_named_argument(original_node.args, "directives")
111
+
112
+ if self.use_pipe_syntax:
113
+ union_node = self._create_union_node_with_pipe_syntax(types) # type: ignore
114
+ else:
115
+ AddImportsVisitor.add_needed_import(self.context, "typing", "Union")
116
+
117
+ union_node = cst.Subscript(
118
+ value=cst.Name(value="Union"),
119
+ slice=[
120
+ cst.SubscriptElement(slice=cst.Index(value=t.value)) for t in types # type: ignore # noqa: E501
121
+ ],
122
+ )
123
+
124
+ union_call_args = [
125
+ cst.Arg(
126
+ value=union_name, # type: ignore
127
+ keyword=cst.Name(value="name"),
128
+ equal=cst.AssignEqual(
129
+ whitespace_before=cst.SimpleWhitespace(""),
130
+ whitespace_after=cst.SimpleWhitespace(""),
131
+ ),
132
+ )
133
+ ]
134
+
135
+ additional_args = {"description": description, "directives": directives}
136
+
137
+ union_call_args.extend(
138
+ cst.Arg(
139
+ value=arg.value,
140
+ keyword=cst.Name(name),
141
+ equal=cst.AssignEqual(
142
+ whitespace_before=cst.SimpleWhitespace(""),
143
+ whitespace_after=cst.SimpleWhitespace(""),
144
+ ),
145
+ )
146
+ for name, arg in additional_args.items()
147
+ if arg is not None
148
+ )
149
+
150
+ union_call_node = cst.Call(
151
+ func=cst.Attribute(
152
+ value=cst.Name(value="strawberry"),
153
+ attr=cst.Name(value="union"),
154
+ ),
155
+ args=union_call_args,
156
+ )
157
+
158
+ return cst.Subscript(
159
+ value=cst.Name(value="Annotated"),
160
+ slice=[
161
+ cst.SubscriptElement(
162
+ slice=cst.Index(
163
+ value=union_node,
164
+ ),
165
+ ),
166
+ cst.SubscriptElement(
167
+ slice=cst.Index(
168
+ value=union_call_node,
169
+ ),
170
+ ),
171
+ ],
172
+ )
173
+
174
+ @classmethod
175
+ def _create_union_node_with_pipe_syntax(
176
+ cls, types: Sequence[cst.BaseElement]
177
+ ) -> cst.BaseExpression:
178
+ type_names = [t.value for t in types]
179
+
180
+ if not all(isinstance(t, cst.Name) for t in type_names):
181
+ raise ValueError("Only names are supported for now")
182
+
183
+ expression = " | ".join(name.value for name in type_names) # type: ignore
184
+
185
+ return cst.parse_expression(expression)
@@ -20,11 +20,18 @@ class InvalidUnionTypeError(StrawberryException):
20
20
 
21
21
  invalid_type: object
22
22
 
23
- def __init__(self, union_name: str, invalid_type: object) -> None:
23
+ def __init__(
24
+ self,
25
+ union_name: str,
26
+ invalid_type: object,
27
+ union_definition: Optional[StrawberryUnion] = None,
28
+ ) -> None:
24
29
  from strawberry.custom_scalar import ScalarWrapper
30
+ from strawberry.type import StrawberryList
25
31
 
26
32
  self.union_name = union_name
27
33
  self.invalid_type = invalid_type
34
+ self.union_definition = union_definition
28
35
 
29
36
  # assuming that the exception happens two stack frames above the current one.
30
37
  # one is our code checking for invalid types, the other is the caller
@@ -32,6 +39,8 @@ class InvalidUnionTypeError(StrawberryException):
32
39
 
33
40
  if isinstance(invalid_type, ScalarWrapper):
34
41
  type_name = invalid_type.wrap.__name__
42
+ elif isinstance(invalid_type, StrawberryList):
43
+ type_name = "list[...]"
35
44
  else:
36
45
  try:
37
46
  type_name = invalid_type.__name__ # type: ignore
@@ -50,10 +59,21 @@ class InvalidUnionTypeError(StrawberryException):
50
59
 
51
60
  @cached_property
52
61
  def exception_source(self) -> Optional[ExceptionSource]:
53
- path = Path(self.frame.filename)
54
-
55
62
  source_finder = SourceFinder()
56
63
 
64
+ if self.union_definition:
65
+ source = source_finder.find_annotated_union(
66
+ self.union_definition, self.invalid_type
67
+ )
68
+
69
+ if source:
70
+ return source
71
+
72
+ if not self.frame:
73
+ return None
74
+
75
+ path = Path(self.frame.filename)
76
+
57
77
  return source_finder.find_union_call(path, self.union_name, self.invalid_type)
58
78
 
59
79