strawberry-graphql 0.243.1__py3-none-any.whl → 0.244.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. strawberry/aiohttp/views.py +58 -44
  2. strawberry/asgi/__init__.py +62 -61
  3. strawberry/channels/__init__.py +1 -6
  4. strawberry/channels/handlers/base.py +2 -2
  5. strawberry/channels/handlers/http_handler.py +18 -1
  6. strawberry/channels/handlers/ws_handler.py +113 -58
  7. strawberry/django/views.py +19 -1
  8. strawberry/fastapi/router.py +27 -45
  9. strawberry/flask/views.py +15 -1
  10. strawberry/http/async_base_view.py +136 -16
  11. strawberry/http/exceptions.py +4 -0
  12. strawberry/http/typevars.py +11 -1
  13. strawberry/litestar/controller.py +77 -86
  14. strawberry/quart/views.py +15 -1
  15. strawberry/sanic/views.py +21 -1
  16. strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +41 -40
  17. strawberry/subscriptions/protocols/graphql_ws/handlers.py +46 -45
  18. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.0.dist-info}/METADATA +1 -1
  19. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.0.dist-info}/RECORD +22 -36
  20. strawberry/aiohttp/handlers/__init__.py +0 -6
  21. strawberry/aiohttp/handlers/graphql_transport_ws_handler.py +0 -62
  22. strawberry/aiohttp/handlers/graphql_ws_handler.py +0 -69
  23. strawberry/asgi/handlers/__init__.py +0 -6
  24. strawberry/asgi/handlers/graphql_transport_ws_handler.py +0 -66
  25. strawberry/asgi/handlers/graphql_ws_handler.py +0 -71
  26. strawberry/channels/handlers/graphql_transport_ws_handler.py +0 -62
  27. strawberry/channels/handlers/graphql_ws_handler.py +0 -72
  28. strawberry/fastapi/handlers/__init__.py +0 -6
  29. strawberry/fastapi/handlers/graphql_transport_ws_handler.py +0 -20
  30. strawberry/fastapi/handlers/graphql_ws_handler.py +0 -18
  31. strawberry/litestar/handlers/__init__.py +0 -0
  32. strawberry/litestar/handlers/graphql_transport_ws_handler.py +0 -60
  33. strawberry/litestar/handlers/graphql_ws_handler.py +0 -66
  34. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.0.dist-info}/LICENSE +0 -0
  35. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.0.dist-info}/WHEEL +0 -0
  36. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.0.dist-info}/entry_points.txt +0 -0
@@ -17,6 +17,7 @@ from typing import (
17
17
  Union,
18
18
  cast,
19
19
  )
20
+ from typing_extensions import TypeGuard
20
21
 
21
22
  from starlette import status
22
23
  from starlette.background import BackgroundTasks # noqa: TCH002
@@ -34,10 +35,9 @@ from fastapi import APIRouter, Depends, params
34
35
  from fastapi.datastructures import Default
35
36
  from fastapi.routing import APIRoute
36
37
  from fastapi.utils import generate_unique_id
37
- from strawberry.asgi import ASGIRequestAdapter
38
+ from strawberry.asgi import ASGIRequestAdapter, ASGIWebSocketAdapter
38
39
  from strawberry.exceptions import InvalidCustomContext
39
40
  from strawberry.fastapi.context import BaseContext, CustomContext
40
- from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler
41
41
  from strawberry.http import process_result
42
42
  from strawberry.http.async_base_view import AsyncBaseHTTPView
43
43
  from strawberry.http.exceptions import HTTPException
@@ -58,12 +58,14 @@ if TYPE_CHECKING:
58
58
 
59
59
 
60
60
  class GraphQLRouter(
61
- AsyncBaseHTTPView[Request, Response, Response, Context, RootValue], APIRouter
61
+ AsyncBaseHTTPView[
62
+ Request, Response, Response, WebSocket, WebSocket, Context, RootValue
63
+ ],
64
+ APIRouter,
62
65
  ):
63
- graphql_ws_handler_class = GraphQLWSHandler
64
- graphql_transport_ws_handler_class = GraphQLTransportWSHandler
65
66
  allow_queries_via_get = True
66
67
  request_adapter_class = ASGIRequestAdapter
