strawberry-graphql 0.275.7__py3-none-any.whl → 0.284.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of strawberry-graphql might be problematic. Click here for more details.

Files changed (161) hide show
  1. strawberry/__init__.py +2 -0
  2. strawberry/aiohttp/test/client.py +8 -15
  3. strawberry/aiohttp/views.py +15 -64
  4. strawberry/annotation.py +70 -25
  5. strawberry/asgi/__init__.py +22 -56
  6. strawberry/asgi/test/client.py +6 -6
  7. strawberry/chalice/views.py +13 -79
  8. strawberry/channels/handlers/base.py +7 -8
  9. strawberry/channels/handlers/http_handler.py +50 -32
  10. strawberry/channels/handlers/ws_handler.py +12 -14
  11. strawberry/channels/router.py +3 -4
  12. strawberry/channels/testing.py +7 -9
  13. strawberry/cli/__init__.py +7 -6
  14. strawberry/cli/commands/codegen.py +7 -7
  15. strawberry/cli/commands/dev.py +72 -0
  16. strawberry/cli/commands/schema_codegen.py +1 -2
  17. strawberry/cli/commands/server.py +3 -44
  18. strawberry/cli/commands/upgrade/__init__.py +3 -3
  19. strawberry/cli/commands/upgrade/_run_codemod.py +2 -2
  20. strawberry/cli/constants.py +1 -2
  21. strawberry/cli/{debug_server.py → dev_server.py} +3 -7
  22. strawberry/codegen/plugins/print_operation.py +2 -2
  23. strawberry/codegen/plugins/python.py +2 -2
  24. strawberry/codegen/query_codegen.py +20 -30
  25. strawberry/codegen/types.py +32 -32
  26. strawberry/codemods/__init__.py +9 -0
  27. strawberry/codemods/annotated_unions.py +2 -2
  28. strawberry/codemods/maybe_optional.py +118 -0
  29. strawberry/dataloader.py +28 -24
  30. strawberry/directive.py +6 -7
  31. strawberry/django/test/client.py +3 -3
  32. strawberry/django/views.py +21 -91
  33. strawberry/exceptions/__init__.py +4 -4
  34. strawberry/exceptions/conflicting_arguments.py +2 -2
  35. strawberry/exceptions/duplicated_type_name.py +4 -4
  36. strawberry/exceptions/exception.py +3 -3
  37. strawberry/exceptions/handler.py +8 -7
  38. strawberry/exceptions/invalid_argument_type.py +2 -2
  39. strawberry/exceptions/invalid_superclass_interface.py +2 -2
  40. strawberry/exceptions/invalid_union_type.py +4 -4
  41. strawberry/exceptions/missing_arguments_annotations.py +2 -2
  42. strawberry/exceptions/missing_dependencies.py +2 -4
  43. strawberry/exceptions/missing_field_annotation.py +2 -2
  44. strawberry/exceptions/missing_return_annotation.py +2 -2
  45. strawberry/exceptions/object_is_not_a_class.py +2 -2
  46. strawberry/exceptions/object_is_not_an_enum.py +2 -2
  47. strawberry/exceptions/permission_fail_silently_requires_optional.py +2 -2
  48. strawberry/exceptions/private_strawberry_field.py +2 -2
  49. strawberry/exceptions/scalar_already_registered.py +2 -2
  50. strawberry/exceptions/syntax.py +3 -3
  51. strawberry/exceptions/unresolved_field_type.py +2 -2
  52. strawberry/exceptions/utils/source_finder.py +25 -25
  53. strawberry/experimental/pydantic/_compat.py +8 -7
  54. strawberry/experimental/pydantic/conversion.py +2 -2
  55. strawberry/experimental/pydantic/conversion_types.py +2 -2
  56. strawberry/experimental/pydantic/error_type.py +10 -12
  57. strawberry/experimental/pydantic/fields.py +9 -15
  58. strawberry/experimental/pydantic/object_type.py +17 -25
  59. strawberry/experimental/pydantic/utils.py +1 -2
  60. strawberry/ext/mypy_plugin.py +12 -14
  61. strawberry/extensions/base_extension.py +2 -1
  62. strawberry/extensions/context.py +13 -18
  63. strawberry/extensions/directives.py +9 -3
  64. strawberry/extensions/field_extension.py +4 -4
  65. strawberry/extensions/mask_errors.py +24 -13
  66. strawberry/extensions/max_aliases.py +1 -3
  67. strawberry/extensions/parser_cache.py +1 -2
  68. strawberry/extensions/query_depth_limiter.py +18 -14
  69. strawberry/extensions/runner.py +2 -2
  70. strawberry/extensions/tracing/apollo.py +3 -3
  71. strawberry/extensions/tracing/datadog.py +3 -3
  72. strawberry/extensions/tracing/opentelemetry.py +6 -8
  73. strawberry/extensions/tracing/utils.py +3 -1
  74. strawberry/extensions/utils.py +2 -2
  75. strawberry/extensions/validation_cache.py +2 -3
  76. strawberry/fastapi/context.py +6 -6
  77. strawberry/fastapi/router.py +43 -42
  78. strawberry/federation/argument.py +4 -5
  79. strawberry/federation/enum.py +18 -21
  80. strawberry/federation/field.py +94 -97
  81. strawberry/federation/object_type.py +56 -58
  82. strawberry/federation/scalar.py +27 -35
  83. strawberry/federation/schema.py +15 -16
  84. strawberry/federation/schema_directive.py +7 -6
  85. strawberry/federation/schema_directives.py +11 -11
  86. strawberry/federation/union.py +4 -4
  87. strawberry/flask/views.py +16 -85
  88. strawberry/http/__init__.py +30 -20
  89. strawberry/http/async_base_view.py +208 -89
  90. strawberry/http/base.py +28 -11
  91. strawberry/http/exceptions.py +5 -7
  92. strawberry/http/ides.py +2 -3
  93. strawberry/http/sync_base_view.py +115 -69
  94. strawberry/http/types.py +3 -3
  95. strawberry/litestar/controller.py +43 -77
  96. strawberry/permission.py +4 -6
  97. strawberry/printer/ast_from_value.py +3 -5
  98. strawberry/printer/printer.py +18 -15
  99. strawberry/quart/views.py +16 -48
  100. strawberry/relay/exceptions.py +4 -4
  101. strawberry/relay/fields.py +33 -32
  102. strawberry/relay/types.py +32 -35
  103. strawberry/relay/utils.py +11 -23
  104. strawberry/resolvers.py +2 -1
  105. strawberry/sanic/context.py +1 -0
  106. strawberry/sanic/utils.py +3 -3
  107. strawberry/sanic/views.py +15 -54
  108. strawberry/scalars.py +2 -2
  109. strawberry/schema/_graphql_core.py +55 -0
  110. strawberry/schema/base.py +32 -33
  111. strawberry/schema/compat.py +9 -9
  112. strawberry/schema/config.py +10 -1
  113. strawberry/schema/exceptions.py +1 -3
  114. strawberry/schema/name_converter.py +9 -8
  115. strawberry/schema/schema.py +133 -100
  116. strawberry/schema/schema_converter.py +96 -58
  117. strawberry/schema/types/base_scalars.py +1 -1
  118. strawberry/schema/types/concrete_type.py +5 -5
  119. strawberry/schema/validation_rules/maybe_null.py +136 -0
  120. strawberry/schema_codegen/__init__.py +3 -3
  121. strawberry/schema_directive.py +7 -6
  122. strawberry/static/graphiql.html +5 -5
  123. strawberry/streamable.py +35 -0
  124. strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +5 -16
  125. strawberry/subscriptions/protocols/graphql_transport_ws/types.py +20 -20
  126. strawberry/subscriptions/protocols/graphql_ws/handlers.py +5 -12
  127. strawberry/subscriptions/protocols/graphql_ws/types.py +14 -14
  128. strawberry/test/client.py +18 -18
  129. strawberry/tools/create_type.py +2 -3
  130. strawberry/types/arguments.py +41 -28
  131. strawberry/types/auto.py +3 -4
  132. strawberry/types/base.py +25 -27
  133. strawberry/types/enum.py +22 -25
  134. strawberry/types/execution.py +21 -16
  135. strawberry/types/field.py +109 -130
  136. strawberry/types/fields/resolver.py +19 -21
  137. strawberry/types/info.py +5 -11
  138. strawberry/types/lazy_type.py +2 -3
  139. strawberry/types/maybe.py +12 -3
  140. strawberry/types/mutation.py +115 -118
  141. strawberry/types/nodes.py +2 -2
  142. strawberry/types/object_type.py +43 -63
  143. strawberry/types/scalar.py +37 -43
  144. strawberry/types/union.py +12 -14
  145. strawberry/utils/aio.py +12 -9
  146. strawberry/utils/await_maybe.py +3 -3
  147. strawberry/utils/deprecations.py +2 -2
  148. strawberry/utils/importer.py +1 -2
  149. strawberry/utils/inspect.py +4 -6
  150. strawberry/utils/logging.py +2 -2
  151. strawberry/utils/operation.py +4 -4
  152. strawberry/utils/typing.py +18 -83
  153. {strawberry_graphql-0.275.7.dist-info → strawberry_graphql-0.284.3.dist-info}/METADATA +14 -8
  154. strawberry_graphql-0.284.3.dist-info/RECORD +243 -0
  155. {strawberry_graphql-0.275.7.dist-info → strawberry_graphql-0.284.3.dist-info}/WHEEL +1 -1
  156. strawberry/utils/dataclasses.py +0 -37
  157. strawberry/utils/debug.py +0 -46
  158. strawberry/utils/graphql_lexer.py +0 -35
  159. strawberry_graphql-0.275.7.dist-info/RECORD +0 -241
  160. {strawberry_graphql-0.275.7.dist-info → strawberry_graphql-0.284.3.dist-info}/entry_points.txt +0 -0
  161. {strawberry_graphql-0.275.7.dist-info → strawberry_graphql-0.284.3.dist-info/licenses}/LICENSE +0 -0
