strawberry-graphql 0.243.1__py3-none-any.whl → 0.244.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. strawberry/aiohttp/views.py +58 -44
  2. strawberry/asgi/__init__.py +62 -61
  3. strawberry/channels/__init__.py +1 -6
  4. strawberry/channels/handlers/base.py +2 -2
  5. strawberry/channels/handlers/http_handler.py +18 -1
  6. strawberry/channels/handlers/ws_handler.py +113 -58
  7. strawberry/codegen/query_codegen.py +8 -6
  8. strawberry/django/views.py +19 -1
  9. strawberry/fastapi/router.py +27 -45
  10. strawberry/flask/views.py +15 -1
  11. strawberry/http/async_base_view.py +136 -16
  12. strawberry/http/exceptions.py +4 -0
  13. strawberry/http/typevars.py +11 -1
  14. strawberry/litestar/controller.py +77 -86
  15. strawberry/quart/views.py +15 -1
  16. strawberry/sanic/views.py +21 -1
  17. strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +41 -40
  18. strawberry/subscriptions/protocols/graphql_ws/handlers.py +46 -45
  19. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/METADATA +1 -1
  20. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/RECORD +23 -37
  21. strawberry/aiohttp/handlers/__init__.py +0 -6
  22. strawberry/aiohttp/handlers/graphql_transport_ws_handler.py +0 -62
  23. strawberry/aiohttp/handlers/graphql_ws_handler.py +0 -69
  24. strawberry/asgi/handlers/__init__.py +0 -6
  25. strawberry/asgi/handlers/graphql_transport_ws_handler.py +0 -66
  26. strawberry/asgi/handlers/graphql_ws_handler.py +0 -71
  27. strawberry/channels/handlers/graphql_transport_ws_handler.py +0 -62
  28. strawberry/channels/handlers/graphql_ws_handler.py +0 -72
  29. strawberry/fastapi/handlers/__init__.py +0 -6
  30. strawberry/fastapi/handlers/graphql_transport_ws_handler.py +0 -20
  31. strawberry/fastapi/handlers/graphql_ws_handler.py +0 -18
  32. strawberry/litestar/handlers/__init__.py +0 -0
  33. strawberry/litestar/handlers/graphql_transport_ws_handler.py +0 -60
  34. strawberry/litestar/handlers/graphql_ws_handler.py +0 -66
  35. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/LICENSE +0 -0
  36. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/WHEEL +0 -0
  37. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/entry_points.txt +0 -0
@@ -7,19 +7,19 @@ from datetime import timedelta
7
7
  from typing import (
8
8
  TYPE_CHECKING,
9
9
  Any,
10
+ AsyncGenerator,
10
11
  AsyncIterator,
11
12
  Callable,
12
13
  Dict,
13
14
  FrozenSet,
14
- List,
15
15
  Optional,
16
- Set,
17
16
  Tuple,
18
17
  Type,
19
18
  TypedDict,
20
19
  Union,
21
20
  cast,
22
21
  )
22
+ from typing_extensions import TypeGuard
23
23
 
24
24
  from msgspec import Struct
25
25
 
@@ -35,23 +35,24 @@ from litestar import (
35
35
  )
36
36
  from litestar.background_tasks import BackgroundTasks
37
37
  from litestar.di import Provide
38
- from litestar.exceptions import NotFoundException, ValidationException
38
+ from litestar.exceptions import (
39
+ NotFoundException,
40
+ SerializationException,
41
+ ValidationException,
42
+ WebSocketDisconnect,
43
+ )
39
44
  from litestar.response.streaming import Stream
40
45
  from litestar.status_codes import HTTP_200_OK
41
46
  from strawberry.exceptions import InvalidCustomContext
42
- from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter
43
- from strawberry.http.exceptions import HTTPException
47
+ from strawberry.http.async_base_view import (
48
+ AsyncBaseHTTPView,
49
+ AsyncHTTPRequestAdapter,
50
+ AsyncWebSocketAdapter,
51
+ )
52
+ from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
44
53
  from strawberry.http.types import FormData, HTTPMethod, QueryParams
45
54
  from strawberry.http.typevars import Context, RootValue
46
55
  from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
47
- from strawberry.subscriptions.protocols.graphql_transport_ws import (
48
- WS_4406_PROTOCOL_NOT_ACCEPTABLE,
49
- )
50
-
51
- from .handlers.graphql_transport_ws_handler import (
52
- GraphQLTransportWSHandler as BaseGraphQLTransportWSHandler,
53
- )
54
- from .handlers.graphql_ws_handler import GraphQLWSHandler as BaseGraphQLWSHandler
55
56
 