68
+ websocket_adapter_class = ASGIWebSocketAdapter
67
69
 
68
70
  @staticmethod
69
71
  async def __get_root_value() -> None:
@@ -261,44 +263,7 @@ class GraphQLRouter(
261
263
  context: Context = Depends(self.context_getter),
262
264
  root_value: RootValue = Depends(self.root_value_getter),
263
265
  ) -> None:
264
- async def _get_context() -> Context:
265
- return context
266
-
267
- async def _get_root_value() -> RootValue:
268
- return root_value
269
-
270
- preferred_protocol = self.pick_preferred_protocol(websocket)
271
- if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
272
- await self.graphql_transport_ws_handler_class(
273
- schema=self.schema,
274
- debug=self.debug,
275
- connection_init_wait_timeout=self.connection_init_wait_timeout,
276
- get_context=_get_context,
277
- get_root_value=_get_root_value,
278
- ws=websocket,
279
- ).handle()
280
- elif preferred_protocol == GRAPHQL_WS_PROTOCOL:
281
- await self.graphql_ws_handler_class(
282
- schema=self.schema,
283
- debug=self.debug,
284
- keep_alive=self.keep_alive,
285
- keep_alive_interval=self.keep_alive_interval,
286
- get_context=_get_context,
287
- get_root_value=_get_root_value,
288
- ws=websocket,
289
- ).handle()
290
- else:
291
- # Code 4406 is "Subprotocol not acceptable"
292
- await websocket.close(code=4406)
293
-
294
- def pick_preferred_protocol(self, ws: WebSocket) -> Optional[str]:
295
- protocols = ws["subprotocols"]
296
- intersection = set(protocols) & set(self.protocols)
297
- return min(
298
- intersection,
299
- key=lambda i: protocols.index(i),
300
- default=None,
301
- )
266
+ await self.run(request=websocket, context=context, root_value=root_value)
302
267
 
303
268
  async def render_graphql_ide(self, request: Request) -> HTMLResponse:
304
269
  return HTMLResponse(self.graphql_ide_html)
@@ -309,12 +274,12 @@ class GraphQLRouter(
309
274
  return process_result(result)
310
275
 
311
276
  async def get_context(
312
- self, request: Request, response: Response
277
+ self, request: Union[Request, WebSocket], response: Union[Response, WebSocket]
313
278
  ) -> Context: # pragma: no cover
314
279
  raise ValueError("`get_context` is not used by FastAPI GraphQL Router")
315
280
 
316
281
  async def get_root_value(
317
- self, request: Request
282
+ self, request: Union[Request, WebSocket]
318
283
  ) -> Optional[RootValue]: # pragma: no cover
319
284
  raise ValueError("`get_root_value` is not used by FastAPI GraphQL Router")
320
285
 
@@ -350,5 +315,22 @@ class GraphQLRouter(
350
315
  },
351
316
  )
352
317
 
318
+ def is_websocket_request(
319
+ self, request: Union[Request, WebSocket]
320
+ ) -> TypeGuard[WebSocket]:
321
+ return request.scope["type"] == "websocket"
322
+
323
+ async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]:
324
+ protocols = request["subprotocols"]
325
+ intersection = set(protocols) & set(self.protocols)
326
+ sorted_intersection = sorted(intersection, key=protocols.index)
327
+ return next(iter(sorted_intersection), None)
328
+
329
+ async def create_websocket_response(
330
+ self, request: WebSocket, subprotocol: Optional[str]
331
+ ) -> WebSocket:
332
+ await request.accept(subprotocol=subprotocol)
333
+ return request
334
+
353
335
 
354
336
  __all__ = ["GraphQLRouter"]
strawberry/flask/views.py CHANGED
@@ -9,6 +9,7 @@ from typing import (
9
9
  Union,
10
10
  cast,
11
11
  )
12
+ from typing_extensions import TypeGuard
12
13
 
13
14
  from flask import Request, Response, render_template_string, request
14
15
  from flask.views import View
@@ -159,7 +160,9 @@ class AsyncFlaskHTTPRequestAdapter(AsyncHTTPRequestAdapter):
159
160
 
