strawberry-graphql 0.227.0.dev1713475585__py3-none-any.whl → 0.227.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. strawberry/channels/handlers/base.py +7 -14
  2. strawberry/codegen/query_codegen.py +2 -4
  3. strawberry/custom_scalar.py +2 -4
  4. strawberry/dataloader.py +2 -4
  5. strawberry/directive.py +1 -2
  6. strawberry/django/views.py +1 -1
  7. strawberry/enum.py +2 -4
  8. strawberry/experimental/pydantic/conversion_types.py +5 -10
  9. strawberry/experimental/pydantic/error_type.py +1 -1
  10. strawberry/experimental/pydantic/object_type.py +1 -1
  11. strawberry/ext/mypy_plugin.py +24 -1
  12. strawberry/federation/enum.py +2 -4
  13. strawberry/federation/field.py +3 -6
  14. strawberry/federation/object_type.py +8 -25
  15. strawberry/federation/scalar.py +2 -4
  16. strawberry/federation/schema_directive.py +1 -1
  17. strawberry/field.py +3 -6
  18. strawberry/http/async_base_view.py +12 -24
  19. strawberry/http/base.py +3 -6
  20. strawberry/http/sync_base_view.py +13 -26
  21. strawberry/litestar/controller.py +2 -2
  22. strawberry/object_type.py +32 -23
  23. strawberry/parent.py +1 -2
  24. strawberry/printer/printer.py +3 -6
  25. strawberry/private.py +1 -2
  26. strawberry/relay/fields.py +2 -4
  27. strawberry/relay/types.py +12 -24
  28. strawberry/schema/execute.py +0 -5
  29. strawberry/schema/schema.py +0 -11
  30. strawberry/schema/schema_converter.py +1 -35
  31. strawberry/schema_codegen/__init__.py +96 -87
  32. strawberry/schema_directive.py +1 -1
  33. strawberry/starlite/controller.py +2 -2
  34. strawberry/type.py +4 -12
  35. strawberry/types/type_resolver.py +9 -2
  36. strawberry/types/types.py +1 -13
  37. strawberry/utils/aio.py +1 -1
  38. strawberry/utils/typing.py +2 -4
  39. {strawberry_graphql-0.227.0.dev1713475585.dist-info → strawberry_graphql-0.227.1.dist-info}/METADATA +2 -1
  40. {strawberry_graphql-0.227.0.dev1713475585.dist-info → strawberry_graphql-0.227.1.dist-info}/RECORD +43 -46
  41. strawberry/schema/validation_rules/__init__.py +0 -0
  42. strawberry/schema/validation_rules/one_of.py +0 -80
  43. strawberry/schema_directives.py +0 -9
  44. {strawberry_graphql-0.227.0.dev1713475585.dist-info → strawberry_graphql-0.227.1.dist-info}/LICENSE +0 -0
  45. {strawberry_graphql-0.227.0.dev1713475585.dist-info → strawberry_graphql-0.227.1.dist-info}/WHEEL +0 -0
  46. {strawberry_graphql-0.227.0.dev1713475585.dist-info → strawberry_graphql-0.227.1.dist-info}/entry_points.txt +0 -0
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, Tuple, Union
7
7
  from typing_extensions import Protocol, TypeAlias
8
8
 
9
9
  import libcst as cst
10
+ from graphlib import TopologicalSorter
10
11
  from graphql import (
11
12
  EnumTypeDefinitionNode,
12
13
  EnumValueDefinitionNode,
@@ -177,19 +178,6 @@ def _get_argument(name: str, value: ArgumentValue) -> cst.Arg:
177
178
  )
178
179
 
179
180
 
180
- # TODO: this might be removed now
181
- def _get_argument_list(name: str, values: list[ArgumentValue]) -> cst.Arg:
182
- value = cst.List(
183
- elements=[cst.Element(value=_sanitize_argument(value)) for value in values],
184
- )
185
-
186
- return cst.Arg(
187
- value=value,
188
- keyword=cst.Name(name),
189
- equal=cst.AssignEqual(cst.SimpleWhitespace(""), cst.SimpleWhitespace("")),
190
- )
191
-
192
-
193
181
  def _get_field_value(
194
182
  field: FieldDefinitionNode | InputValueDefinitionNode,
195
183
  alias: str | None,
@@ -441,11 +429,11 @@ def _get_class_definition(
441
429
  | InputObjectTypeDefinitionNode,
442
430
  is_apollo_federation: bool,
443
431
  imports: set[Import],
444
- ) -> cst.ClassDef:
432
+ ) -> Definition:
445
433
  decorator = _get_strawberry_decorator(definition, is_apollo_federation, imports)
446
434
 
447
- bases = (
448
- [cst.Arg(cst.Name(interface.name.value)) for interface in definition.interfaces]
435
+ interfaces = (
436
+ [interface.name.value for interface in definition.interfaces]
449
437
  if isinstance(
450
438
  definition, (ObjectTypeDefinitionNode, InterfaceTypeDefinitionNode)
451
439
  )
@@ -453,21 +441,24 @@ def _get_class_definition(
453
441
  else []
454
442
  )
455
443
 
456
- return cst.ClassDef(
444
+ class_definition = cst.ClassDef(
457
445
  name=cst.Name(definition.name.value),
458
- bases=bases,
459
446
  body=cst.IndentedBlock(
460
447
  body=[
461
448
  _get_field(field, is_apollo_federation, imports)
462
449
  for field in definition.fields
463
450
  ]
464
451
  ),
452
+ bases=[cst.Arg(cst.Name(interface)) for interface in interfaces],
465
453
  decorators=[decorator],
466
454
  )
467
455
 
456
+ return Definition(class_definition, interfaces, definition.name.value)
457
+
468
458
 
469
459
  def _get_enum_value(enum_value: EnumValueDefinitionNode) -> cst.SimpleStatementLine:
470
460
  name = enum_value.name.value
461
+
471
462
  return cst.SimpleStatementLine(
472
463
  body=[
473
464
  cst.Assign(
@@ -478,7 +469,7 @@ def _get_enum_value(enum_value: EnumValueDefinitionNode) -> cst.SimpleStatementL
478
469
  )
479
470
 
480
471
 
481
- def _get_enum_definition(definition: EnumTypeDefinitionNode) -> cst.ClassDef:
472
+ def _get_enum_definition(definition: EnumTypeDefinitionNode) -> Definition:
482
473
  decorator = cst.Decorator(
483
474
  decorator=cst.Attribute(
484
475
  value=cst.Name("strawberry"),
@@ -486,7 +477,7 @@ def _get_enum_definition(definition: EnumTypeDefinitionNode) -> cst.ClassDef:
486
477
  ),
487
478
  )
488
479
 
489
- return cst.ClassDef(
480
+ class_definition = cst.ClassDef(
490
481
  name=cst.Name(definition.name.value),
491
482
  bases=[cst.Arg(cst.Name("Enum"))],
492
483
  body=cst.IndentedBlock(
@@ -495,6 +486,12 @@ def _get_enum_definition(definition: EnumTypeDefinitionNode) -> cst.ClassDef:
495
486
  decorators=[decorator],
496
487
  )
497
488
 
489
+ return Definition(
490
+ class_definition,
491
+ [],
492
+ definition.name.value,
493
+ )
494
+
498
495
 
499
496
  def _get_schema_definition(
500
497
  root_query_name: str | None,
@@ -562,38 +559,54 @@ def _get_schema_definition(
562
559
  )
563
560
 
564
561
 
565
- def _get_union_definition(definition: UnionTypeDefinitionNode) -> cst.Assign:
562
+ @dataclasses.dataclass(frozen=True)
563
+ class Definition:
564
+ code: cst.CSTNode
565
+ dependencies: list[str]
566
+ name: str
567
+
568
+
569
+ def _get_union_definition(definition: UnionTypeDefinitionNode) -> Definition:
566
570
  name = definition.name.value
567
571
 
568
572
  types = cst.parse_expression(
569
573
  " | ".join([type_.name.value for type_ in definition.types])
570
574
  )
571
575
 
572
- return cst.Assign(
573
- targets=[cst.AssignTarget(cst.Name(name))],
574
- value=cst.Subscript(
575
- value=cst.Name("Annotated"),
576
- slice=[
577
- cst.SubscriptElement(slice=cst.Index(types)),
578
- cst.SubscriptElement(
579
- slice=cst.Index(
580
- cst.Call(
581
- cst.Attribute(
582
- value=cst.Name("strawberry"),
583
- attr=cst.Name("union"),
584
- ),
585
- args=[_get_argument("name", name)],
586
- )
587
- )
576
+ simple_statement = cst.SimpleStatementLine(
577
+ body=[
578
+ cst.Assign(
579
+ targets=[cst.AssignTarget(cst.Name(name))],
580
+ value=cst.Subscript(
581
+ value=cst.Name("Annotated"),
582
+ slice=[
583
+ cst.SubscriptElement(slice=cst.Index(types)),
584
+ cst.SubscriptElement(
585
+ slice=cst.Index(
586
+ cst.Call(
587
+ cst.Attribute(
588
+ value=cst.Name("strawberry"),
589
+ attr=cst.Name("union"),
590
+ ),
591
+ args=[_get_argument("name", name)],
592
+ )
593
+ )
594
+ ),
595
+ ],
588
596
  ),
589
- ],
590
- ),
597
+ )
598
+ ]
599
+ )
600
+ return Definition(
601
+ simple_statement,
602
+ [],
603
+ definition.name.value,
591
604
  )
592
605
 
593
606
 
594
607
  def _get_scalar_definition(
595
608
  definition: ScalarTypeDefinitionNode, imports: set[Import]
596
- ) -> cst.SimpleStatementLine | None:
609
+ ) -> Definition | None:
597
610
  name = definition.name.value
598
611
 
599
612
  if name == "Date":
@@ -652,7 +665,7 @@ def _get_scalar_definition(
652
665
  ),
653
666
  ]
654
667
 
655
- return cst.SimpleStatementLine(
668
+ statement_definition = cst.SimpleStatementLine(
656
669
  body=[
657
670
  cst.Assign(
658
671
  targets=[cst.AssignTarget(cst.Name(name))],
@@ -677,12 +690,13 @@ def _get_scalar_definition(
677
690
  )
678
691
  ]
679
692
  )
693
+ return Definition(statement_definition, [], name=definition.name.value)
680
694
 
681
695
 
682
696
  def codegen(schema: str) -> str:
683
697
  document = parse(schema)
684
698
 
685
- definitions: list[cst.CSTNode] = []
699
+ definitions: dict[str, Definition] = {}
686
700
 
687
701
  root_query_name: str | None = None
688
702
  root_mutation_name: str | None = None
@@ -692,17 +706,17 @@ def codegen(schema: str) -> str:
692
706
  Import(module=None, imports=("strawberry",)),
693
707
  }
694
708
 
695
- object_types: dict[str, cst.ClassDef] = {}
696
-
697
709
  # when we encounter a extend schema @link ..., we check if is an apollo federation schema
698
710
  # and we use this variable to keep track of it, but at the moment the assumption is that
699
711
  # the schema extension is always done at the top, this might not be the case all the
700
712
  # time
701
713
  is_apollo_federation = False
702
714
 
703
- for definition in document.definitions:
715
+ for graphql_definition in document.definitions:
716
+ definition: Definition | None = None
717
+
704
718
  if isinstance(
705
- definition,
719
+ graphql_definition,
706
720
  (
707
721
  ObjectTypeDefinitionNode,
708
722
  InterfaceTypeDefinitionNode,
@@ -710,23 +724,17 @@ def codegen(schema: str) -> str:
710
724
  ObjectTypeExtensionNode,
711
725
  ),
712
726
  ):
713
- class_definition = _get_class_definition(
714
- definition, is_apollo_federation, imports
727
+ definition = _get_class_definition(
728
+ graphql_definition, is_apollo_federation, imports
715
729
  )
716
730
 
717
- object_types[definition.name.value] = class_definition
718
-
719
- definitions.append(cst.EmptyLine())
720
- definitions.append(class_definition)
721
-
722
- elif isinstance(definition, EnumTypeDefinitionNode):
731
+ elif isinstance(graphql_definition, EnumTypeDefinitionNode):
723
732
  imports.add(Import(module="enum", imports=("Enum",)))
724
733
 
725
- definitions.append(cst.EmptyLine())
726
- definitions.append(_get_enum_definition(definition))
734
+ definition = _get_enum_definition(graphql_definition)
727
735
 
728
- elif isinstance(definition, SchemaDefinitionNode):
729
- for operation_type_definition in definition.operation_types:
736
+ elif isinstance(graphql_definition, SchemaDefinitionNode):
737
+ for operation_type_definition in graphql_definition.operation_types:
730
738
  if operation_type_definition.operation == OperationType.QUERY:
731
739
  root_query_name = operation_type_definition.type.name.value
732
740
  elif operation_type_definition.operation == OperationType.MUTATION:
@@ -737,36 +745,33 @@ def codegen(schema: str) -> str:
737
745
  raise NotImplementedError(
738
746
  f"Unknown operation {operation_type_definition.operation}"
739
747
  )
740
- elif isinstance(definition, UnionTypeDefinitionNode):
748
+ elif isinstance(graphql_definition, UnionTypeDefinitionNode):
741
749
  imports.add(Import(module="typing", imports=("Annotated",)))
742
750
 
743
- definitions.append(cst.EmptyLine())
744
- definitions.append(_get_union_definition(definition))
745
- definitions.append(cst.EmptyLine())
746
- elif isinstance(definition, ScalarTypeDefinitionNode):
747
- scalar_definition = _get_scalar_definition(definition, imports)
748
-
749
- if scalar_definition is not None:
750
- definitions.append(cst.EmptyLine())
751
- definitions.append(scalar_definition)
752
- definitions.append(cst.EmptyLine())
753
- elif isinstance(definition, SchemaExtensionNode):
751
+ definition = _get_union_definition(graphql_definition)
752
+ elif isinstance(graphql_definition, ScalarTypeDefinitionNode):
753
+ definition = _get_scalar_definition(graphql_definition, imports)
754
+
755
+ elif isinstance(graphql_definition, SchemaExtensionNode):
754
756
  is_apollo_federation = any(
755
757
  _is_federation_link_directive(directive)
756
- for directive in definition.directives
758
+ for directive in graphql_definition.directives
757
759
  )
758
760
  else:
759
761
  raise NotImplementedError(f"Unknown definition {definition}")
760
762
 
763
+ if definition is not None:
764
+ definitions[definition.name] = definition
765
+
761
766
  if root_query_name is None:
762
- root_query_name = "Query" if "Query" in object_types else None
767
+ root_query_name = "Query" if "Query" in definitions else None
763
768
 
764
769
  if root_mutation_name is None:
765
- root_mutation_name = "Mutation" if "Mutation" in object_types else None
770
+ root_mutation_name = "Mutation" if "Mutation" in definitions else None
766
771
 
767
772
  if root_subscription_name is None:
768
773
  root_subscription_name = (
769
- "Subscription" if "Subscription" in object_types else None
774
+ "Subscription" if "Subscription" in definitions else None
770
775
  )
771
776
 
772
777
  schema_definition = _get_schema_definition(
@@ -777,19 +782,23 @@ def codegen(schema: str) -> str:
777
782
  )
778
783
 
779
784
  if schema_definition:
780
- definitions.append(cst.EmptyLine())
781
- definitions.append(schema_definition)
785
+ definitions["Schema"] = Definition(schema_definition, [], "schema")
782
786
 
783
- module = cst.Module(
784
- body=[
785
- *[
786
- cst.SimpleStatementLine(body=[import_.to_cst()])
787
- for import_ in sorted(
788
- imports, key=lambda i: (i.module or "", i.imports)
789
- )
790
- ],
791
- *definitions, # type: ignore
792
- ]
793
- )
787
+ body: list[cst.CSTNode] = [
788
+ cst.SimpleStatementLine(body=[import_.to_cst()])
789
+ for import_ in sorted(imports, key=lambda i: (i.module or "", i.imports))
790
+ ]
791
+
792
+ # DAG to sort definitions based on dependencies
793
+ graph = {name: definition.dependencies for name, definition in definitions.items()}
794
+ ts = TopologicalSorter(graph)
795
+
796
+ for definition_name in tuple(ts.static_order()):
797
+ definition = definitions[definition_name]
798
+
799
+ body.append(cst.EmptyLine())
800
+ body.append(definition.code)
801
+
802
+ module = cst.Module(body=body) # type: ignore
794
803
 
795
804
  return module.code
@@ -54,7 +54,7 @@ def schema_directive(
54
54
  ) -> Callable[..., T]:
55
55
  def _wrap(cls: T) -> T:
56
56
  cls = _wrap_dataclass(cls)
57
- fields = _get_fields(cls)
57
+ fields = _get_fields(cls, {})
58
58
 
59
59
  cls.__strawberry_directive__ = StrawberrySchemaDirective(
60
60
  python_name=cls.__name__,
@@ -222,9 +222,9 @@ def make_graphql_controller(
222
222
  "response": Provide(response_getter),
223
223
  }
224
224
  graphql_ws_handler_class: Type[GraphQLWSHandler] = GraphQLWSHandler
225
- graphql_transport_ws_handler_class: Type[
225
+ graphql_transport_ws_handler_class: Type[GraphQLTransportWSHandler] = (
226
226
  GraphQLTransportWSHandler
227
- ] = GraphQLTransportWSHandler
227
+ )
228
228
 
229
229
  _keep_alive: bool = keep_alive
230
230
  _keep_alive_interval: float = keep_alive_interval
strawberry/type.py CHANGED
@@ -37,10 +37,6 @@ class StrawberryType(ABC):
37
37
  def type_params(self) -> List[TypeVar]:
38
38
  return []
39
39
 
40
- @property
41
- def is_one_of(self) -> bool:
42
- return False
43
-
44
40
  @abstractmethod
45
41
  def copy_with(
46
42
  self,
@@ -147,12 +143,10 @@ class StrawberryContainer(StrawberryType):
147
143
  return False
148
144
 
149
145
 
150
- class StrawberryList(StrawberryContainer):
151
- ...
146
+ class StrawberryList(StrawberryContainer): ...
152
147
 
153
148
 
154
- class StrawberryOptional(StrawberryContainer):
155
- ...
149
+ class StrawberryOptional(StrawberryContainer): ...
156
150
 
157
151
 
158
152
  class StrawberryTypeVar(StrawberryType):
@@ -212,8 +206,7 @@ def get_object_definition(
212
206
  obj: Any,
213
207
  *,
214
208
  strict: Literal[True],
215
- ) -> StrawberryObjectDefinition:
216
- ...
209
+ ) -> StrawberryObjectDefinition: ...
217
210
 
218
211
 
219
212
  @overload
@@ -221,8 +214,7 @@ def get_object_definition(
221
214
  obj: Any,
222
215
  *,
223
216
  strict: bool = False,
224
- ) -> Optional[StrawberryObjectDefinition]:
225
- ...
217
+ ) -> Optional[StrawberryObjectDefinition]: ...
226
218
 
227
219
 
228
220
  def get_object_definition(
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
4
  import sys
5
- from typing import Dict, List, Type
5
+ from typing import Any, Dict, List, Type
6
6
 
7
7
  from strawberry.annotation import StrawberryAnnotation
8
8
  from strawberry.exceptions import (
@@ -16,7 +16,9 @@ from strawberry.type import has_object_definition
16
16
  from strawberry.unset import UNSET
17
17
 
18
18
 
19
- def _get_fields(cls: Type) -> List[StrawberryField]:
19
+ def _get_fields(
20
+ cls: Type[Any], original_type_annotations: Dict[str, Type[Any]]
21
+ ) -> List[StrawberryField]:
20
22
  """Get all the strawberry fields off a strawberry.type cls
21
23
 
22
24
  This function returns a list of StrawberryFields (one for each field item), while
@@ -49,6 +51,7 @@ def _get_fields(cls: Type) -> List[StrawberryField]:
49
51
  passing a named function (i.e. not an anonymous lambda) to strawberry.field
50
52
  (typically as a decorator).
51
53
  """
54
+
52
55
  fields: Dict[str, StrawberryField] = {}
53
56
 
54
57
  # before trying to find any fields, let's first add the fields defined in
@@ -152,6 +155,10 @@ def _get_fields(cls: Type) -> List[StrawberryField]:
152
155
  assert_message = "Field must have a name by the time the schema is generated"
153
156
  assert field_name is not None, assert_message
154
157
 
158
+ if field.name in original_type_annotations:
159
+ field.type = original_type_annotations[field.name]
160
+ field.type_annotation = StrawberryAnnotation(annotation=field.type)
161
+
155
162
  # TODO: Raise exception if field_name already in fields
156
163
  fields[field_name] = field
157
164
 
strawberry/types/types.py CHANGED
@@ -197,24 +197,12 @@ class StrawberryObjectDefinition(StrawberryType):
197
197
  # All field mappings succeeded. This is a match
198
198
  return True
199
199
 
200
- @property
201
- def is_one_of(self) -> bool:
202
- from strawberry.schema_directives import OneOf
203
-
204
- if not self.is_input or not self.directives:
205
- return False
206
-
207
- return any(
208
- directive for directive in self.directives if isinstance(directive, OneOf)
209
- )
210
-
211
200
 
212
201
  # TODO: remove when deprecating _type_definition
213
202
  if TYPE_CHECKING:
214
203
 
215
204
  @deprecated("Use StrawberryObjectDefinition instead")
216
- class TypeDefinition(StrawberryObjectDefinition):
217
- ...
205
+ class TypeDefinition(StrawberryObjectDefinition): ...
218
206
 
219
207
  else:
220
208
  TypeDefinition = StrawberryObjectDefinition
strawberry/utils/aio.py CHANGED
@@ -24,7 +24,7 @@ async def aenumerate(
24
24
  i = 0
25
25
  async for element in iterable:
26
26
  yield i, element
27
- i += 1 # noqa: SIM113
27
+ i += 1
28
28
 
29
29
 
30
30
  async def aislice(
@@ -204,13 +204,11 @@ def get_parameters(annotation: Type) -> Union[Tuple[object], Tuple[()]]:
204
204
 
205
205
 
206
206
  @overload
207
- def _ast_replace_union_operation(expr: ast.expr) -> ast.expr:
208
- ...
207
+ def _ast_replace_union_operation(expr: ast.expr) -> ast.expr: ...
209
208
 
210
209
 
211
210
  @overload
212
- def _ast_replace_union_operation(expr: ast.Expr) -> ast.Expr:
213
- ...
211
+ def _ast_replace_union_operation(expr: ast.Expr) -> ast.Expr: ...
214
212
 
215
213
 
216
214
  def _ast_replace_union_operation(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: strawberry-graphql
3
- Version: 0.227.0.dev1713475585
3
+ Version: 0.227.1
4
4
  Summary: A library for creating GraphQL APIs
5
5
  Home-page: https://strawberry.rocks/
6
6
  License: MIT
@@ -44,6 +44,7 @@ Requires-Dist: chalice (>=1.22,<2.0) ; extra == "chalice"
44
44
  Requires-Dist: channels (>=3.0.5) ; extra == "channels"
45
45
  Requires-Dist: fastapi (>=0.65.2) ; extra == "fastapi"
46
46
  Requires-Dist: flask (>=1.1) ; extra == "flask"
47
+ Requires-Dist: graphlib_backport ; (python_version < "3.9") and (extra == "cli")
47
48
  Requires-Dist: graphql-core (>=3.2.0,<3.3.0)
48
49
  Requires-Dist: libcst (>=0.4.7) ; extra == "debug" or extra == "debug-server" or extra == "cli"
49
50
  Requires-Dist: litestar (>=2) ; (python_version >= "3.8") and (extra == "litestar")