56
57
  if TYPE_CHECKING:
57
58
  from collections.abc import Mapping
@@ -152,22 +153,6 @@ class GraphQLResource(Struct):
152
153
  extensions: Optional[dict[str, object]]
153
154
 
154
155
 
155
- class GraphQLWSHandler(BaseGraphQLWSHandler):
156
- async def get_context(self) -> Any:
157
- return await self._get_context()
158
-
159
- async def get_root_value(self) -> Any:
160
- return await self._get_root_value()
161
-
162
-
163
- class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
164
- async def get_context(self) -> Any:
165
- return await self._get_context()
166
-
167
- async def get_root_value(self) -> Any:
168
- return await self._get_root_value()
169
-
170
-
171
156
  class LitestarRequestAdapter(AsyncHTTPRequestAdapter):
172
157
  def __init__(self, request: Request[Any, Any, Any]) -> None:
173
158
  self.request = request
@@ -203,10 +188,37 @@ class LitestarRequestAdapter(AsyncHTTPRequestAdapter):
203
188
  return FormData(form=multipart_data, files=multipart_data)
204
189
 
205
190
 
191
+ class LitestarWebSocketAdapter(AsyncWebSocketAdapter):
192
+ def __init__(self, request: WebSocket, response: WebSocket) -> None:
193
+ self.ws = response
194
+
195
+ async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
196
+ try:
197
+ try:
198
+ while self.ws.connection_state != "disconnect":
199
+ yield await self.ws.receive_json()
200
+ except (SerializationException, ValueError):
201
+ raise NonJsonMessageReceived()
202
+ except WebSocketDisconnect:
203
+ pass
204
+
205
+ async def send_json(self, message: Mapping[str, object]) -> None:
206
+ await self.ws.send_json(message)
207
+
208
+ async def close(self, code: int, reason: str) -> None:
209
+ await self.ws.close(code=code, reason=reason)
210
+
211
+
206
212
  class GraphQLController(
207
213
  Controller,
208
214
  AsyncBaseHTTPView[
209
- Request[Any, Any, Any], Response[Any], Response[Any], Context, RootValue
215
+ Request[Any, Any, Any],
216
+ Response[Any],
217
+ Response[Any],
218
+ WebSocket,
219
+ WebSocket,
220
+ Context,
221
+ RootValue,
210
222
  ],
211
223
  ):
212
224
  path: str = ""
@@ -219,10 +231,7 @@ class GraphQLController(
219
231
  }
220
232
 
221
233
  request_adapter_class = LitestarRequestAdapter
222
- graphql_ws_handler_class: Type[GraphQLWSHandler] = GraphQLWSHandler
223
- graphql_transport_ws_handler_class: Type[GraphQLTransportWSHandler] = (
224
- GraphQLTransportWSHandler
225
- )
234
+ websocket_adapter_class = LitestarWebSocketAdapter
226
235
 
227
236
  allow_queries_via_get: bool = True
228
237
  graphiql_allowed_accept: FrozenSet[str] = frozenset({"text/html", "*/*"})