160
161
  class AsyncGraphQLView(
161
162
  BaseGraphQLView,
162
- AsyncBaseHTTPView[Request, Response, Response, Context, RootValue],
163
+ AsyncBaseHTTPView[
164
+ Request, Response, Response, Request, Response, Context, RootValue
165
+ ],
163
166
  View,
164
167
  ):
165
168
  methods = ["GET", "POST"]
@@ -187,6 +190,17 @@ class AsyncGraphQLView(
187
190
  async def render_graphql_ide(self, request: Request) -> Response:
188
191
  return render_template_string(self.graphql_ide_html) # type: ignore
189
192
 
193
+ def is_websocket_request(self, request: Request) -> TypeGuard[Request]:
194
+ return False
195
+
196
+ async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]:
197
+ raise NotImplementedError
198
+
199
+ async def create_websocket_response(
200
+ self, request: Request, subprotocol: Optional[str]
201
+ ) -> Response:
202
+ raise NotImplementedError
203
+
190
204
 
191
205
  __all__ = [
192
206
  "GraphQLView",
@@ -2,6 +2,7 @@ import abc
2
2
  import asyncio
3
3
  import contextlib
4
4
  import json
5
+ from datetime import timedelta
5
6
  from typing import (
6
7
  Any,
7
8
  AsyncGenerator,
@@ -13,8 +14,10 @@ from typing import (
13
14
  Optional,
14
15
  Tuple,
15
16
  Union,
17
+ cast,
18
+ overload,
16
19
  )
17
- from typing_extensions import Literal
20
+ from typing_extensions import Literal, TypeGuard
18
21
 
19
22
  from graphql import GraphQLError
20
23
 
@@ -29,6 +32,11 @@ from strawberry.http import (
29
32
  from strawberry.http.ides import GraphQL_IDE
30
33
  from strawberry.schema.base import BaseSchema
31
34
  from strawberry.schema.exceptions import InvalidOperationTypeError
35
+ from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
36
+ from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import (
37
+ BaseGraphQLTransportWSHandler,
38
+ )
39
+ from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
32
40
  from strawberry.types import ExecutionResult, SubscriptionExecutionResult
33
41
  from strawberry.types.graphql import OperationType
34
42
 
@@ -36,7 +44,15 @@ from .base import BaseView
36
44
  from .exceptions import HTTPException
37
45
  from .parse_content_type import parse_content_type
38
46
  from .types import FormData, HTTPMethod, QueryParams
39
- from .typevars import Context, Request, Response, RootValue, SubResponse
47
+ from .typevars import (
48
+ Context,
49
+ Request,
50
+ Response,
51
+ RootValue,
52
+ SubResponse,
53
+ WebSocketRequest,
54
+ WebSocketResponse,
55
+ )
40
56
 
41
57
 
42
58
  class AsyncHTTPRequestAdapter(abc.ABC):
@@ -63,14 +79,42 @@ class AsyncHTTPRequestAdapter(abc.ABC):
63
79
  async def get_form_data(self) -> FormData: ...
64
80
 
65
81
 
82
+ class AsyncWebSocketAdapter(abc.ABC):
83
+ @abc.abstractmethod
84
+ def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: ...
85
+
86
+ @abc.abstractmethod
87
+ async def send_json(self, message: Mapping[str, object]) -> None: ...
88
+
89
+ @abc.abstractmethod
90
+ async def close(self, code: int, reason: str) -> None: ...
91
+
92
+
66
93
  class AsyncBaseHTTPView(
67
94
  abc.ABC,
68
95
  BaseView[Request],
69
- Generic[Request, Response, SubResponse, Context, RootValue],
96
+ Generic[
97
+ Request,
98
+ Response,
99
+ SubResponse,
100
+ WebSocketRequest,
101
+ WebSocketResponse,
102
+ Context,
103
+ RootValue,
104
+ ],
70
105
  ):
71
106
  schema: BaseSchema
72
107
  graphql_ide: Optional[GraphQL_IDE]
108
+ debug: bool
109
+ keep_alive = False
110
+ keep_alive_interval: Optional[float] = None
111
+ connection_init_wait_timeout: timedelta = timedelta(minutes=1)
73
112
  request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter]
113
+ websocket_adapter_class: Callable[
114
+ [WebSocketRequest, WebSocketResponse], AsyncWebSocketAdapter
115
+ ]
116
+ graphql_transport_ws_handler_class = BaseGraphQLTransportWSHandler
117
+ graphql_ws_handler_class = BaseGraphQLWSHandler
74
118
 
75
119
  @property
76
120
  @abc.abstractmethod
@@ -80,10 +124,16 @@ class AsyncBaseHTTPView(
80
124
  async def get_sub_response(self, request: Request) -> SubResponse: ...
81
125
 
82
126
  @abc.abstractmethod
83
- async def get_context(self, request: Request, response: SubResponse) -> Context: ...
127
+ async def get_context(
128
+ self,
129
+ request: Union[Request, WebSocketRequest],
130
+ response: Union[SubResponse, WebSocketResponse],
131
+ ) -> Context: ...
84
132
 
85
133
  @abc.abstractmethod
86
- async def get_root_value(self, request: Request) -> Optional[RootValue]: ...
134
+ async def get_root_value(
135
+ self, request: Union[Request, WebSocketRequest]
136
+ ) -> Optional[RootValue]: ...
87
137
 
88
138
  @abc.abstractmethod
89
139
  def create_response(
@@ -102,6 +152,21 @@ class AsyncBaseHTTPView(
102
152
  ) -> Response:
103
153
  raise ValueError("Multipart responses are not supported")
104
154
 
155
+ @abc.abstractmethod
156
+ def is_websocket_request(
157
+ self, request: Union[Request, WebSocketRequest]
158
+ ) -> TypeGuard[WebSocketRequest]: ...
159
+
160
+ @abc.abstractmethod
161
+ async def pick_websocket_subprotocol(
162
+ self, request: WebSocketRequest
163
+ ) -> Optional[str]: ...
164
+
165
+ @abc.abstractmethod
166
+ async def create_websocket_response(
167
+ self, request: WebSocketRequest, subprotocol: Optional[str]
168
+ ) -> WebSocketResponse: ...
169
+
105
170
  async def execute_operation(
106
171
  self, request: Request, context: Context, root_value: Optional[RootValue]
107
172
  ) -> Union[ExecutionResult, SubscriptionExecutionResult]:
@@ -167,35 +232,90 @@ class AsyncBaseHTTPView(
167
232
  ) -> None:
168
233
  """Hook to allow custom handling of errors, used by the Sentry Integration."""
169
234
 
235
+ @overload
170
236
  async def run(
171
237
  self,
172
238
  request: Request,
173
239
  context: Optional[Context] = UNSET,
174
240
  root_value: Optional[RootValue] = UNSET,
175
- ) -> Response:
176
- request_adapter = self.request_adapter_class(request)
241
+ ) -> Response: ...
177
242
 
178
- if not self.is_request_allowed(request_adapter):
179
- raise HTTPException(405, "GraphQL only supports GET and POST requests.")
243
+ @overload
244
+ async def run(
245
+ self,
246
+ request: WebSocketRequest,
247
+ context: Optional[Context] = UNSET,
248
+ root_value: Optional[RootValue] = UNSET,
249
+ ) -> WebSocketResponse: ...
180
250
 
181
- if self.should_render_graphql_ide(request_adapter):
182
- if self.graphql_ide:
183
- return await self.render_graphql_ide(request)
251
+ async def run(
252
+ self,
253
+ request: Union[Request, WebSocketRequest],
254
+ context: Optional[Context] = UNSET,
255
+ root_value: Optional[RootValue] = UNSET,
256
+ ) -> Union[Response, WebSocketResponse]:
257
+ root_value = (
258
+ await self.get_root_value(request) if root_value is UNSET else root_value
259
+ )
260
+
261
+ if self.is_websocket_request(request):
262
+ websocket_subprotocol = await self.pick_websocket_subprotocol(request)
263
+ websocket_response = await self.create_websocket_response(
264
+ request, websocket_subprotocol
265
+ )
266
+ websocket = self.websocket_adapter_class(request, websocket_response)
267
+
268
+ context = (
269
+ await self.get_context(request, response=websocket_response)
270
+ if context is UNSET
271
+ else context
272
+ )
273
+
274
+ if websocket_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
275
+ await self.graphql_transport_ws_handler_class(
276
+ websocket=websocket,
277
+ context=context,
278
+ root_value=root_value,
279
+ schema=self.schema,
280
+ debug=self.debug,
281
+ connection_init_wait_timeout=self.connection_init_wait_timeout,
282
+ ).handle()
283
+ elif websocket_subprotocol == GRAPHQL_WS_PROTOCOL:
284
+ await self.graphql_ws_handler_class(
285
+ websocket=websocket,
286
+ context=context,
287
+ root_value=root_value,
288
+ schema=self.schema,
289
+ debug=self.debug,
290
+ keep_alive=self.keep_alive,
291
+ keep_alive_interval=self.keep_alive_interval,
292
+ ).handle()
184
293
  else:
185
- raise HTTPException(404, "Not Found")
294
+ await websocket.close(4406, "Subprotocol not acceptable")
295
+
296
+ return websocket_response
297
+ else:
298
+ request = cast(Request, request)
186
299
 
300
+ request_adapter = self.request_adapter_class(request)
187
301
  sub_response = await self.get_sub_response(request)
188
302
  context = (
189
303
  await self.get_context(request, response=sub_response)
190
304
  if context is UNSET
191
305
  else context
192
306
  )
193
- root_value = (
194
- await self.get_root_value(request) if root_value is UNSET else root_value
195
- )
196
307
 
197
308
  assert context
198
309
 
310
+ if not self.is_request_allowed(request_adapter):
311
+ raise HTTPException(405, "GraphQL only supports GET and POST requests.")
312
+
313
+ if self.should_render_graphql_ide(request_adapter):
314
+ if self.graphql_ide:
315
+ return await self.render_graphql_ide(request)
316
+ else:
317
+ raise HTTPException(404, "Not Found")
318
+
199
319
  try:
200
320
  result = await self.execute_operation(
201
321
  request=request, context=context, root_value=root_value
@@ -4,4 +4,8 @@ class HTTPException(Exception):
4
4
  self.reason = reason
5
5
 
6
6
 
7
+ class NonJsonMessageReceived(Exception):
8
+ pass
9
+
10
+
7
11
  __all__ = ["HTTPException"]
@@ -3,8 +3,18 @@ from typing import TypeVar
3
3
  Request = TypeVar("Request", contravariant=True)
4
4
  Response = TypeVar("Response")
5
5
  SubResponse = TypeVar("SubResponse")
6
+ WebSocketRequest = TypeVar("WebSocketRequest")
7
+ WebSocketResponse = TypeVar("WebSocketResponse")
6
8
  Context = TypeVar("Context")
7
9
  RootValue = TypeVar("RootValue")
8
10
 
9
11
 
10
- __all__ = ["Request", "Response", "SubResponse", "Context", "RootValue"]
12
+ __all__ = [
13
+ "Request",
14
+ "Response",
15
+ "SubResponse",
16
+ "WebSocketRequest",
17
+ "WebSocketResponse",
18
+ "Context",
19
+ "RootValue",
20
+ ]
@@ -7,19 +7,19 @@ from datetime import timedelta
7
7
  from typing import (
8
8
  TYPE_CHECKING,
9
9
  Any,
10
+ AsyncGenerator,
10
11
  AsyncIterator,
11
12
  Callable,
12
13
  Dict,
13
14
  FrozenSet,
14
- List,
15
15
  Optional,
16
- Set,
17
16
  Tuple,
18
17
  Type,
19
18
  TypedDict,
20
19
  Union,
21
20
  cast,
22
21
  )
22
+ from typing_extensions import TypeGuard
23
23
 
24
24
  from msgspec import Struct
25
25
 
@@ -35,23 +35,24 @@ from litestar import (
35
35
  )
36
36
  from litestar.background_tasks import BackgroundTasks
37
37
  from litestar.di import Provide
38
- from litestar.exceptions import NotFoundException, ValidationException
38
+ from litestar.exceptions import (
39
+ NotFoundException,
40
+ SerializationException,
41
+ ValidationException,
42
+ WebSocketDisconnect,
43
+ )
39
44
  from litestar.response.streaming import Stream
40
45
  from litestar.status_codes import HTTP_200_OK
41
46
  from strawberry.exceptions import InvalidCustomContext
42
- from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter
43
- from strawberry.http.exceptions import HTTPException
47
+ from strawberry.http.async_base_view import (
48
+ AsyncBaseHTTPView,
49
+ AsyncHTTPRequestAdapter,
50
+ AsyncWebSocketAdapter,
51
+ )
52
+ from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
44
53
  from strawberry.http.types import FormData, HTTPMethod, QueryParams
45
54
  from strawberry.http.typevars import Context, RootValue
46
55
  from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
47
- from strawberry.subscriptions.protocols.graphql_transport_ws import (
48
- WS_4406_PROTOCOL_NOT_ACCEPTABLE,
49
- )
50
-
51
- from .handlers.graphql_transport_ws_handler import (
52
- GraphQLTransportWSHandler as BaseGraphQLTransportWSHandler,
53
- )
54
- from .handlers.graphql_ws_handler import GraphQLWSHandler as BaseGraphQLWSHandler
55
56
 
56
57
  if TYPE_CHECKING:
57
58
  from collections.abc import Mapping
@@ -152,22 +153,6 @@ class GraphQLResource(Struct):
152
153
  extensions: Optional[dict[str, object]]
153
154
 
154
155
 
155
- class GraphQLWSHandler(BaseGraphQLWSHandler):
156
- async def get_context(self) -> Any:
157
- return await self._get_context()
158
-
159
- async def get_root_value(self) -> Any:
160
- return await self._get_root_value()
161
-
162
-
163
- class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
164
- async def get_context(self) -> Any:
165
- return await self._get_context()
166
-
167
- async def get_root_value(self) -> Any:
168
- return await self._get_root_value()
169
-
170
-
171
156
  class LitestarRequestAdapter(AsyncHTTPRequestAdapter):
172
157
  def __init__(self, request: Request[Any, Any, Any]) -> None:
173
158
  self.request = request
@@ -203,10 +188,37 @@ class LitestarRequestAdapter(AsyncHTTPRequestAdapter):
203
188
  return FormData(form=multipart_data, files=multipart_data)
204
189
 
205
190
 
191
+ class LitestarWebSocketAdapter(AsyncWebSocketAdapter):
192
+ def __init__(self, request: WebSocket, response: WebSocket) -> None:
193
+ self.ws = response
194
+
195
+ async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
196
+ try:
197
+ try:
198
+ while self.ws.connection_state != "disconnect":
199
+ yield await self.ws.receive_json()
200
+ except (SerializationException, ValueError):
201
+ raise NonJsonMessageReceived()
202
+ except WebSocketDisconnect:
203
+ pass
204
+
205
+ async def send_json(self, message: Mapping[str, object]) -> None:
206
+ await self.ws.send_json(message)
207
+
208
+ async def close(self, code: int, reason: str) -> None:
209
+ await self.ws.close(code=code, reason=reason)
210
+
211
+
206
212
  class GraphQLController(
207
213
  Controller,
208
214
  AsyncBaseHTTPView[
209
- Request[Any, Any, Any], Response[Any], Response[Any], Context, RootValue
215
+ Request[Any, Any, Any],
216
+ Response[Any],
217
+ Response[Any],
218
+ WebSocket,
219
+ WebSocket,
220
+ Context,
221
+ RootValue,
210
222
  ],
211
223
  ):
212
224
  path: str = ""
@@ -219,10 +231,7 @@ class GraphQLController(
219
231
  }
220
232
 
221
233
  request_adapter_class = LitestarRequestAdapter
222
- graphql_ws_handler_class: Type[GraphQLWSHandler] = GraphQLWSHandler
223
- graphql_transport_ws_handler_class: Type[GraphQLTransportWSHandler] = (
224
- GraphQLTransportWSHandler
225
- )
234
+ websocket_adapter_class = LitestarWebSocketAdapter
226
235
 
227
236
  allow_queries_via_get: bool = True
228
237
  graphiql_allowed_accept: FrozenSet[str] = frozenset({"text/html", "*/*"})
@@ -236,6 +245,23 @@ class GraphQLController(
236
245
  keep_alive: bool = False
237
246
  keep_alive_interval: float = 1
238
247
 
248
+ def is_websocket_request(
249
+ self, request: Union[Request, WebSocket]
250
+ ) -> TypeGuard[WebSocket]:
251
+ return isinstance(request, WebSocket)
252
+
253
+ async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]:
254
+ subprotocols = request.scope["subprotocols"]
255
+ intersection = set(subprotocols) & set(self.protocols)
256
+ sorted_intersection = sorted(intersection, key=subprotocols.index)
257
+ return next(iter(sorted_intersection), None)
258
+
259
+ async def create_websocket_response(
260
+ self, request: WebSocket, subprotocol: Optional[str]
261
+ ) -> WebSocket:
262
+ await request.accept(subprotocols=subprotocol)
263
+ return request
264
+
239
265
  async def execute_request(
240
266
  self,
241
267
  request: Request[Any, Any, Any],
@@ -245,8 +271,6 @@ class GraphQLController(
245
271
  try:
246
272
  return await self.run(
247
273
  request,
248
- # TODO: check the dependency, above, can we make it so that
249
- # we don't need to type ignore here?
250
274
  context=context,
251
275
  root_value=root_value,
252
276
  )
@@ -328,14 +352,29 @@ class GraphQLController(
328
352
  root_value=root_value,
329
353
  )
330
354
 
355
+ @websocket()
356
+ async def websocket_endpoint(
357
+ self,
358
+ socket: WebSocket,
359
+ context_ws: Any,
360
+ root_value: Any,
361
+ ) -> None:
362
+ await self.run(
363
+ request=socket,
364
+ context=context_ws,
365
+ root_value=root_value,
366
+ )
367
+
331
368
  async def get_context(
332
- self, request: Request[Any, Any, Any], response: Response
369
+ self,
370
+ request: Union[Request[Any, Any, Any], WebSocket],
371
+ response: Union[Response, WebSocket],
333
372
  ) -> Context: # pragma: no cover
334
373
  msg = "`get_context` is not used by Litestar's controller"
335
374
  raise ValueError(msg)
336
375
 
337
376
  async def get_root_value(
338
- self, request: Request[Any, Any, Any]
377
+ self, request: Union[Request[Any, Any, Any], WebSocket]
339
378
  ) -> RootValue | None: # pragma: no cover
340
379
  msg = "`get_root_value` is not used by Litestar's controller"
341
380
  raise ValueError(msg)
@@ -343,54 +382,6 @@ class GraphQLController(
343
382
  async def get_sub_response(self, request: Request[Any, Any, Any]) -> Response:
344
383
  return self.temporal_response
345
384
 
346
- @websocket()
347
- async def websocket_endpoint(
348
- self,
349
- socket: WebSocket,
350
- context_ws: Any,
351
- root_value: Any,
352
- ) -> None:
353
- async def _get_context() -> Any:
354
- return context_ws
355
-
356
- async def _get_root_value() -> Any:
357
- return root_value
358
-
359
- preferred_protocol = self.pick_preferred_protocol(socket)
360
- if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
361
- await self.graphql_transport_ws_handler_class(
362
- schema=self.schema,
363
- debug=self.debug,
364
- connection_init_wait_timeout=self.connection_init_wait_timeout,
365
- get_context=_get_context,
366
- get_root_value=_get_root_value,
367
- ws=socket,
368
- ).handle()
369
- elif preferred_protocol == GRAPHQL_WS_PROTOCOL:
370
- await self.graphql_ws_handler_class(
371
- schema=self.schema,
372
- debug=self.debug,
373
- keep_alive=self.keep_alive,
374
- keep_alive_interval=self.keep_alive_interval,
375
- get_context=_get_context,
376
- get_root_value=_get_root_value,
377
- ws=socket,
378
- ).handle()
379
- else:
380
- await socket.close(code=WS_4406_PROTOCOL_NOT_ACCEPTABLE)
381
-
382
- def pick_preferred_protocol(self, socket: WebSocket) -> str | None:
383
- protocols: List[str] = socket.scope["subprotocols"]
384
- intersection: Set[str] = set(protocols) & set(self.protocols)
385
- return (
386
- min(
387
- intersection,
388
- key=lambda i: protocols.index(i) if i else "",
389
- default=None,
390
- )
391
- or None
392
- )
393
-
394
385
 
395
386
  def make_graphql_controller(
396
387
  schema: BaseSchema,