@@ -2,20 +2,19 @@ import abc
2
2
  import asyncio
3
3
  import contextlib
4
4
  import json
5
- from collections.abc import AsyncGenerator, Mapping
5
+ from collections.abc import AsyncGenerator, Callable, Mapping
6
6
  from datetime import timedelta
7
7
  from typing import (
8
8
  Any,
9
- Callable,
10
9
  Generic,
11
- Optional,
12
- Union,
10
+ Literal,
11
+ TypeGuard,
13
12
  cast,
14
13
  overload,
15
14
  )
16
- from typing_extensions import Literal, TypeGuard
17
15
 
18
16
  from graphql import GraphQLError
17
+ from lia import AsyncHTTPRequestAdapter, HTTPException
19
18
 
20
19
  from strawberry.exceptions import MissingQueryError
21
20
  from strawberry.file_uploads.utils import replace_placeholders_with_files
@@ -25,6 +24,9 @@ from strawberry.http import (
25
24
  process_result,
26
25
  )
27
26
  from strawberry.http.ides import GraphQL_IDE
27
+ from strawberry.schema._graphql_core import (
28
+ GraphQLIncrementalExecutionResults,
29
+ )
28
30
  from strawberry.schema.base import BaseSchema
29
31
  from strawberry.schema.exceptions import (
30
32
  CannotGetOperationTypeError,
@@ -40,9 +42,7 @@ from strawberry.types.graphql import OperationType
40
42
  from strawberry.types.unset import UNSET, UnsetType
41
43
 
42
44
  from .base import BaseView
43
- from .exceptions import HTTPException
44
45
  from .parse_content_type import parse_content_type
45
- from .types import FormData, HTTPMethod, QueryParams
46
46
  from .typevars import (
47
47
  Context,
48
48
  Request,
@@ -54,30 +54,6 @@ from .typevars import (
54
54
  )
55
55
 
56
56
 
57
- class AsyncHTTPRequestAdapter(abc.ABC):
58
- @property
59
- @abc.abstractmethod
60
- def query_params(self) -> QueryParams: ...
61
-
62
- @property
63
- @abc.abstractmethod
64
- def method(self) -> HTTPMethod: ...
65
-
66
- @property
67
- @abc.abstractmethod
68
- def headers(self) -> Mapping[str, str]: ...
69
-
70
- @property
71
- @abc.abstractmethod
72
- def content_type(self) -> Optional[str]: ...
73
-
74
- @abc.abstractmethod
75
- async def get_body(self) -> Union[str, bytes]: ...
76
-
77
- @abc.abstractmethod
78
- async def get_form_data(self) -> FormData: ...
79
-
80
-
81
57
  class AsyncWebSocketAdapter(abc.ABC):
82
58
  def __init__(self, view: "AsyncBaseHTTPView") -> None:
83
59
  self.view = view
@@ -108,10 +84,9 @@ class AsyncBaseHTTPView(
108
84
  ],
109
85
  ):
110
86
  schema: BaseSchema
111
- graphql_ide: Optional[GraphQL_IDE]
112
- debug: bool
87
+ graphql_ide: GraphQL_IDE | None
113
88
  keep_alive = False
114
- keep_alive_interval: Optional[float] = None
89
+ keep_alive_interval: float | None = None
115
90
  connection_init_wait_timeout: timedelta = timedelta(minutes=1)
116
91
  request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter]
117
92
  websocket_adapter_class: Callable[
@@ -139,18 +114,20 @@ class AsyncBaseHTTPView(
139
114
  @abc.abstractmethod
140
115
  async def get_context(
141
116
  self,
142
- request: Union[Request, WebSocketRequest],
143
- response: Union[SubResponse, WebSocketResponse],
117
+ request: Request | WebSocketRequest,
118
+ response: SubResponse | WebSocketResponse,
144
119
  ) -> Context: ...
145
120
 
146
121
  @abc.abstractmethod
147
122
  async def get_root_value(
148
- self, request: Union[Request, WebSocketRequest]
149
- ) -> Optional[RootValue]: ...
123
+ self, request: Request | WebSocketRequest
124
+ ) -> RootValue | None: ...
150
125
 
151
126
  @abc.abstractmethod
152
127
  def create_response(
153
- self, response_data: GraphQLHTTPResponse, sub_response: SubResponse
128
+ self,
129
+ response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse],
130
+ sub_response: SubResponse,
154
131
  ) -> Response: ...
155
132
 
156
133
  @abc.abstractmethod
@@ -167,22 +144,26 @@ class AsyncBaseHTTPView(
167
144
 
168
145
  @abc.abstractmethod
169
146
  def is_websocket_request(
170
- self, request: Union[Request, WebSocketRequest]
147
+ self, request: Request | WebSocketRequest
171
148
  ) -> TypeGuard[WebSocketRequest]: ...
172
149
 
173
150
  @abc.abstractmethod
174
151
  async def pick_websocket_subprotocol(
175
152
  self, request: WebSocketRequest
176
- ) -> Optional[str]: ...
153
+ ) -> str | None: ...
177
154
 