@@ -236,6 +245,23 @@ class GraphQLController(
236
245
  keep_alive: bool = False
237
246
  keep_alive_interval: float = 1
238
247
 
248
+ def is_websocket_request(
249
+ self, request: Union[Request, WebSocket]
250
+ ) -> TypeGuard[WebSocket]:
251
+ return isinstance(request, WebSocket)
252
+
253
+ async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]:
254
+ subprotocols = request.scope["subprotocols"]
255
+ intersection = set(subprotocols) & set(self.protocols)
256
+ sorted_intersection = sorted(intersection, key=subprotocols.index)
257
+ return next(iter(sorted_intersection), None)
258
+
259
+ async def create_websocket_response(
260
+ self, request: WebSocket, subprotocol: Optional[str]
261
+ ) -> WebSocket:
262
+ await request.accept(subprotocols=subprotocol)
263
+ return request
264
+
239
265
  async def execute_request(
240
266
  self,
241
267
  request: Request[Any, Any, Any],
@@ -245,8 +271,6 @@ class GraphQLController(
245
271
  try:
246
272
  return await self.run(
247
273
  request,
248
- # TODO: check the dependency, above, can we make it so that
249
- # we don't need to type ignore here?
250
274
  context=context,
251
275
  root_value=root_value,
252
276
  )
@@ -328,14 +352,29 @@ class GraphQLController(
328
352
  root_value=root_value,
329
353
  )
330
354
 
355
+ @websocket()
356
+ async def websocket_endpoint(
357
+ self,
358
+ socket: WebSocket,
359
+ context_ws: Any,
360
+ root_value: Any,
361
+ ) -> None:
362
+ await self.run(
363
+ request=socket,
364
+ context=context_ws,
365
+ root_value=root_value,
366
+ )
367
+
331
368
  async def get_context(
332
- self, request: Request[Any, Any, Any], response: Response
369
+ self,
370
+ request: Union[Request[Any, Any, Any], WebSocket],
371
+ response: Union[Response, WebSocket],
333
372
  ) -> Context: # pragma: no cover
334
373
  msg = "`get_context` is not used by Litestar's controller"
335
374
  raise ValueError(msg)
336
375
 
337
376
  async def get_root_value(
338
- self, request: Request[Any, Any, Any]
377
+ self, request: Union[Request[Any, Any, Any], WebSocket]
339
378
  ) -> RootValue | None: # pragma: no cover
340
379
  msg = "`get_root_value` is not used by Litestar's controller"
341
380
  raise ValueError(msg)
@@ -343,54 +382,6 @@ class GraphQLController(
343
382
  async def get_sub_response(self, request: Request[Any, Any, Any]) -> Response:
344
383
  return self.temporal_response
345
384
 
346
- @websocket()
347
- async def websocket_endpoint(
348
- self,
349
- socket: WebSocket,
350
- context_ws: Any,
351
- root_value: Any,
352
- ) -> None:
353
- async def _get_context() -> Any:
354
- return context_ws
355
-
356
- async def _get_root_value() -> Any:
357
- return root_value
358
-
359
- preferred_protocol = self.pick_preferred_protocol(socket)
360
- if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
361
- await self.graphql_transport_ws_handler_class(
362
- schema=self.schema,
363
- debug=self.debug,
364
- connection_init_wait_timeout=self.connection_init_wait_timeout,
365
- get_context=_get_context,
366
- get_root_value=_get_root_value,
367
- ws=socket,
368
- ).handle()
369
- elif preferred_protocol == GRAPHQL_WS_PROTOCOL:
370
- await self.graphql_ws_handler_class(
371
- schema=self.schema,
372
- debug=self.debug,
373
- keep_alive=self.keep_alive,
374
- keep_alive_interval=self.keep_alive_interval,
375
- get_context=_get_context,
376
- get_root_value=_get_root_value,
377
- ws=socket,
378
- ).handle()
379
- else:
380
- await socket.close(code=WS_4406_PROTOCOL_NOT_ACCEPTABLE)
381
-
382
- def pick_preferred_protocol(self, socket: WebSocket) -> str | None:
383
- protocols: List[str] = socket.scope["subprotocols"]
384
- intersection: Set[str] = set(protocols) & set(self.protocols)
385
- return (
386
- min(
387
- intersection,
388
- key=lambda i: protocols.index(i) if i else "",
389
- default=None,
390
- )
391
- or None
392
- )
393
-
394
385
 
395
386
  def make_graphql_controller(
396
387
  schema: BaseSchema,
strawberry/quart/views.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import warnings
2
2
  from collections.abc import Mapping
3
3
  from typing import TYPE_CHECKING, AsyncGenerator, Callable, Dict, Optional, cast
4
+ from typing_extensions import TypeGuard
4
5
 
5
6
  from quart import Request, Response, request
6
7
  from quart.views import View
@@ -46,7 +47,9 @@ class QuartHTTPRequestAdapter(AsyncHTTPRequestAdapter):
46
47
 
47
48
 
48
49
  class GraphQLView(
49
- AsyncBaseHTTPView[Request, Response, Response, Context, RootValue],
50
+ AsyncBaseHTTPView[
51
+ Request, Response, Response, Request, Response, Context, RootValue
52
+ ],
50
53
  View,
51
54
  ):
52
55
  _ide_subscription_enabled = False
@@ -121,5 +124,16 @@ class GraphQLView(
121
124
  },
122
125
  )
123
126
 
127
+ def is_websocket_request(self, request: Request) -> TypeGuard[Request]:
128
+ return False
129
+
130
+ async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]:
131
+ raise NotImplementedError
132
+
133
+ async def create_websocket_response(
134
+ self, request: Request, subprotocol: Optional[str]
135
+ ) -> Response:
136
+ raise NotImplementedError
137
+
124
138
 
125
139
  __all__ = ["GraphQLView"]
strawberry/sanic/views.py CHANGED
@@ -13,6 +13,7 @@ from typing import (
13
13
  Type,
14
14
  cast,
15
15
  )
16
+ from typing_extensions import TypeGuard
16
17
 
17
18
  from sanic.request import Request
18
19
  from sanic.response import HTTPResponse, html
@@ -71,7 +72,15 @@ class SanicHTTPRequestAdapter(AsyncHTTPRequestAdapter):
71
72
 
72
73
 
73
74
  class GraphQLView(
74
- AsyncBaseHTTPView[Request, HTTPResponse, TemporalResponse, Context, RootValue],
75
+ AsyncBaseHTTPView[
76
+ Request,
77
+ HTTPResponse,
78
+ TemporalResponse,
79
+ Request,
80
+ TemporalResponse,
81
+ Context,
82
+ RootValue,
83
+ ],
75
84
  HTTPMethodView,
76
85
  ):
77
86
  """Class based view to handle GraphQL HTTP Requests.
@@ -206,5 +215,16 @@ class GraphQLView(
206
215
  # corner case
207
216
  return None # type: ignore
208
217
 
218
+ def is_websocket_request(self, request: Request) -> TypeGuard[Request]:
219
+ return False
220
+
221
+ async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]:
222
+ raise NotImplementedError
223
+
224
+ async def create_websocket_response(
225
+ self, request: Request, subprotocol: Optional[str]
226
+ ) -> TemporalResponse:
227
+ raise NotImplementedError
228
+
209
229
 
210
230
  __all__ = ["GraphQLView"]
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  import asyncio
4
4
  import logging
5
- from abc import ABC, abstractmethod
6
5
  from contextlib import suppress
7
6
  from typing import (
8
7
  TYPE_CHECKING,
@@ -16,6 +15,7 @@ from typing import (
16
15
 
17
16
  from graphql import GraphQLError, GraphQLSyntaxError, parse
18
17
 
18
+ from strawberry.http.exceptions import NonJsonMessageReceived
19
19
  from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
20
20
  CompleteMessage,
21
21
  ConnectionAckMessage,
@@ -38,6 +38,7 @@ from strawberry.utils.operation import get_operation_type
38
38
  if TYPE_CHECKING:
39
39
  from datetime import timedelta
40
40
 
41
+ from strawberry.http.async_base_view import AsyncWebSocketAdapter
41
42
  from strawberry.schema import BaseSchema
42
43
  from strawberry.schema.subscribe import SubscriptionResult
43
44
  from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
@@ -45,15 +46,21 @@ if TYPE_CHECKING:
45
46
  )
46
47
 
47
48
 
48
- class BaseGraphQLTransportWSHandler(ABC):
49
+ class BaseGraphQLTransportWSHandler:
49
50
  task_logger: logging.Logger = logging.getLogger("strawberry.ws.task")
50
51
 
51
52
  def __init__(
52
53
  self,
54
+ websocket: AsyncWebSocketAdapter,
55
+ context: object,
56
+ root_value: object,
53
57
  schema: BaseSchema,
54
58
  debug: bool,
55
59
  connection_init_wait_timeout: timedelta,
56
60
  ) -> None:
61
+ self.websocket = websocket
62
+ self.context = context
63
+ self.root_value = root_value
57
64
  self.schema = schema
58
65
  self.debug = debug
59
66
  self.connection_init_wait_timeout = connection_init_wait_timeout
@@ -65,28 +72,16 @@ class BaseGraphQLTransportWSHandler(ABC):
65
72
  self.completed_tasks: List[asyncio.Task] = []
66
73
  self.connection_params: Optional[Dict[str, Any]] = None
67
74
 
68
- @abstractmethod
69
- async def get_context(self) -> Any:
70
- """Return the operations context."""
71
-
72
- @abstractmethod
73
- async def get_root_value(self) -> Any:
74
- """Return the schemas root value."""
75
-
76
- @abstractmethod
77
- async def send_json(self, data: dict) -> None:
78
- """Send the data JSON encoded to the WebSocket client."""
79
-
80
- @abstractmethod
81
- async def close(self, code: int, reason: str) -> None:
82
- """Close the WebSocket with the passed code and reason."""
83
-
84
- @abstractmethod
85
- async def handle_request(self) -> Any:
86
- """Handle the request this instance was created for."""
87
-
88
75
  async def handle(self) -> Any:
89
- return await self.handle_request()
76
+ self.on_request_accepted()
77
+
78
+ try:
79
+ async for message in self.websocket.iter_json():
80
+ await self.handle_message(message)
81
+ except NonJsonMessageReceived:
82
+ await self.handle_invalid_message("WebSocket message type must be text")
83
+ finally:
84
+ await self.shutdown()
90
85
 
91
86
  async def shutdown(self) -> None:
92
87
  if self.connection_init_timeout_task:
@@ -118,7 +113,7 @@ class BaseGraphQLTransportWSHandler(ABC):
118
113
 
119
114
  self.connection_timed_out = True
120
115
  reason = "Connection initialisation timeout"
121
- await self.close(code=4408, reason=reason)
116
+ await self.websocket.close(code=4408, reason=reason)
122
117
  except Exception as error:
123
118
  await self.handle_task_exception(error) # pragma: no cover
124
119
  finally:
@@ -189,14 +184,16 @@ class BaseGraphQLTransportWSHandler(ABC):
189
184
  )
190
185
 
191
186
  if not isinstance(payload, dict):
192
- await self.close(code=4400, reason="Invalid connection init payload")
187
+ await self.websocket.close(
188
+ code=4400, reason="Invalid connection init payload"
189
+ )
193
190
  return
194
191
 
195
192
  self.connection_params = payload
196
193
 
197
194
  if self.connection_init_received:
198
195
  reason = "Too many initialisation requests"
199
- await self.close(code=4429, reason=reason)
196
+ await self.websocket.close(code=4429, reason=reason)
200
197
  return
201
198
 
202
199
  self.connection_init_received = True
@@ -211,13 +208,13 @@ class BaseGraphQLTransportWSHandler(ABC):
211
208
 
212
209
  async def handle_subscribe(self, message: SubscribeMessage) -> None:
213
210
  if not self.connection_acknowledged:
214
- await self.close(code=4401, reason="Unauthorized")
211
+ await self.websocket.close(code=4401, reason="Unauthorized")
215
212
  return
216
213
 
217
214
  try:
218
215
  graphql_document = parse(message.payload.query)
219
216
  except GraphQLSyntaxError as exc:
220
- await self.close(code=4400, reason=exc.message)
217
+ await self.websocket.close(code=4400, reason=exc.message)
221
218
  return
222
219
 
223
220
  try:
@@ -225,12 +222,14 @@ class BaseGraphQLTransportWSHandler(ABC):
225
222
  graphql_document, message.payload.operationName
226
223
  )
227
224
  except RuntimeError:
228
- await self.close(code=4400, reason="Can't get GraphQL operation type")
225
+ await self.websocket.close(
226
+ code=4400, reason="Can't get GraphQL operation type"
227
+ )
229
228
  return
230
229
 
231
230
  if message.id in self.operations:
232
231
  reason = f"Subscriber for {message.id} already exists"
233
- await self.close(code=4409, reason=reason)
232
+ await self.websocket.close(code=4409, reason=reason)
234
233
  return
235
234
 
236
235
  if self.debug: # pragma: no cover
@@ -240,26 +239,28 @@ class BaseGraphQLTransportWSHandler(ABC):
240
239
  message.payload.variables,
241
240
  )
242
241
 
243
- context = await self.get_context()
244
- if isinstance(context, dict):
245
- context["connection_params"] = self.connection_params
246
- root_value = await self.get_root_value()
242
+ if isinstance(self.context, dict):
243
+ self.context["connection_params"] = self.connection_params
244
+ elif hasattr(self.context, "connection_params"):
245
+ self.context.connection_params = self.connection_params
246
+
247
247
  result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult]
248
+
248
249
  # Get an AsyncGenerator yielding the results
249
250
  if operation_type == OperationType.SUBSCRIPTION:
250
251
  result_source = self.schema.subscribe(
251
252
  query=message.payload.query,
252
253
  variable_values=message.payload.variables,
253
254
  operation_name=message.payload.operationName,
254
- context_value=context,
255
- root_value=root_value,
255
+ context_value=self.context,
256
+ root_value=self.root_value,
256
257
  )
257
258
  else:
258
259
  result_source = self.schema.execute(
259
260
  query=message.payload.query,
260
261
  variable_values=message.payload.variables,
261
- context_value=context,
262
- root_value=root_value,
262
+ context_value=self.context,
263
+ root_value=self.root_value,
263
264
  operation_name=message.payload.operationName,
264
265
  )
265
266
 
@@ -312,11 +313,11 @@ class BaseGraphQLTransportWSHandler(ABC):
312
313
  await self.cleanup_operation(operation_id=message.id)
313
314
 
314
315
  async def handle_invalid_message(self, error_message: str) -> None:
315
- await self.close(code=4400, reason=error_message)
316
+ await self.websocket.close(code=4400, reason=error_message)
316
317
 
317
318
  async def send_message(self, message: GraphQLTransportMessage) -> None:
318
319
  data = message.as_dict()
319
- await self.send_json(data)
320
+ await self.websocket.send_json(data)
320
321
 
321
322
  async def cleanup_operation(self, operation_id: str) -> None:
322
323
  if operation_id not in self.operations: