strawberry-graphql 0.220.0.dev1709543239__py3-none-any.whl → 0.221.0.dev1710955937__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -69,6 +69,7 @@ class GraphQLWebsocketCommunicator(WebsocketCommunicator):
69
69
  path: str,
70
70
  headers: Optional[List[Tuple[bytes, bytes]]] = None,
71
71
  protocol: str = GRAPHQL_TRANSPORT_WS_PROTOCOL,
72
+ connection_params: dict = {},
72
73
  **kwargs: Any,
73
74
  ):
74
75
  """
@@ -81,6 +82,7 @@ class GraphQLWebsocketCommunicator(WebsocketCommunicator):
81
82
  self.protocol = protocol
82
83
  subprotocols = kwargs.get("subprotocols", [])
83
84
  subprotocols.append(protocol)
85
+ self.connection_params = connection_params
84
86
  super().__init__(application, path, headers, subprotocols=subprotocols)
85
87
 
86
88
  async def __aenter__(self) -> Self:
@@ -99,7 +101,9 @@ class GraphQLWebsocketCommunicator(WebsocketCommunicator):
99
101
  res = await self.connect()
100
102
  if self.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
101
103
  assert res == (True, GRAPHQL_TRANSPORT_WS_PROTOCOL)
102
- await self.send_json_to(ConnectionInitMessage().as_dict())
104
+ await self.send_json_to(
105
+ ConnectionInitMessage(payload=self.connection_params).as_dict()
106
+ )
103
107
  response = await self.receive_json_from()
104
108
  assert response == ConnectionAckMessage().as_dict()
105
109
  else:
@@ -88,82 +88,77 @@ async def execute(
88
88
  extensions=list(extensions),
89
89
  )
90
90
 
91
- try:
92
- async with extensions_runner.operation():
93
- # Note: In graphql-core the schema would be validated here but in
94
- # Strawberry we are validating it at initialisation time instead
95
- if not execution_context.query:
96
- raise MissingQueryError()
97
-
98
- async with extensions_runner.parsing():
99
- try:
100
- if not execution_context.graphql_document:
101
- execution_context.graphql_document = parse_document(
102
- execution_context.query, **execution_context.parse_options
103
- )
104
-
105
- except GraphQLError as exc:
106
- execution_context.errors = [exc]
107
- process_errors([exc], execution_context)
108
- return ExecutionResult(
109
- data=None,
110
- errors=[exc],
111
- extensions=await extensions_runner.get_extensions_results(),
91
+ async with extensions_runner.operation():
92
+ # Note: In graphql-core the schema would be validated here but in
93
+ # Strawberry we are validating it at initialisation time instead
94
+ if not execution_context.query:
95
+ raise MissingQueryError()
96
+
97
+ async with extensions_runner.parsing():
98
+ try:
99
+ if not execution_context.graphql_document:
100
+ execution_context.graphql_document = parse_document(
101
+ execution_context.query, **execution_context.parse_options
112
102
  )
113
103
 
114
- if execution_context.operation_type not in allowed_operation_types:
115
- raise InvalidOperationTypeError(execution_context.operation_type)
116
-
117
- async with extensions_runner.validation():
118
- _run_validation(execution_context)
119
- if execution_context.errors:
120
- process_errors(execution_context.errors, execution_context)
121
- return ExecutionResult(data=None, errors=execution_context.errors)
122
-
123
- async with extensions_runner.executing():
124
- if not execution_context.result:
125
- result = original_execute(
126
- schema,
127
- execution_context.graphql_document,
128
- root_value=execution_context.root_value,
129
- middleware=extensions_runner.as_middleware_manager(),
130
- variable_values=execution_context.variables,
131
- operation_name=execution_context.operation_name,
132
- context_value=execution_context.context,
133
- execution_context_class=execution_context_class,
134
- )
135
-
136
- if isawaitable(result):
137
- result = await cast(Awaitable["GraphQLExecutionResult"], result)
138
-
139
- result = cast("GraphQLExecutionResult", result)
140
- execution_context.result = result
141
- # Also set errors on the execution_context so that it's easier
142
- # to access in extensions
143
- if result.errors:
144
- execution_context.errors = result.errors
145
-
146
- # Run the `Schema.process_errors` function here before
147
- # extensions have a chance to modify them (see the MaskErrors
148
- # extension). That way we can log the original errors but
149
- # only return a sanitised version to the client.
150
- process_errors(result.errors, execution_context)
151
-
152
- except (MissingQueryError, InvalidOperationTypeError) as e:
153
- raise e
154
- except Exception as exc:
155
- error = (
156
- exc
157
- if isinstance(exc, GraphQLError)
158
- else GraphQLError(str(exc), original_error=exc)
159
- )
160
- execution_context.errors = [error]
161
- process_errors([error], execution_context)
162
- return ExecutionResult(
163
- data=None,
164
- errors=[error],
165
- extensions=await extensions_runner.get_extensions_results(),
166
- )
104
+ except GraphQLError as error:
105
+ execution_context.errors = [error]
106
+ process_errors([error], execution_context)
107
+ return ExecutionResult(
108
+ data=None,
109
+ errors=[error],
110
+ extensions=await extensions_runner.get_extensions_results(),
111
+ )
112
+
113
+ except Exception as error: # pragma: no cover
114
+ error = GraphQLError(str(error), original_error=error)
115
+
116
+ execution_context.errors = [error]
117
+ process_errors([error], execution_context)
118
+
119
+ return ExecutionResult(
120
+ data=None,
121
+ errors=[error],
122
+ extensions=await extensions_runner.get_extensions_results(),
123
+ )
124
+
125
+ if execution_context.operation_type not in allowed_operation_types:
126
+ raise InvalidOperationTypeError(execution_context.operation_type)
127
+
128
+ async with extensions_runner.validation():
129
+ _run_validation(execution_context)
130
+ if execution_context.errors:
131
+ process_errors(execution_context.errors, execution_context)
132
+ return ExecutionResult(data=None, errors=execution_context.errors)
133
+
134
+ async with extensions_runner.executing():
135
+ if not execution_context.result:
136
+ result = original_execute(
137
+ schema,
138
+ execution_context.graphql_document,
139
+ root_value=execution_context.root_value,
140
+ middleware=extensions_runner.as_middleware_manager(),
141
+ variable_values=execution_context.variables,
142
+ operation_name=execution_context.operation_name,
143
+ context_value=execution_context.context,
144
+ execution_context_class=execution_context_class,
145
+ )
146
+
147
+ if isawaitable(result):
148
+ result = await cast(Awaitable["GraphQLExecutionResult"], result)
149
+
150
+ result = cast("GraphQLExecutionResult", result)
151
+ execution_context.result = result
152
+ # Also set errors on the execution_context so that it's easier
153
+ # to access in extensions
154
+ if result.errors:
155
+ execution_context.errors = result.errors
156
+
157
+ # Run the `Schema.process_errors` function here before
158
+ # extensions have a chance to modify them (see the MaskErrors
159
+ # extension). That way we can log the original errors but
160
+ # only return a sanitised version to the client.
161
+ process_errors(result.errors, execution_context)
167
162
 
168
163
  return ExecutionResult(
169
164
  data=execution_context.result.data,
@@ -186,86 +181,80 @@ def execute_sync(
186
181
  extensions=list(extensions),
187
182
  )
188
183
 
189
- try:
190
- with extensions_runner.operation():
191
- # Note: In graphql-core the schema would be validated here but in
192
- # Strawberry we are validating it at initialisation time instead
193
- if not execution_context.query:
194
- raise MissingQueryError()
195
-
196
- with extensions_runner.parsing():
197
- try:
198
- if not execution_context.graphql_document:
199
- execution_context.graphql_document = parse_document(
200
- execution_context.query, **execution_context.parse_options
201
- )
202
-
203
- except GraphQLError as exc:
204
- execution_context.errors = [exc]
205
- process_errors([exc], execution_context)
206
- return ExecutionResult(
207
- data=None,
208
- errors=[exc],
209
- extensions=extensions_runner.get_extensions_results_sync(),
184
+ with extensions_runner.operation():
185
+ # Note: In graphql-core the schema would be validated here but in
186
+ # Strawberry we are validating it at initialisation time instead
187
+ if not execution_context.query:
188
+ raise MissingQueryError()
189
+
190
+ with extensions_runner.parsing():
191
+ try:
192
+ if not execution_context.graphql_document:
193
+ execution_context.graphql_document = parse_document(
194
+ execution_context.query, **execution_context.parse_options
210
195
  )
211
196
 
212
- if execution_context.operation_type not in allowed_operation_types:
213
- raise InvalidOperationTypeError(execution_context.operation_type)
214
-
215
- with extensions_runner.validation():
216
- _run_validation(execution_context)
217
- if execution_context.errors:
218
- process_errors(execution_context.errors, execution_context)
219
- return ExecutionResult(data=None, errors=execution_context.errors)
220
-
221
- with extensions_runner.executing():
222
- if not execution_context.result:
223
- result = original_execute(
224
- schema,
225
- execution_context.graphql_document,
226
- root_value=execution_context.root_value,
227
- middleware=extensions_runner.as_middleware_manager(),
228
- variable_values=execution_context.variables,
229
- operation_name=execution_context.operation_name,
230
- context_value=execution_context.context,
231
- execution_context_class=execution_context_class,
197
+ except GraphQLError as error:
198
+ execution_context.errors = [error]
199
+ process_errors([error], execution_context)
200
+ return ExecutionResult(
201
+ data=None,
202
+ errors=[error],
203
+ extensions=extensions_runner.get_extensions_results_sync(),
204
+ )
205
+
206
+ except Exception as error: # pragma: no cover
207
+ error = GraphQLError(str(error), original_error=error)
208
+
209
+ execution_context.errors = [error]
210
+ process_errors([error], execution_context)
211
+ return ExecutionResult(
212
+ data=None,
213
+ errors=[error],
214
+ extensions=extensions_runner.get_extensions_results_sync(),
215
+ )
216
+
217
+ if execution_context.operation_type not in allowed_operation_types:
218
+ raise InvalidOperationTypeError(execution_context.operation_type)
219
+
220
+ with extensions_runner.validation():
221
+ _run_validation(execution_context)
222
+ if execution_context.errors:
223
+ process_errors(execution_context.errors, execution_context)
224
+ return ExecutionResult(data=None, errors=execution_context.errors)
225
+
226
+ with extensions_runner.executing():
227
+ if not execution_context.result:
228
+ result = original_execute(
229
+ schema,
230
+ execution_context.graphql_document,
231
+ root_value=execution_context.root_value,
232
+ middleware=extensions_runner.as_middleware_manager(),
233
+ variable_values=execution_context.variables,
234
+ operation_name=execution_context.operation_name,
235
+ context_value=execution_context.context,
236
+ execution_context_class=execution_context_class,
237
+ )
238
+
239
+ if isawaitable(result):
240
+ result = cast(Awaitable["GraphQLExecutionResult"], result)
241
+ ensure_future(result).cancel()
242
+ raise RuntimeError(
243
+ "GraphQL execution failed to complete synchronously."
232
244
  )
233
245
 
234
- if isawaitable(result):
235
- result = cast(Awaitable["GraphQLExecutionResult"], result)
236
- ensure_future(result).cancel()
237
- raise RuntimeError(
238
- "GraphQL execution failed to complete synchronously."
239
- )
240
-
241
- result = cast("GraphQLExecutionResult", result)
242
- execution_context.result = result
243
- # Also set errors on the execution_context so that it's easier
244
- # to access in extensions
245
- if result.errors:
246
- execution_context.errors = result.errors
247
-
248
- # Run the `Schema.process_errors` function here before
249
- # extensions have a chance to modify them (see the MaskErrors
250
- # extension). That way we can log the original errors but
251
- # only return a sanitised version to the client.
252
- process_errors(result.errors, execution_context)
253
-
254
- except (MissingQueryError, InvalidOperationTypeError) as e:
255
- raise e
256
- except Exception as exc:
257
- error = (
258
- exc
259
- if isinstance(exc, GraphQLError)
260
- else GraphQLError(str(exc), original_error=exc)
261
- )
262
- execution_context.errors = [error]
263
- process_errors([error], execution_context)
264
- return ExecutionResult(
265
- data=None,
266
- errors=[error],
267
- extensions=extensions_runner.get_extensions_results_sync(),
268
- )
246
+ result = cast("GraphQLExecutionResult", result)
247
+ execution_context.result = result
248
+ # Also set errors on the execution_context so that it's easier
249
+ # to access in extensions
250
+ if result.errors:
251
+ execution_context.errors = result.errors
252
+
253
+ # Run the `Schema.process_errors` function here before
254
+ # extensions have a chance to modify them (see the MaskErrors
255
+ # extension). That way we can log the original errors but
256
+ # only return a sanitised version to the client.
257
+ process_errors(result.errors, execution_context)
269
258
 
270
259
  return ExecutionResult(
271
260
  data=execution_context.result.data,
@@ -2,6 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
4
  import keyword
5
+ from collections import defaultdict
6
+ from typing import TYPE_CHECKING, Tuple
7
+ from typing_extensions import Protocol
5
8
 
6
9
  import libcst as cst
7
10
  from graphql import (
@@ -19,6 +22,7 @@ from graphql import (
19
22
  OperationType,
20
23
  ScalarTypeDefinitionNode,
21
24
  SchemaDefinitionNode,
25
+ SchemaExtensionNode,
22
26
  StringValueNode,
23
27
  TypeNode,
24
28
  UnionTypeDefinitionNode,
@@ -27,6 +31,14 @@ from graphql import (
27
31
 
28
32
  from strawberry.utils.str_converters import to_snake_case
29
33
 
34
+ if TYPE_CHECKING:
35
+ from graphql.language.ast import ConstDirectiveNode
36
+
37
+
38
+ class HasDirectives(Protocol):
39
+ directives: Tuple[ConstDirectiveNode]
40
+
41
+
30
42
  _SCALAR_MAP = {
31
43
  "Int": cst.Name("int"),
32
44
  "Float": cst.Name("float"),
@@ -48,6 +60,19 @@ _SCALAR_MAP = {
48
60
  }
49
61
 
50
62
 
63
+ def _is_federation_link_directive(directive: ConstDirectiveNode) -> bool:
64
+ if directive.name.value != "link":
65
+ return False
66
+
67
+ for argument in directive.arguments:
68
+ if argument.name.value == "url":
69
+ return argument.value.value.startswith(
70
+ "https://specs.apollo.dev/federation"
71
+ )
72
+
73
+ return False
74
+
75
+
51
76
  def _get_field_type(
52
77
  field_type: TypeNode, was_non_nullable: bool = False
53
78
  ) -> cst.BaseExpression:
@@ -85,7 +110,10 @@ def _get_field_type(
85
110
  )
86
111
 
87
112
 
88
- def _get_argument(name: str, value: str) -> cst.Arg:
113
+ def _sanitize_argument(value: str | bool) -> cst.SimpleString | cst.Name:
114
+ if isinstance(value, bool):
115
+ return cst.Name(value=str(value))
116
+
89
117
  if "\n" in value:
90
118
  argument_value = cst.SimpleString(f'"""\n{value}\n"""')
91
119
  elif '"' in value:
@@ -93,6 +121,12 @@ def _get_argument(name: str, value: str) -> cst.Arg:
93
121
  else:
94
122
  argument_value = cst.SimpleString(f'"{value}"')
95
123
 
124
+ return argument_value
125
+
126
+
127
+ def _get_argument(name: str, value: str | bool) -> cst.Arg:
128
+ argument_value = _sanitize_argument(value)
129
+
96
130
  return cst.Arg(
97
131
  value=argument_value,
98
132
  keyword=cst.Name(name),
@@ -100,7 +134,25 @@ def _get_argument(name: str, value: str) -> cst.Arg:
100
134
  )
101
135
 
102
136
 
103
- def _get_field_value(description: str | None, alias: str | None) -> cst.Call | None:
137
+ def _get_argument_list(name: str, values: list[str]) -> cst.Arg:
138
+ value = cst.List(
139
+ elements=[cst.Element(value=_sanitize_argument(value)) for value in values],
140
+ )
141
+
142
+ return cst.Arg(
143
+ value=value,
144
+ keyword=cst.Name(name),
145
+ equal=cst.AssignEqual(cst.SimpleWhitespace(""), cst.SimpleWhitespace("")),
146
+ )
147
+
148
+
149
+ def _get_field_value(
150
+ field: FieldDefinitionNode | InputValueDefinitionNode,
151
+ alias: str | None,
152
+ is_apollo_federation: bool,
153
+ ) -> cst.Call | None:
154
+ description = field.description.value if field.description else None
155
+
104
156
  args = list(
105
157
  filter(
106
158
  None,
@@ -111,6 +163,24 @@ def _get_field_value(description: str | None, alias: str | None) -> cst.Call | N
111
163
  )
112
164
  )
113
165
 
166
+ directives = _get_directives(field)
167
+
168
+ apollo_federation_args = _get_federation_arguments(directives)
169
+
170
+ if is_apollo_federation and apollo_federation_args:
171
+ args.extend(apollo_federation_args)
172
+
173
+ return cst.Call(
174
+ func=cst.Attribute(
175
+ value=cst.Attribute(
176
+ value=cst.Name("strawberry"),
177
+ attr=cst.Name("federation"),
178
+ ),
179
+ attr=cst.Name("field"),
180
+ ),
181
+ args=args,
182
+ )
183
+
114
184
  if args:
115
185
  return cst.Call(
116
186
  func=cst.Attribute(
@@ -125,6 +195,7 @@ def _get_field_value(description: str | None, alias: str | None) -> cst.Call | N
125
195
 
126
196
  def _get_field(
127
197
  field: FieldDefinitionNode | InputValueDefinitionNode,
198
+ is_apollo_federation: bool,
128
199
  ) -> cst.SimpleStatementLine:
129
200
  name = to_snake_case(field.name.value)
130
201
  alias: str | None = None
@@ -141,19 +212,67 @@ def _get_field(
141
212
  _get_field_type(field.type),
142
213
  ),
143
214
  value=_get_field_value(
144
- description=field.description.value if field.description else None,
145
- alias=alias if alias != name else None,
215
+ field, alias=alias, is_apollo_federation=is_apollo_federation
146
216
  ),
147
217
  )
148
218
  ]
149
219
  )
150
220
 
151
221
 
222
+ def _get_directives(definition: HasDirectives) -> dict[str, list[dict[str, str]]]:
223
+ directives = defaultdict(list)
224
+
225
+ for directive in definition.directives:
226
+ directive_name = directive.name.value
227
+
228
+ directives[directive_name].append(
229
+ {
230
+ argument.name.value: argument.value.value
231
+ for argument in directive.arguments
232
+ }
233
+ )
234
+
235
+ return directives
236
+
237
+
238
+ def _get_federation_arguments(
239
+ directives: dict[str, list[dict[str, str]]],
240
+ ) -> list[cst.Arg]:
241
+ def append_arg_from_directive(
242
+ directive: str, argument_name: str, keyword_name: str | None = None
243
+ ):
244
+ keyword_name = keyword_name or directive
245
+
246
+ if directive in directives:
247
+ arguments.append(
248
+ _get_argument_list(
249
+ keyword_name,
250
+ [item[argument_name] for item in directives[directive]],
251
+ )
252
+ )
253
+
254
+ arguments: list[cst.Arg] = []
255
+
256
+ append_arg_from_directive("key", "fields", "keys")
257
+ append_arg_from_directive("requires", "fields")
258
+ append_arg_from_directive("provides", "fields")
259
+ append_arg_from_directive("tag", "name", "tags")
260
+
261
+ boolean_keys = ("shareable", "inaccessible", "external", "override")
262
+
263
+ arguments.extend(
264
+ _get_argument(key, True) for key in boolean_keys if directives.get(key, False)
265
+ )
266
+
267
+ return arguments
268
+
269
+
152
270
  def _get_strawberry_decorator(
153
271
  definition: ObjectTypeDefinitionNode
154
272
  | ObjectTypeExtensionNode
155
273
  | InterfaceTypeDefinitionNode
156
274
  | InputObjectTypeDefinitionNode,
275
+ is_apollo_federation: bool,
157
276
  ) -> cst.Decorator:
158
277
  type_ = {
159
278
  ObjectTypeDefinitionNode: "type",
@@ -168,15 +287,36 @@ def _get_strawberry_decorator(
168
287
  else None
169
288
  )
170
289
 
290
+ directives = _get_directives(definition)
291
+
171
292
  decorator: cst.BaseExpression = cst.Attribute(
172
293
  value=cst.Name("strawberry"),
173
294
  attr=cst.Name(type_),
174
295
  )
175
296
 
297
+ arguments: list[cst.Arg] = []
298
+
176
299
  if description is not None:
300
+ arguments.append(_get_argument("description", description.value))
301
+
302
+ federation_arguments = _get_federation_arguments(directives)
303
+
304
+ # and has any directive that is a federation directive
305
+ if is_apollo_federation and federation_arguments:
306
+ decorator = cst.Attribute(
307
+ value=cst.Attribute(
308
+ value=cst.Name("strawberry"),
309
+ attr=cst.Name("federation"),
310
+ ),
311
+ attr=cst.Name(type_),
312
+ )
313
+
314
+ arguments.extend(federation_arguments)
315
+
316
+ if arguments:
177
317
  decorator = cst.Call(
178
318
  func=decorator,
179
- args=[_get_argument("description", description.value)],
319
+ args=arguments,
180
320
  )
181
321
 
182
322
  return cst.Decorator(
@@ -189,8 +329,9 @@ def _get_class_definition(
189
329
  | ObjectTypeExtensionNode
190
330
  | InterfaceTypeDefinitionNode
191
331
  | InputObjectTypeDefinitionNode,
332
+ is_apollo_federation: bool,
192
333
  ) -> cst.ClassDef:
193
- decorator = _get_strawberry_decorator(definition)
334
+ decorator = _get_strawberry_decorator(definition, is_apollo_federation)
194
335
 
195
336
  bases = (
196
337
  [cst.Arg(cst.Name(interface.name.value)) for interface in definition.interfaces]
@@ -204,7 +345,11 @@ def _get_class_definition(
204
345
  return cst.ClassDef(
205
346
  name=cst.Name(definition.name.value),
206
347
  bases=bases,
207
- body=cst.IndentedBlock(body=[_get_field(field) for field in definition.fields]),
348
+ body=cst.IndentedBlock(
349
+ body=[
350
+ _get_field(field, is_apollo_federation) for field in definition.fields
351
+ ]
352
+ ),
208
353
  decorators=[decorator],
209
354
  )
210
355
 
@@ -243,6 +388,7 @@ def _get_schema_definition(
243
388
  root_query_name: str | None,
244
389
  root_mutation_name: str | None,
245
390
  root_subscription_name: str | None,
391
+ is_apollo_federation: bool,
246
392
  ) -> cst.SimpleStatementLine | None:
247
393
  if not any([root_query_name, root_mutation_name, root_subscription_name]):
248
394
  return None
@@ -265,17 +411,40 @@ def _get_schema_definition(
265
411
  if root_subscription_name:
266
412
  args.append(_get_arg("subscription", root_subscription_name))
267
413
 
414
+ schema_call = cst.Call(
415
+ func=cst.Attribute(
416
+ value=cst.Name("strawberry"),
417
+ attr=cst.Name("Schema"),
418
+ ),
419
+ args=args,
420
+ )
421
+
422
+ if is_apollo_federation:
423
+ args.append(
424
+ cst.Arg(
425
+ keyword=cst.Name("enable_federation_2"),
426
+ value=cst.Name("True"),
427
+ equal=cst.AssignEqual(
428
+ cst.SimpleWhitespace(""), cst.SimpleWhitespace("")
429
+ ),
430
+ )
431
+ )
432
+ schema_call = cst.Call(
433
+ func=cst.Attribute(
434
+ value=cst.Attribute(
435
+ value=cst.Name(value="strawberry"),
436
+ attr=cst.Name(value="federation"),
437
+ ),
438
+ attr=cst.Name(value="Schema"),
439
+ ),
440
+ args=args,
441
+ )
442
+
268
443
  return cst.SimpleStatementLine(
269
444
  body=[
270
445
  cst.Assign(
271
446
  targets=[cst.AssignTarget(cst.Name("schema"))],
272
- value=cst.Call(
273
- func=cst.Attribute(
274
- value=cst.Name("strawberry"),
275
- attr=cst.Name("Schema"),
276
- ),
277
- args=args,
278
- ),
447
+ value=schema_call,
279
448
  )
280
449
  ]
281
450
  )
@@ -430,6 +599,12 @@ def codegen(schema: str) -> str:
430
599
 
431
600
  object_types: dict[str, cst.ClassDef] = {}
432
601
 
602
+ # when we encounter a extend schema @link ..., we check if is an apollo federation schema
603
+ # and we use this variable to keep track of it, but at the moment the assumption is that
604
+ # the schema extension is always done at the top, this might not be the case all the
605
+ # time
606
+ is_apollo_federation = False
607
+
433
608
  for definition in document.definitions:
434
609
  if isinstance(
435
610
  definition,
@@ -440,7 +615,7 @@ def codegen(schema: str) -> str:
440
615
  ObjectTypeExtensionNode,
441
616
  ),
442
617
  ):
443
- class_definition = _get_class_definition(definition)
618
+ class_definition = _get_class_definition(definition, is_apollo_federation)
444
619
 
445
620
  object_types[definition.name.value] = class_definition
446
621
 
@@ -478,6 +653,11 @@ def codegen(schema: str) -> str:
478
653
  definitions.append(cst.EmptyLine())
479
654
  definitions.append(scalar_definition)
480
655
  definitions.append(cst.EmptyLine())
656
+ elif isinstance(definition, SchemaExtensionNode):
657
+ is_apollo_federation = any(
658
+ _is_federation_link_directive(directive)
659
+ for directive in definition.directives
660
+ )
481
661
  else:
482
662
  raise NotImplementedError(f"Unknown definition {definition}")
483
663
 
@@ -496,6 +676,7 @@ def codegen(schema: str) -> str:
496
676
  root_query_name=root_query_name,
497
677
  root_mutation_name=root_mutation_name,
498
678
  root_subscription_name=root_subscription_name,
679
+ is_apollo_federation=is_apollo_federation,
499
680
  )
500
681
 
501
682
  if schema_definition:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: strawberry-graphql
3
- Version: 0.220.0.dev1709543239
3
+ Version: 0.221.0.dev1710955937
4
4
  Summary: A library for creating GraphQL APIs
5
5
  Home-page: https://strawberry.rocks/
6
6
  License: MIT
@@ -25,7 +25,7 @@ strawberry/channels/handlers/graphql_ws_handler.py,sha256=PHRkwnXt3tY4E0XBVHh4hp
25
25
  strawberry/channels/handlers/http_handler.py,sha256=9pW978XaeF-aFWM9WMaSHCOWmcWoIJCNkW8X3lKJcws,9560
26
26
  strawberry/channels/handlers/ws_handler.py,sha256=sHL44eay4tNoKzkrRn3WewSYH-3ZSJzxJpmBJ-aTkeM,4650
27
27
  strawberry/channels/router.py,sha256=dyOBbSF8nFiygP0zz6MM14mhkvFQAEbbLBXzcpubSHM,1927
28
- strawberry/channels/testing.py,sha256=he9cdsu5KxoPfpR8z2E6kwr2hwUjVhACTrTIoIbsSkQ,5519
28
+ strawberry/channels/testing.py,sha256=IWj1CuIS3vOo2f2fw0W-0GCz-YSs7QSAAscC6suqtiI,5668
29
29
  strawberry/cli/__init__.py,sha256=OkUYNyurO-TyHcD_RsP1bjtNrxlM5gV6Ri6vs4asvvc,374
30
30
  strawberry/cli/app.py,sha256=tTMBV1pdWqMcwjWO2yn-8oLDhMhfJvUzyQtWs75LWJ0,54
31
31
  strawberry/cli/commands/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -186,7 +186,7 @@ strawberry/schema/base.py,sha256=lQBJyzG2ZhKc544oLbXEbpYOPOjaXBop3lxp68h_lfI,297
186
186
  strawberry/schema/compat.py,sha256=n0r3UPUcGemMqK8vklgtCkkuCA1p6tWAYbc6Vl4iNOw,1684
187
187
  strawberry/schema/config.py,sha256=N9KsqkSTnm5snizoBZAVneaWQprhaTd8PFyMuR-jfUs,632
188
188
  strawberry/schema/exceptions.py,sha256=_9Ua-lLRCJfbtd8B6MXILNKGI2SwHkcBNkAHYM7ITp8,534
189
- strawberry/schema/execute.py,sha256=6OE7_5v4G3t_wxp1_mfwu8TTiIkTJNBQaeGCVAljUYw,10982
189
+ strawberry/schema/execute.py,sha256=MAfuE2o89qBmnBpgaTmUA9xJ682KvEQw52EOh4XVE00,10415
190
190
  strawberry/schema/name_converter.py,sha256=UdNyd-QtqF2HsDCQK-nsOcLGxDTj4hJwYFNvMtZnpq4,6533
191
191
  strawberry/schema/schema.py,sha256=HvdOmXK2yHCliD-xyAV01rw3PoGW2f8oEzOYo-cA0AM,13726
192
192
  strawberry/schema/schema_converter.py,sha256=r5T0mwM4qY51yq_0bWDKRqeod9hzJJg3OGzV8siXnRU,34730
@@ -194,7 +194,7 @@ strawberry/schema/types/__init__.py,sha256=oHO3COWhL3L1KLYCJNY1XFf5xt2GGtHiMC-Ua
194
194
  strawberry/schema/types/base_scalars.py,sha256=Z_BmgwLicNexLipGyw6MmZ7OBnkGJU3ySgaY9SwBWrw,1837
195
195
  strawberry/schema/types/concrete_type.py,sha256=HB30G1hMUuuvjAvfSe6ADS35iI_T_wKO-EprVOWTMSs,746
196
196
  strawberry/schema/types/scalar.py,sha256=SVJ8HiKncCvOw2xwABI5xYaHcC7KkGHG-tx2WDtSoCA,2802
197
- strawberry/schema_codegen/__init__.py,sha256=U-ABa02BAfCN6zEGlc2LJEju18ErtGjbf0S8BhJXc6w,15592
197
+ strawberry/schema_codegen/__init__.py,sha256=bPWBUwYPGYRtLwLu1QtDWa0Knt0IgF0FcZnePvDkveI,20837
198
198
  strawberry/schema_directive.py,sha256=GxiOedFB-RJAflpQNUZv00C5Z6gavR-AYdsvoCA_0jc,1963
199
199
  strawberry/starlite/__init__.py,sha256=v209swT8H9MljVL-npvANhEO1zz3__PSfxb_Ix-NoeE,134
200
200
  strawberry/starlite/controller.py,sha256=x6Mm3r36cRfzo6hz9B4AYWbVh2QlYtndYcXFOr_3THM,11860
@@ -241,8 +241,8 @@ strawberry/utils/logging.py,sha256=flS7hV0JiIOEdXcrIjda4WyIWix86cpHHFNJL8gl1y4,7
241
241
  strawberry/utils/operation.py,sha256=Um-tBCPl3_bVFN2Ph7o1mnrxfxBes4HFCj6T0x4kZxE,1135
242
242
  strawberry/utils/str_converters.py,sha256=avIgPVLg98vZH9mA2lhzVdyyjqzLsK2NdBw9mJQ02Xk,813
243
243
  strawberry/utils/typing.py,sha256=Qxz1LwyVsNGV7LQW1dFsaUbsswj5LHBOdKLMom5eyEA,13491
244
- strawberry_graphql-0.220.0.dev1709543239.dist-info/LICENSE,sha256=m-XnIVUKqlG_AWnfi9NReh9JfKhYOB-gJfKE45WM1W8,1072
245
- strawberry_graphql-0.220.0.dev1709543239.dist-info/METADATA,sha256=jkAZjN8RDuaeEREHhENLaroXtadRJqyEHGmu1PGHaHI,7754
246
- strawberry_graphql-0.220.0.dev1709543239.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
247
- strawberry_graphql-0.220.0.dev1709543239.dist-info/entry_points.txt,sha256=Nk7-aT3_uEwCgyqtHESV9H6Mc31cK-VAvhnQNTzTb4k,49
248
- strawberry_graphql-0.220.0.dev1709543239.dist-info/RECORD,,
244
+ strawberry_graphql-0.221.0.dev1710955937.dist-info/LICENSE,sha256=m-XnIVUKqlG_AWnfi9NReh9JfKhYOB-gJfKE45WM1W8,1072
245
+ strawberry_graphql-0.221.0.dev1710955937.dist-info/METADATA,sha256=l7sfrGtQwOGjYGDh5RVMKFg7bL4dvPYKnTzz6xo-eFU,7754
246
+ strawberry_graphql-0.221.0.dev1710955937.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
247
+ strawberry_graphql-0.221.0.dev1710955937.dist-info/entry_points.txt,sha256=Nk7-aT3_uEwCgyqtHESV9H6Mc31cK-VAvhnQNTzTb4k,49
248
+ strawberry_graphql-0.221.0.dev1710955937.dist-info/RECORD,,