178
155
  @abc.abstractmethod
179
156
  async def create_websocket_response(
180
- self, request: WebSocketRequest, subprotocol: Optional[str]
157
+ self, request: WebSocketRequest, subprotocol: str | None
181
158
  ) -> WebSocketResponse: ...
182
159
 
183
160
  async def execute_operation(
184
- self, request: Request, context: Context, root_value: Optional[RootValue]
185
- ) -> Union[ExecutionResult, SubscriptionExecutionResult]:
161
+ self,
162
+ request: Request,
163
+ context: Context,
164
+ root_value: RootValue | None,
165
+ sub_response: SubResponse,
166
+ ) -> ExecutionResult | list[ExecutionResult] | SubscriptionExecutionResult:
186
167
  request_adapter = self.request_adapter_class(request)
187
168
 
188
169
  try:
@@ -198,6 +179,22 @@ class AsyncBaseHTTPView(
198
179
  if not self.allow_queries_via_get and request_adapter.method == "GET":
199
180
  allowed_operation_types = allowed_operation_types - {OperationType.QUERY}
200
181
 
182
+ if isinstance(request_data, list):
183
+ # batch GraphQL requests
184
+ return await asyncio.gather(
185
+ *[
186
+ self.execute_single(
187
+ request=request,
188
+ request_adapter=request_adapter,
189
+ sub_response=sub_response,
190
+ context=context,
191
+ root_value=root_value,
192
+ request_data=data,
193
+ )
194
+ for data in request_data
195
+ ]
196
+ )
197
+
201
198
  if request_data.protocol == "multipart-subscription":
202
199
  return await self.schema.subscribe(
203
200
  request_data.query, # type: ignore
@@ -208,24 +205,59 @@ class AsyncBaseHTTPView(
208
205
  operation_extensions=request_data.extensions,
209
206
  )
210
207
 
211
- return await self.schema.execute(
212
- request_data.query,
208
+ return await self.execute_single(
209
+ request=request,
210
+ request_adapter=request_adapter,
211
+ sub_response=sub_response,
212
+ context=context,
213
213
  root_value=root_value,
214
- variable_values=request_data.variables,
215
- context_value=context,
216
- operation_name=request_data.operation_name,
217
- allowed_operation_types=allowed_operation_types,
218
- operation_extensions=request_data.extensions,
214
+ request_data=request_data,
219
215
  )
220
216
 
217
+ async def execute_single(
218
+ self,
219
+ request: Request,
220
+ request_adapter: AsyncHTTPRequestAdapter,
221
+ sub_response: SubResponse,
222
+ context: Context,
223
+ root_value: RootValue | None,
224
+ request_data: GraphQLRequestData,
225
+ ) -> ExecutionResult:
226
+ allowed_operation_types = OperationType.from_http(request_adapter.method)
227
+
228
+ if not self.allow_queries_via_get and request_adapter.method == "GET":
229
+ allowed_operation_types = allowed_operation_types - {OperationType.QUERY}
230
+
231
+ try:
232
+ result = await self.schema.execute(
233
+ request_data.query,
234
+ root_value=root_value,
235
+ variable_values=request_data.variables,
236
+ context_value=context,
237
+ operation_name=request_data.operation_name,
238
+ allowed_operation_types=allowed_operation_types,
239
+ operation_extensions=request_data.extensions,
240
+ )
241
+ except CannotGetOperationTypeError as e:
242
+ raise HTTPException(400, e.as_http_error_reason()) from e
243
+ except InvalidOperationTypeError as e:
244
+ raise HTTPException(
245
+ 400, e.as_http_error_reason(request_adapter.method)
246
+ ) from e
247
+ except MissingQueryError as e:
248
+ raise HTTPException(400, "No GraphQL query found in the request") from e
249
+
250
+ return result
251
+
221
252
  async def parse_multipart(self, request: AsyncHTTPRequestAdapter) -> dict[str, str]:
222
253
  try:
223
254
  form_data = await request.get_form_data()
224
255
  except ValueError as e:
225
256
  raise HTTPException(400, "Unable to parse the multipart body") from e
226
257
 
227
- operations = form_data["form"].get("operations", "{}")
228
- files_map = form_data["form"].get("map", "{}")
258
+ operations = form_data.form.get("operations", "{}")
259
+ files_map = form_data.form.get("map", "{}")
260
+ files = form_data.files
229
261
 
230
262
  if isinstance(operations, (bytes, str)):
231
263
  operations = self.parse_json(operations)
@@ -234,9 +266,7 @@ class AsyncBaseHTTPView(
234
266
  files_map = self.parse_json(files_map)
235
267
 
236
268
  try:
237
- return replace_placeholders_with_files(
238
- operations, files_map, form_data["files"]
239
- )
269
+ return replace_placeholders_with_files(operations, files_map, files)
240
270
  except KeyError as e:
241
271
  raise HTTPException(400, "File(s) missing in form data") from e
242
272
 
@@ -250,7 +280,7 @@ class AsyncBaseHTTPView(
250
280
  self,
251
281
  request: Request,
252
282
  context: Context = UNSET,
253
- root_value: Optional[RootValue] = UNSET,
283
+ root_value: RootValue | None = UNSET,
254
284
  ) -> Response: ...
255
285
 
256
286
  @overload
@@ -258,15 +288,15 @@ class AsyncBaseHTTPView(
258
288
  self,
259
289
  request: WebSocketRequest,
260
290
  context: Context = UNSET,
261
- root_value: Optional[RootValue] = UNSET,
291
+ root_value: RootValue | None = UNSET,
262
292
  ) -> WebSocketResponse: ...
263
293
 
264
294
  async def run(
265
295
  self,
266
- request: Union[Request, WebSocketRequest],
296
+ request: Request | WebSocketRequest,
267
297
  context: Context = UNSET,
268
- root_value: Optional[RootValue] = UNSET,
269
- ) -> Union[Response, WebSocketResponse]:
298
+ root_value: RootValue | None = UNSET,
299
+ ) -> Response | WebSocketResponse:
270
300
  root_value = (
271
301
  await self.get_root_value(request) if root_value is UNSET else root_value
272
302
  )
@@ -291,7 +321,6 @@ class AsyncBaseHTTPView(
291
321
  context=context,
292
322
  root_value=root_value,
293
323
  schema=self.schema,
294
- debug=self.debug,
295
324
  connection_init_wait_timeout=self.connection_init_wait_timeout,
296
325
  ).handle()
297
326
  elif websocket_subprotocol == GRAPHQL_WS_PROTOCOL:
@@ -301,7 +330,6 @@ class AsyncBaseHTTPView(
301
330
  context=context,
302
331
  root_value=root_value,
303
332
  schema=self.schema,
304
- debug=self.debug,
305
333
  keep_alive=self.keep_alive,
306
334
  keep_alive_interval=self.keep_alive_interval,
307
335
  ).handle()
@@ -327,18 +355,12 @@ class AsyncBaseHTTPView(
327
355
  return await self.render_graphql_ide(request)
328
356
  raise HTTPException(404, "Not Found")
329
357
 
330
- try:
331
- result = await self.execute_operation(
332
- request=request, context=context, root_value=root_value
333
- )
334
- except CannotGetOperationTypeError as e:
335
- raise HTTPException(400, e.as_http_error_reason()) from e
336
- except InvalidOperationTypeError as e:
337
- raise HTTPException(
338
- 400, e.as_http_error_reason(request_adapter.method)
339
- ) from e
340
- except MissingQueryError as e:
341
- raise HTTPException(400, "No GraphQL query found in the request") from e
358
+ result = await self.execute_operation(
359
+ request=request,
360
+ context=context,
361
+ root_value=root_value,
362
+ sub_response=sub_response,
363
+ )
342
364
 
343
365
  if isinstance(result, SubscriptionExecutionResult):
344
366
  stream = self._get_stream(request, result)
@@ -348,27 +370,110 @@ class AsyncBaseHTTPView(
348
370
  stream,
349
371
  sub_response,
350
372
  headers={
351
- "Transfer-Encoding": "chunked",
352
373
  "Content-Type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
353
374
  },
354
375
  )
376
+ if isinstance(result, GraphQLIncrementalExecutionResults):
377
+
378
+ async def stream() -> AsyncGenerator[str, None]:
379
+ yield "---"
380
+
381
+ response = await self.process_result(request, result.initial_result)
382
+
383
+ response["hasNext"] = result.initial_result.has_next
384
+ response["pending"] = [
385
+ p.formatted for p in result.initial_result.pending
386
+ ]
387
+ response["extensions"] = result.initial_result.extensions
388
+
389
+ yield self.encode_multipart_data(response, "-")
390
+
391
+ all_pending = result.initial_result.pending
392
+
393
+ async for value in result.subsequent_results:
394
+ response = {
395
+ "hasNext": value.has_next,
396
+ "extensions": value.extensions,
397
+ }
398
+
399
+ if value.pending:
400
+ response["pending"] = [p.formatted for p in value.pending]
401
+
402
+ if value.completed:
403
+ response["completed"] = [p.formatted for p in value.completed]
404
+
405
+ if value.incremental:
406
+ incremental = []
407
+
408
+ all_pending.extend(value.pending)
409
+
410
+ for incremental_value in value.incremental:
411
+ pending_value = next(
412
+ (
413
+ v
414
+ for v in all_pending
415
+ if v.id == incremental_value.id
416
+ ),
417
+ None,
418
+ )
419
+
420
+ assert pending_value
421
+
422
+ incremental.append(
423
+ {
424
+ **incremental_value.formatted,
425
+ "path": pending_value.path,
426
+ "label": pending_value.label,
427
+ }
428
+ )
429
+
430
+ response["incremental"] = incremental
431
+
432
+ yield self.encode_multipart_data(response, "-")
433
+
434
+ yield "--\r\n"
435
+
436
+ return await self.create_streaming_response(
437
+ request,
438
+ stream,
439
+ sub_response,
440
+ headers={
441
+ "Content-Type": 'multipart/mixed; boundary="-"',
442
+ },
443
+ )
355
444
 
356
- response_data = await self.process_result(request=request, result=result)
445
+ response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse]
446
+
447
+ if isinstance(result, list):
448
+ response_data = []
449
+ for execution_result in result:
450
+ processed_result = await self.process_result(
451
+ request=request, result=execution_result
452
+ )
453
+ if execution_result.errors:
454
+ self._handle_errors(execution_result.errors, processed_result)
455
+ response_data.append(processed_result)
456
+ else:
457
+ response_data = await self.process_result(request=request, result=result)
357
458
 
358
- if result.errors:
359
- self._handle_errors(result.errors, response_data)
459
+ if result.errors:
460
+ self._handle_errors(result.errors, response_data)
360
461
 
361
462
  return self.create_response(
362
463
  response_data=response_data, sub_response=sub_response
363
464
  )
364
465
 
365
466
  def encode_multipart_data(self, data: Any, separator: str) -> str:
467
+ encoded_data = self.encode_json(data)
468
+
366
469
  return "".join(
367
470
  [
368
- f"\r\n--{separator}\r\n",
369
- "Content-Type: application/json\r\n\r\n",
370
- self.encode_json(data),
371
- "\n",
471
+ "\r\n",
472
+ "Content-Type: application/json; charset=utf-8\r\n",
473
+ "Content-Length: " + str(len(encoded_data)) + "\r\n",
474
+ "\r\n",
475
+ encoded_data,
476
+ f"\r\n--{separator}",
372
477
  ]
373
478
  )
374
479
 
@@ -508,15 +613,16 @@ class AsyncBaseHTTPView(
508
613
 
509
614
  async def parse_http_body(
510
615
  self, request: AsyncHTTPRequestAdapter
511
- ) -> GraphQLRequestData:
616
+ ) -> GraphQLRequestData | list[GraphQLRequestData]:
512
617
  headers = {key.lower(): value for key, value in request.headers.items()}
513
618
  content_type, _ = parse_content_type(request.content_type or "")
514
619
  accept = headers.get("accept", "")
515
620
 
516
- protocol: Literal["http", "multipart-subscription"] = "http"
517
-
518
- if self._is_multipart_subscriptions(*parse_content_type(accept)):
519
- protocol = "multipart-subscription"
621
+ protocol: Literal["http", "multipart-subscription"] = (
622
+ "multipart-subscription"
623
+ if self._is_multipart_subscriptions(*parse_content_type(accept))
624
+ else "http"
625
+ )
520
626
 
521
627
  if request.method == "GET":
522
628
  data = self.parse_query_params(request.query_params)
@@ -527,6 +633,19 @@ class AsyncBaseHTTPView(
527
633
  else:
528
634
  raise HTTPException(400, "Unsupported content type")
529
635
 
636
+ if isinstance(data, list):
637
+ self._validate_batch_request(data, protocol=protocol)
638
+ return [
639
+ GraphQLRequestData(
640
+ query=item.get("query"),
641
+ variables=item.get("variables"),
642
+ operation_name=item.get("operationName"),
643
+ extensions=item.get("extensions"),
644
+ protocol=protocol,
645
+ )
646
+ for item in data
647
+ ]
648
+
530
649
  query = data.get("query")
531
650
  if not isinstance(query, (str, type(None))):
532
651
  raise HTTPException(
@@ -563,7 +682,7 @@ class AsyncBaseHTTPView(
563
682
 
564
683
  async def on_ws_connect(
565
684
  self, context: Context
566
- ) -> Union[UnsetType, None, dict[str, object]]:
685
+ ) -> UnsetType | None | dict[str, object]:
567
686
  return UNSET
568
687
 
569
688
 
strawberry/http/base.py CHANGED
@@ -1,18 +1,21 @@
1
1
  import json
2
2
  from collections.abc import Mapping
3
- from typing import Any, Generic, Optional, Union
3
+ from typing import Any, Generic
4
4
  from typing_extensions import Protocol
5
5
 
6
+ from lia import HTTPException
7
+
8
+ from strawberry.http import GraphQLRequestData
6
9
  from strawberry.http.ides import GraphQL_IDE, get_graphql_ide_html
7
10
  from strawberry.http.types import HTTPMethod, QueryParams
11
+ from strawberry.schema.base import BaseSchema
8
12
 
9
- from .exceptions import HTTPException
10
13
  from .typevars import Request
11
14
 
12
15
 
13
16
  class BaseRequestProtocol(Protocol):
14
17
  @property
15
- def query_params(self) -> Mapping[str, Optional[Union[str, list[str]]]]: ...
18
+ def query_params(self) -> Mapping[str, str | list[str] | None]: ...
16
19
 
17
20
  @property
18
21
  def method(self) -> HTTPMethod: ...
@@ -22,8 +25,9 @@ class BaseRequestProtocol(Protocol):
22
25
 
23
26
 
24
27
  class BaseView(Generic[Request]):
25
- graphql_ide: Optional[GraphQL_IDE]
28
+ graphql_ide: GraphQL_IDE | None
26
29
  multipart_uploads_enabled: bool = False
30
+ schema: BaseSchema
27
31
 
28
32
  def should_render_graphql_ide(self, request: BaseRequestProtocol) -> bool:
29
33
  return (
@@ -38,13 +42,13 @@ class BaseView(Generic[Request]):
38
42
  def is_request_allowed(self, request: BaseRequestProtocol) -> bool:
39
43
  return request.method in ("GET", "POST")
40
44
 
41
- def parse_json(self, data: Union[str, bytes]) -> Any:
45
+ def parse_json(self, data: str | bytes) -> Any:
42
46
  try:
43
47
  return self.decode_json(data)
44
48
  except json.JSONDecodeError as e:
45
49
  raise HTTPException(400, "Unable to parse request body as JSON") from e
46
50
 
47
- def decode_json(self, data: Union[str, bytes]) -> object:
51
+ def decode_json(self, data: str | bytes) -> object:
48
52
  return json.loads(data)
49
53
 
50
54
  def encode_json(self, data: object) -> str:
@@ -74,13 +78,26 @@ class BaseView(Generic[Request]):
74
78
  def _is_multipart_subscriptions(
75
79
  self, content_type: str, params: dict[str, str]
76
80
  ) -> bool:
77
- if content_type != "multipart/mixed":
78
- return False
81
+ subscription_spec = params.get("subscriptionspec", "").strip("'\"")
82
+ return (
83
+ content_type == "multipart/mixed"
84
+ and ("boundary" not in params or params["boundary"] == "graphql")
85
+ and subscription_spec.startswith("1.0")
86
+ )
79
87
 
80
- if params.get("boundary") != "graphql":
81
- return False
88
+ def _validate_batch_request(
89
+ self, request_data: list[GraphQLRequestData], protocol: str
90
+ ) -> None:
91
+ if self.schema.config.batching_config is None:
92
+ raise HTTPException(400, "Batching is not enabled")
93
+
94
+ if protocol == "multipart-subscription":
95
+ raise HTTPException(
96
+ 400, "Batching is not supported for multipart subscriptions"
97
+ )
82
98
 
83
- return params.get("subscriptionspec", "").startswith("1.0")
99
+ if len(request_data) > self.schema.config.batching_config["max_operations"]:
100
+ raise HTTPException(400, "Too many operations")
84
101
 
85
102
 
86
103
  __all__ = ["BaseView"]
@@ -1,9 +1,3 @@
1
- class HTTPException(Exception):
2
- def __init__(self, status_code: int, reason: str) -> None:
3
- self.status_code = status_code
4
- self.reason = reason
5
-
6
-
7
1
  class NonTextMessageReceived(Exception):
8
2
  pass
9
3
 
@@ -16,4 +10,8 @@ class WebSocketDisconnected(Exception):
16
10
  pass
17
11
 
18
12
 
19
- __all__ = ["HTTPException"]
13
+ __all__ = [
14
+ "NonJsonMessageReceived",
15
+ "NonTextMessageReceived",
16
+ "WebSocketDisconnected",
17
+ ]
strawberry/http/ides.py CHANGED
@@ -1,12 +1,11 @@
1
1
  import pathlib
2
- from typing import Optional
3
- from typing_extensions import Literal
2
+ from typing import Literal
4
3
 
5
4
  GraphQL_IDE = Literal["graphiql", "apollo-sandbox", "pathfinder"]
6
5
 
7
6
 
8
7
  def get_graphql_ide_html(
9
- graphql_ide: Optional[GraphQL_IDE] = "graphiql",
8
+ graphql_ide: GraphQL_IDE | None = "graphiql",
10
9
  ) -> str:
11
10
  here = pathlib.Path(__file__).parents[1]
12
11