strawberry-graphql 0.243.1__py3-none-any.whl → 0.244.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. strawberry/aiohttp/views.py +58 -44
  2. strawberry/asgi/__init__.py +62 -61
  3. strawberry/channels/__init__.py +1 -6
  4. strawberry/channels/handlers/base.py +2 -2
  5. strawberry/channels/handlers/http_handler.py +18 -1
  6. strawberry/channels/handlers/ws_handler.py +113 -58
  7. strawberry/codegen/query_codegen.py +8 -6
  8. strawberry/django/views.py +19 -1
  9. strawberry/fastapi/router.py +27 -45
  10. strawberry/flask/views.py +15 -1
  11. strawberry/http/async_base_view.py +136 -16
  12. strawberry/http/exceptions.py +4 -0
  13. strawberry/http/typevars.py +11 -1
  14. strawberry/litestar/controller.py +77 -86
  15. strawberry/quart/views.py +15 -1
  16. strawberry/sanic/views.py +21 -1
  17. strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +41 -40
  18. strawberry/subscriptions/protocols/graphql_ws/handlers.py +46 -45
  19. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/METADATA +1 -1
  20. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/RECORD +23 -37
  21. strawberry/aiohttp/handlers/__init__.py +0 -6
  22. strawberry/aiohttp/handlers/graphql_transport_ws_handler.py +0 -62
  23. strawberry/aiohttp/handlers/graphql_ws_handler.py +0 -69
  24. strawberry/asgi/handlers/__init__.py +0 -6
  25. strawberry/asgi/handlers/graphql_transport_ws_handler.py +0 -66
  26. strawberry/asgi/handlers/graphql_ws_handler.py +0 -71
  27. strawberry/channels/handlers/graphql_transport_ws_handler.py +0 -62
  28. strawberry/channels/handlers/graphql_ws_handler.py +0 -72
  29. strawberry/fastapi/handlers/__init__.py +0 -6
  30. strawberry/fastapi/handlers/graphql_transport_ws_handler.py +0 -20
  31. strawberry/fastapi/handlers/graphql_ws_handler.py +0 -18
  32. strawberry/litestar/handlers/__init__.py +0 -0
  33. strawberry/litestar/handlers/graphql_transport_ws_handler.py +0 -60
  34. strawberry/litestar/handlers/graphql_ws_handler.py +0 -66
  35. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/LICENSE +0 -0
  36. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/WHEEL +0 -0
  37. {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/entry_points.txt +0 -0
@@ -13,6 +13,7 @@ from typing import (
13
13
  Union,
14
14
  cast,
15
15
  )
16
+ from typing_extensions import TypeGuard
16
17
 
17
18
  from asgiref.sync import markcoroutinefunction
18
19
  from django.core.serializers.json import DjangoJSONEncoder
@@ -258,7 +259,13 @@ class GraphQLView(
258
259
  class AsyncGraphQLView(
259
260
  BaseView,
260
261
  AsyncBaseHTTPView[
261
- HttpRequest, HttpResponseBase, TemporalHttpResponse, Context, RootValue
262
+ HttpRequest,
263
+ HttpResponseBase,
264
+ TemporalHttpResponse,
265
+ HttpRequest,
266
+ TemporalHttpResponse,
267
+ Context,
268
+ RootValue,
262
269
  ],
263
270
  View,
264
271
  ):
@@ -312,5 +319,16 @@ class AsyncGraphQLView(
312
319
 
313
320
  return response
314
321
 
322
+ def is_websocket_request(self, request: HttpRequest) -> TypeGuard[HttpRequest]:
323
+ return False
324
+
325
+ async def pick_websocket_subprotocol(self, request: HttpRequest) -> Optional[str]:
326
+ raise NotImplementedError
327
+
328
+ async def create_websocket_response(
329
+ self, request: HttpRequest, subprotocol: Optional[str]
330
+ ) -> TemporalHttpResponse:
331
+ raise NotImplementedError
332
+
315
333
 
316
334
  __all__ = ["GraphQLView", "AsyncGraphQLView"]
@@ -17,6 +17,7 @@ from typing import (
17
17
  Union,
18
18
  cast,
19
19
  )
20
+ from typing_extensions import TypeGuard
20
21
 
21
22
  from starlette import status
22
23
  from starlette.background import BackgroundTasks # noqa: TCH002
@@ -34,10 +35,9 @@ from fastapi import APIRouter, Depends, params
34
35
  from fastapi.datastructures import Default
35
36
  from fastapi.routing import APIRoute
36
37
  from fastapi.utils import generate_unique_id
37
- from strawberry.asgi import ASGIRequestAdapter
38
+ from strawberry.asgi import ASGIRequestAdapter, ASGIWebSocketAdapter
38
39
  from strawberry.exceptions import InvalidCustomContext
39
40
  from strawberry.fastapi.context import BaseContext, CustomContext
40
- from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler
41
41
  from strawberry.http import process_result
42
42
  from strawberry.http.async_base_view import AsyncBaseHTTPView
43
43
  from strawberry.http.exceptions import HTTPException
@@ -58,12 +58,14 @@ if TYPE_CHECKING:
58
58
 
59
59
 
60
60
  class GraphQLRouter(
61
- AsyncBaseHTTPView[Request, Response, Response, Context, RootValue], APIRouter
61
+ AsyncBaseHTTPView[
62
+ Request, Response, Response, WebSocket, WebSocket, Context, RootValue
63
+ ],
64
+ APIRouter,
62
65
  ):
63
- graphql_ws_handler_class = GraphQLWSHandler
64
- graphql_transport_ws_handler_class = GraphQLTransportWSHandler
65
66
  allow_queries_via_get = True
66
67
  request_adapter_class = ASGIRequestAdapter
68
+ websocket_adapter_class = ASGIWebSocketAdapter
67
69
 
68
70
  @staticmethod
69
71
  async def __get_root_value() -> None:
@@ -261,44 +263,7 @@ class GraphQLRouter(
261
263
  context: Context = Depends(self.context_getter),
262
264
  root_value: RootValue = Depends(self.root_value_getter),
263
265
  ) -> None:
264
- async def _get_context() -> Context:
265
- return context
266
-
267
- async def _get_root_value() -> RootValue:
268
- return root_value
269
-
270
- preferred_protocol = self.pick_preferred_protocol(websocket)
271
- if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
272
- await self.graphql_transport_ws_handler_class(
273
- schema=self.schema,
274
- debug=self.debug,
275
- connection_init_wait_timeout=self.connection_init_wait_timeout,
276
- get_context=_get_context,
277
- get_root_value=_get_root_value,
278
- ws=websocket,
279
- ).handle()
280
- elif preferred_protocol == GRAPHQL_WS_PROTOCOL:
281
- await self.graphql_ws_handler_class(
282
- schema=self.schema,
283
- debug=self.debug,
284
- keep_alive=self.keep_alive,
285
- keep_alive_interval=self.keep_alive_interval,
286
- get_context=_get_context,
287
- get_root_value=_get_root_value,
288
- ws=websocket,
289
- ).handle()
290
- else:
291
- # Code 4406 is "Subprotocol not acceptable"
292
- await websocket.close(code=4406)
293
-
294
- def pick_preferred_protocol(self, ws: WebSocket) -> Optional[str]:
295
- protocols = ws["subprotocols"]
296
- intersection = set(protocols) & set(self.protocols)
297
- return min(
298
- intersection,
299
- key=lambda i: protocols.index(i),
300
- default=None,
301
- )
266
+ await self.run(request=websocket, context=context, root_value=root_value)
302
267
 
303
268
  async def render_graphql_ide(self, request: Request) -> HTMLResponse:
304
269
  return HTMLResponse(self.graphql_ide_html)
@@ -309,12 +274,12 @@ class GraphQLRouter(
309
274
  return process_result(result)
310
275
 
311
276
  async def get_context(
312
- self, request: Request, response: Response
277
+ self, request: Union[Request, WebSocket], response: Union[Response, WebSocket]
313
278
  ) -> Context: # pragma: no cover
314
279
  raise ValueError("`get_context` is not used by FastAPI GraphQL Router")
315
280
 
316
281
  async def get_root_value(
317
- self, request: Request
282
+ self, request: Union[Request, WebSocket]
318
283
  ) -> Optional[RootValue]: # pragma: no cover
319
284
  raise ValueError("`get_root_value` is not used by FastAPI GraphQL Router")
320
285
 
@@ -350,5 +315,22 @@ class GraphQLRouter(
350
315
  },
351
316
  )
352
317
 
318
+ def is_websocket_request(
319
+ self, request: Union[Request, WebSocket]
320
+ ) -> TypeGuard[WebSocket]:
321
+ return request.scope["type"] == "websocket"
322
+
323
+ async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]:
324
+ protocols = request["subprotocols"]
325
+ intersection = set(protocols) & set(self.protocols)
326
+ sorted_intersection = sorted(intersection, key=protocols.index)
327
+ return next(iter(sorted_intersection), None)
328
+
329
+ async def create_websocket_response(
330
+ self, request: WebSocket, subprotocol: Optional[str]
331
+ ) -> WebSocket:
332
+ await request.accept(subprotocol=subprotocol)
333
+ return request
334
+
353
335
 
354
336
  __all__ = ["GraphQLRouter"]
strawberry/flask/views.py CHANGED
@@ -9,6 +9,7 @@ from typing import (
9
9
  Union,
10
10
  cast,
11
11
  )
12
+ from typing_extensions import TypeGuard
12
13
 
13
14
  from flask import Request, Response, render_template_string, request
14
15
  from flask.views import View
@@ -159,7 +160,9 @@ class AsyncFlaskHTTPRequestAdapter(AsyncHTTPRequestAdapter):
159
160
 
160
161
  class AsyncGraphQLView(
161
162
  BaseGraphQLView,
162
- AsyncBaseHTTPView[Request, Response, Response, Context, RootValue],
163
+ AsyncBaseHTTPView[
164
+ Request, Response, Response, Request, Response, Context, RootValue
165
+ ],
163
166
  View,
164
167
  ):
165
168
  methods = ["GET", "POST"]
@@ -187,6 +190,17 @@ class AsyncGraphQLView(
187
190
  async def render_graphql_ide(self, request: Request) -> Response:
188
191
  return render_template_string(self.graphql_ide_html) # type: ignore
189
192
 
193
+ def is_websocket_request(self, request: Request) -> TypeGuard[Request]:
194
+ return False
195
+
196
+ async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]:
197
+ raise NotImplementedError
198
+
199
+ async def create_websocket_response(
200
+ self, request: Request, subprotocol: Optional[str]
201
+ ) -> Response:
202
+ raise NotImplementedError
203
+
190
204
 
191
205
  __all__ = [
192
206
  "GraphQLView",
@@ -2,6 +2,7 @@ import abc
2
2
  import asyncio
3
3
  import contextlib
4
4
  import json
5
+ from datetime import timedelta
5
6
  from typing import (
6
7
  Any,
7
8
  AsyncGenerator,
@@ -13,8 +14,10 @@ from typing import (
13
14
  Optional,
14
15
  Tuple,
15
16
  Union,
17
+ cast,
18
+ overload,
16
19
  )
17
- from typing_extensions import Literal
20
+ from typing_extensions import Literal, TypeGuard
18
21
 
19
22
  from graphql import GraphQLError
20
23
 
@@ -29,6 +32,11 @@ from strawberry.http import (
29
32
  from strawberry.http.ides import GraphQL_IDE
30
33
  from strawberry.schema.base import BaseSchema
31
34
  from strawberry.schema.exceptions import InvalidOperationTypeError
35
+ from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
36
+ from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import (
37
+ BaseGraphQLTransportWSHandler,
38
+ )
39
+ from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
32
40
  from strawberry.types import ExecutionResult, SubscriptionExecutionResult
33
41
  from strawberry.types.graphql import OperationType
34
42
 
@@ -36,7 +44,15 @@ from .base import BaseView
36
44
  from .exceptions import HTTPException
37
45
  from .parse_content_type import parse_content_type
38
46
  from .types import FormData, HTTPMethod, QueryParams
39
- from .typevars import Context, Request, Response, RootValue, SubResponse
47
+ from .typevars import (
48
+ Context,
49
+ Request,
50
+ Response,
51
+ RootValue,
52
+ SubResponse,
53
+ WebSocketRequest,
54
+ WebSocketResponse,
55
+ )
40
56
 
41
57
 
42
58
  class AsyncHTTPRequestAdapter(abc.ABC):
@@ -63,14 +79,42 @@ class AsyncHTTPRequestAdapter(abc.ABC):
63
79
  async def get_form_data(self) -> FormData: ...
64
80
 
65
81
 
82
+ class AsyncWebSocketAdapter(abc.ABC):
83
+ @abc.abstractmethod
84
+ def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: ...
85
+
86
+ @abc.abstractmethod
87
+ async def send_json(self, message: Mapping[str, object]) -> None: ...
88
+
89
+ @abc.abstractmethod
90
+ async def close(self, code: int, reason: str) -> None: ...
91
+
92
+
66
93
  class AsyncBaseHTTPView(
67
94
  abc.ABC,
68
95
  BaseView[Request],
69
- Generic[Request, Response, SubResponse, Context, RootValue],
96
+ Generic[
97
+ Request,
98
+ Response,
99
+ SubResponse,
100
+ WebSocketRequest,
101
+ WebSocketResponse,
102
+ Context,
103
+ RootValue,
104
+ ],
70
105
  ):
71
106
  schema: BaseSchema
72
107
  graphql_ide: Optional[GraphQL_IDE]
108
+ debug: bool
109
+ keep_alive = False
110
+ keep_alive_interval: Optional[float] = None
111
+ connection_init_wait_timeout: timedelta = timedelta(minutes=1)
73
112
  request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter]
113
+ websocket_adapter_class: Callable[
114
+ [WebSocketRequest, WebSocketResponse], AsyncWebSocketAdapter
115
+ ]
116
+ graphql_transport_ws_handler_class = BaseGraphQLTransportWSHandler
117
+ graphql_ws_handler_class = BaseGraphQLWSHandler
74
118
 
75
119
  @property
76
120
  @abc.abstractmethod
@@ -80,10 +124,16 @@ class AsyncBaseHTTPView(
80
124
  async def get_sub_response(self, request: Request) -> SubResponse: ...
81
125
 
82
126
  @abc.abstractmethod
83
- async def get_context(self, request: Request, response: SubResponse) -> Context: ...
127
+ async def get_context(
128
+ self,
129
+ request: Union[Request, WebSocketRequest],
130
+ response: Union[SubResponse, WebSocketResponse],
131
+ ) -> Context: ...
84
132
 
85
133
  @abc.abstractmethod
86
- async def get_root_value(self, request: Request) -> Optional[RootValue]: ...
134
+ async def get_root_value(
135
+ self, request: Union[Request, WebSocketRequest]
136
+ ) -> Optional[RootValue]: ...
87
137
 
88
138
  @abc.abstractmethod
89
139
  def create_response(
@@ -102,6 +152,21 @@ class AsyncBaseHTTPView(
102
152
  ) -> Response:
103
153
  raise ValueError("Multipart responses are not supported")
104
154
 
155
+ @abc.abstractmethod
156
+ def is_websocket_request(
157
+ self, request: Union[Request, WebSocketRequest]
158
+ ) -> TypeGuard[WebSocketRequest]: ...
159
+
160
+ @abc.abstractmethod
161
+ async def pick_websocket_subprotocol(
162
+ self, request: WebSocketRequest
163
+ ) -> Optional[str]: ...
164
+
165
+ @abc.abstractmethod
166
+ async def create_websocket_response(
167
+ self, request: WebSocketRequest, subprotocol: Optional[str]
168
+ ) -> WebSocketResponse: ...
169
+
105
170
  async def execute_operation(
106
171
  self, request: Request, context: Context, root_value: Optional[RootValue]
107
172
  ) -> Union[ExecutionResult, SubscriptionExecutionResult]:
@@ -167,35 +232,90 @@ class AsyncBaseHTTPView(
167
232
  ) -> None:
168
233
  """Hook to allow custom handling of errors, used by the Sentry Integration."""
169
234
 
235
+ @overload
170
236
  async def run(
171
237
  self,
172
238
  request: Request,
173
239
  context: Optional[Context] = UNSET,
174
240
  root_value: Optional[RootValue] = UNSET,
175
- ) -> Response:
176
- request_adapter = self.request_adapter_class(request)
241
+ ) -> Response: ...
177
242
 
178
- if not self.is_request_allowed(request_adapter):
179
- raise HTTPException(405, "GraphQL only supports GET and POST requests.")
243
+ @overload
244
+ async def run(
245
+ self,
246
+ request: WebSocketRequest,
247
+ context: Optional[Context] = UNSET,
248
+ root_value: Optional[RootValue] = UNSET,
249
+ ) -> WebSocketResponse: ...
180
250
 
181
- if self.should_render_graphql_ide(request_adapter):
182
- if self.graphql_ide:
183
- return await self.render_graphql_ide(request)
251
+ async def run(
252
+ self,
253
+ request: Union[Request, WebSocketRequest],
254
+ context: Optional[Context] = UNSET,
255
+ root_value: Optional[RootValue] = UNSET,
256
+ ) -> Union[Response, WebSocketResponse]:
257
+ root_value = (
258
+ await self.get_root_value(request) if root_value is UNSET else root_value
259
+ )
260
+
261
+ if self.is_websocket_request(request):
262
+ websocket_subprotocol = await self.pick_websocket_subprotocol(request)
263
+ websocket_response = await self.create_websocket_response(
264
+ request, websocket_subprotocol
265
+ )
266
+ websocket = self.websocket_adapter_class(request, websocket_response)
267
+
268
+ context = (
269
+ await self.get_context(request, response=websocket_response)
270
+ if context is UNSET
271
+ else context
272
+ )
273
+
274
+ if websocket_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
275
+ await self.graphql_transport_ws_handler_class(
276
+ websocket=websocket,
277
+ context=context,
278
+ root_value=root_value,
279
+ schema=self.schema,
280
+ debug=self.debug,
281
+ connection_init_wait_timeout=self.connection_init_wait_timeout,
282
+ ).handle()
283
+ elif websocket_subprotocol == GRAPHQL_WS_PROTOCOL:
284
+ await self.graphql_ws_handler_class(
285
+ websocket=websocket,
286
+ context=context,
287
+ root_value=root_value,
288
+ schema=self.schema,
289
+ debug=self.debug,
290
+ keep_alive=self.keep_alive,
291
+ keep_alive_interval=self.keep_alive_interval,
292
+ ).handle()
184
293
  else:
185
- raise HTTPException(404, "Not Found")
294
+ await websocket.close(4406, "Subprotocol not acceptable")
295
+
296
+ return websocket_response
297
+ else:
298
+ request = cast(Request, request)
186
299
 
300
+ request_adapter = self.request_adapter_class(request)
187
301
  sub_response = await self.get_sub_response(request)
188
302
  context = (
189
303
  await self.get_context(request, response=sub_response)
190
304
  if context is UNSET
191
305
  else context
192
306
  )
193
- root_value = (
194
- await self.get_root_value(request) if root_value is UNSET else root_value
195
- )
196
307
 
197
308
  assert context
198
309
 
310
+ if not self.is_request_allowed(request_adapter):
311
+ raise HTTPException(405, "GraphQL only supports GET and POST requests.")
312
+
313
+ if self.should_render_graphql_ide(request_adapter):
314
+ if self.graphql_ide:
315
+ return await self.render_graphql_ide(request)
316
+ else:
317
+ raise HTTPException(404, "Not Found")
318
+
199
319
  try:
200
320
  result = await self.execute_operation(
201
321
  request=request, context=context, root_value=root_value
@@ -4,4 +4,8 @@ class HTTPException(Exception):
4
4
  self.reason = reason
5
5
 
6
6
 
7
+ class NonJsonMessageReceived(Exception):
8
+ pass
9
+
10
+
7
11
  __all__ = ["HTTPException"]
@@ -3,8 +3,18 @@ from typing import TypeVar
3
3
  Request = TypeVar("Request", contravariant=True)
4
4
  Response = TypeVar("Response")
5
5
  SubResponse = TypeVar("SubResponse")
6
+ WebSocketRequest = TypeVar("WebSocketRequest")
7
+ WebSocketResponse = TypeVar("WebSocketResponse")
6
8
  Context = TypeVar("Context")
7
9
  RootValue = TypeVar("RootValue")
8
10
 
9
11
 
10
- __all__ = ["Request", "Response", "SubResponse", "Context", "RootValue"]
12
+ __all__ = [
13
+ "Request",
14
+ "Response",
15
+ "SubResponse",
16
+ "WebSocketRequest",
17
+ "WebSocketResponse",
18
+ "Context",
19
+ "RootValue",
20
+ ]