strawberry-graphql 0.243.1__py3-none-any.whl → 0.244.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strawberry/aiohttp/views.py +58 -44
- strawberry/asgi/__init__.py +62 -61
- strawberry/channels/__init__.py +1 -6
- strawberry/channels/handlers/base.py +2 -2
- strawberry/channels/handlers/http_handler.py +18 -1
- strawberry/channels/handlers/ws_handler.py +113 -58
- strawberry/django/views.py +19 -1
- strawberry/fastapi/router.py +27 -45
- strawberry/flask/views.py +15 -1
- strawberry/http/async_base_view.py +136 -16
- strawberry/http/exceptions.py +4 -0
- strawberry/http/typevars.py +11 -1
- strawberry/litestar/controller.py +77 -86
- strawberry/quart/views.py +15 -1
- strawberry/sanic/views.py +21 -1
- strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +41 -40
- strawberry/subscriptions/protocols/graphql_ws/handlers.py +46 -45
- {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.0.dist-info}/METADATA +1 -1
- {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.0.dist-info}/RECORD +22 -36
- strawberry/aiohttp/handlers/__init__.py +0 -6
- strawberry/aiohttp/handlers/graphql_transport_ws_handler.py +0 -62
- strawberry/aiohttp/handlers/graphql_ws_handler.py +0 -69
- strawberry/asgi/handlers/__init__.py +0 -6
- strawberry/asgi/handlers/graphql_transport_ws_handler.py +0 -66
- strawberry/asgi/handlers/graphql_ws_handler.py +0 -71
- strawberry/channels/handlers/graphql_transport_ws_handler.py +0 -62
- strawberry/channels/handlers/graphql_ws_handler.py +0 -72
- strawberry/fastapi/handlers/__init__.py +0 -6
- strawberry/fastapi/handlers/graphql_transport_ws_handler.py +0 -20
- strawberry/fastapi/handlers/graphql_ws_handler.py +0 -18
- strawberry/litestar/handlers/__init__.py +0 -0
- strawberry/litestar/handlers/graphql_transport_ws_handler.py +0 -60
- strawberry/litestar/handlers/graphql_ws_handler.py +0 -66
- {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.0.dist-info}/LICENSE +0 -0
- {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.0.dist-info}/WHEEL +0 -0
- {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.0.dist-info}/entry_points.txt +0 -0
strawberry/fastapi/router.py
CHANGED
@@ -17,6 +17,7 @@ from typing import (
|
|
17
17
|
Union,
|
18
18
|
cast,
|
19
19
|
)
|
20
|
+
from typing_extensions import TypeGuard
|
20
21
|
|
21
22
|
from starlette import status
|
22
23
|
from starlette.background import BackgroundTasks # noqa: TCH002
|
@@ -34,10 +35,9 @@ from fastapi import APIRouter, Depends, params
|
|
34
35
|
from fastapi.datastructures import Default
|
35
36
|
from fastapi.routing import APIRoute
|
36
37
|
from fastapi.utils import generate_unique_id
|
37
|
-
from strawberry.asgi import ASGIRequestAdapter
|
38
|
+
from strawberry.asgi import ASGIRequestAdapter, ASGIWebSocketAdapter
|
38
39
|
from strawberry.exceptions import InvalidCustomContext
|
39
40
|
from strawberry.fastapi.context import BaseContext, CustomContext
|
40
|
-
from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler
|
41
41
|
from strawberry.http import process_result
|
42
42
|
from strawberry.http.async_base_view import AsyncBaseHTTPView
|
43
43
|
from strawberry.http.exceptions import HTTPException
|
@@ -58,12 +58,14 @@ if TYPE_CHECKING:
|
|
58
58
|
|
59
59
|
|
60
60
|
class GraphQLRouter(
|
61
|
-
AsyncBaseHTTPView[
|
61
|
+
AsyncBaseHTTPView[
|
62
|
+
Request, Response, Response, WebSocket, WebSocket, Context, RootValue
|
63
|
+
],
|
64
|
+
APIRouter,
|
62
65
|
):
|
63
|
-
graphql_ws_handler_class = GraphQLWSHandler
|
64
|
-
graphql_transport_ws_handler_class = GraphQLTransportWSHandler
|
65
66
|
allow_queries_via_get = True
|
66
67
|
request_adapter_class = ASGIRequestAdapter
|
68
|
+
websocket_adapter_class = ASGIWebSocketAdapter
|
67
69
|
|
68
70
|
@staticmethod
|
69
71
|
async def __get_root_value() -> None:
|
@@ -261,44 +263,7 @@ class GraphQLRouter(
|
|
261
263
|
context: Context = Depends(self.context_getter),
|
262
264
|
root_value: RootValue = Depends(self.root_value_getter),
|
263
265
|
) -> None:
|
264
|
-
|
265
|
-
return context
|
266
|
-
|
267
|
-
async def _get_root_value() -> RootValue:
|
268
|
-
return root_value
|
269
|
-
|
270
|
-
preferred_protocol = self.pick_preferred_protocol(websocket)
|
271
|
-
if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
|
272
|
-
await self.graphql_transport_ws_handler_class(
|
273
|
-
schema=self.schema,
|
274
|
-
debug=self.debug,
|
275
|
-
connection_init_wait_timeout=self.connection_init_wait_timeout,
|
276
|
-
get_context=_get_context,
|
277
|
-
get_root_value=_get_root_value,
|
278
|
-
ws=websocket,
|
279
|
-
).handle()
|
280
|
-
elif preferred_protocol == GRAPHQL_WS_PROTOCOL:
|
281
|
-
await self.graphql_ws_handler_class(
|
282
|
-
schema=self.schema,
|
283
|
-
debug=self.debug,
|
284
|
-
keep_alive=self.keep_alive,
|
285
|
-
keep_alive_interval=self.keep_alive_interval,
|
286
|
-
get_context=_get_context,
|
287
|
-
get_root_value=_get_root_value,
|
288
|
-
ws=websocket,
|
289
|
-
).handle()
|
290
|
-
else:
|
291
|
-
# Code 4406 is "Subprotocol not acceptable"
|
292
|
-
await websocket.close(code=4406)
|
293
|
-
|
294
|
-
def pick_preferred_protocol(self, ws: WebSocket) -> Optional[str]:
|
295
|
-
protocols = ws["subprotocols"]
|
296
|
-
intersection = set(protocols) & set(self.protocols)
|
297
|
-
return min(
|
298
|
-
intersection,
|
299
|
-
key=lambda i: protocols.index(i),
|
300
|
-
default=None,
|
301
|
-
)
|
266
|
+
await self.run(request=websocket, context=context, root_value=root_value)
|
302
267
|
|
303
268
|
async def render_graphql_ide(self, request: Request) -> HTMLResponse:
|
304
269
|
return HTMLResponse(self.graphql_ide_html)
|
@@ -309,12 +274,12 @@ class GraphQLRouter(
|
|
309
274
|
return process_result(result)
|
310
275
|
|
311
276
|
async def get_context(
|
312
|
-
self, request: Request, response: Response
|
277
|
+
self, request: Union[Request, WebSocket], response: Union[Response, WebSocket]
|
313
278
|
) -> Context: # pragma: no cover
|
314
279
|
raise ValueError("`get_context` is not used by FastAPI GraphQL Router")
|
315
280
|
|
316
281
|
async def get_root_value(
|
317
|
-
self, request: Request
|
282
|
+
self, request: Union[Request, WebSocket]
|
318
283
|
) -> Optional[RootValue]: # pragma: no cover
|
319
284
|
raise ValueError("`get_root_value` is not used by FastAPI GraphQL Router")
|
320
285
|
|
@@ -350,5 +315,22 @@ class GraphQLRouter(
|
|
350
315
|
},
|
351
316
|
)
|
352
317
|
|
318
|
+
def is_websocket_request(
|
319
|
+
self, request: Union[Request, WebSocket]
|
320
|
+
) -> TypeGuard[WebSocket]:
|
321
|
+
return request.scope["type"] == "websocket"
|
322
|
+
|
323
|
+
async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]:
|
324
|
+
protocols = request["subprotocols"]
|
325
|
+
intersection = set(protocols) & set(self.protocols)
|
326
|
+
sorted_intersection = sorted(intersection, key=protocols.index)
|
327
|
+
return next(iter(sorted_intersection), None)
|
328
|
+
|
329
|
+
async def create_websocket_response(
|
330
|
+
self, request: WebSocket, subprotocol: Optional[str]
|
331
|
+
) -> WebSocket:
|
332
|
+
await request.accept(subprotocol=subprotocol)
|
333
|
+
return request
|
334
|
+
|
353
335
|
|
354
336
|
__all__ = ["GraphQLRouter"]
|
strawberry/flask/views.py
CHANGED
@@ -9,6 +9,7 @@ from typing import (
|
|
9
9
|
Union,
|
10
10
|
cast,
|
11
11
|
)
|
12
|
+
from typing_extensions import TypeGuard
|
12
13
|
|
13
14
|
from flask import Request, Response, render_template_string, request
|
14
15
|
from flask.views import View
|
@@ -159,7 +160,9 @@ class AsyncFlaskHTTPRequestAdapter(AsyncHTTPRequestAdapter):
|
|
159
160
|
|
160
161
|
class AsyncGraphQLView(
|
161
162
|
BaseGraphQLView,
|
162
|
-
AsyncBaseHTTPView[
|
163
|
+
AsyncBaseHTTPView[
|
164
|
+
Request, Response, Response, Request, Response, Context, RootValue
|
165
|
+
],
|
163
166
|
View,
|
164
167
|
):
|
165
168
|
methods = ["GET", "POST"]
|
@@ -187,6 +190,17 @@ class AsyncGraphQLView(
|
|
187
190
|
async def render_graphql_ide(self, request: Request) -> Response:
|
188
191
|
return render_template_string(self.graphql_ide_html) # type: ignore
|
189
192
|
|
193
|
+
def is_websocket_request(self, request: Request) -> TypeGuard[Request]:
|
194
|
+
return False
|
195
|
+
|
196
|
+
async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]:
|
197
|
+
raise NotImplementedError
|
198
|
+
|
199
|
+
async def create_websocket_response(
|
200
|
+
self, request: Request, subprotocol: Optional[str]
|
201
|
+
) -> Response:
|
202
|
+
raise NotImplementedError
|
203
|
+
|
190
204
|
|
191
205
|
__all__ = [
|
192
206
|
"GraphQLView",
|
@@ -2,6 +2,7 @@ import abc
|
|
2
2
|
import asyncio
|
3
3
|
import contextlib
|
4
4
|
import json
|
5
|
+
from datetime import timedelta
|
5
6
|
from typing import (
|
6
7
|
Any,
|
7
8
|
AsyncGenerator,
|
@@ -13,8 +14,10 @@ from typing import (
|
|
13
14
|
Optional,
|
14
15
|
Tuple,
|
15
16
|
Union,
|
17
|
+
cast,
|
18
|
+
overload,
|
16
19
|
)
|
17
|
-
from typing_extensions import Literal
|
20
|
+
from typing_extensions import Literal, TypeGuard
|
18
21
|
|
19
22
|
from graphql import GraphQLError
|
20
23
|
|
@@ -29,6 +32,11 @@ from strawberry.http import (
|
|
29
32
|
from strawberry.http.ides import GraphQL_IDE
|
30
33
|
from strawberry.schema.base import BaseSchema
|
31
34
|
from strawberry.schema.exceptions import InvalidOperationTypeError
|
35
|
+
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
|
36
|
+
from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import (
|
37
|
+
BaseGraphQLTransportWSHandler,
|
38
|
+
)
|
39
|
+
from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
|
32
40
|
from strawberry.types import ExecutionResult, SubscriptionExecutionResult
|
33
41
|
from strawberry.types.graphql import OperationType
|
34
42
|
|
@@ -36,7 +44,15 @@ from .base import BaseView
|
|
36
44
|
from .exceptions import HTTPException
|
37
45
|
from .parse_content_type import parse_content_type
|
38
46
|
from .types import FormData, HTTPMethod, QueryParams
|
39
|
-
from .typevars import
|
47
|
+
from .typevars import (
|
48
|
+
Context,
|
49
|
+
Request,
|
50
|
+
Response,
|
51
|
+
RootValue,
|
52
|
+
SubResponse,
|
53
|
+
WebSocketRequest,
|
54
|
+
WebSocketResponse,
|
55
|
+
)
|
40
56
|
|
41
57
|
|
42
58
|
class AsyncHTTPRequestAdapter(abc.ABC):
|
@@ -63,14 +79,42 @@ class AsyncHTTPRequestAdapter(abc.ABC):
|
|
63
79
|
async def get_form_data(self) -> FormData: ...
|
64
80
|
|
65
81
|
|
82
|
+
class AsyncWebSocketAdapter(abc.ABC):
|
83
|
+
@abc.abstractmethod
|
84
|
+
def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: ...
|
85
|
+
|
86
|
+
@abc.abstractmethod
|
87
|
+
async def send_json(self, message: Mapping[str, object]) -> None: ...
|
88
|
+
|
89
|
+
@abc.abstractmethod
|
90
|
+
async def close(self, code: int, reason: str) -> None: ...
|
91
|
+
|
92
|
+
|
66
93
|
class AsyncBaseHTTPView(
|
67
94
|
abc.ABC,
|
68
95
|
BaseView[Request],
|
69
|
-
Generic[
|
96
|
+
Generic[
|
97
|
+
Request,
|
98
|
+
Response,
|
99
|
+
SubResponse,
|
100
|
+
WebSocketRequest,
|
101
|
+
WebSocketResponse,
|
102
|
+
Context,
|
103
|
+
RootValue,
|
104
|
+
],
|
70
105
|
):
|
71
106
|
schema: BaseSchema
|
72
107
|
graphql_ide: Optional[GraphQL_IDE]
|
108
|
+
debug: bool
|
109
|
+
keep_alive = False
|
110
|
+
keep_alive_interval: Optional[float] = None
|
111
|
+
connection_init_wait_timeout: timedelta = timedelta(minutes=1)
|
73
112
|
request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter]
|
113
|
+
websocket_adapter_class: Callable[
|
114
|
+
[WebSocketRequest, WebSocketResponse], AsyncWebSocketAdapter
|
115
|
+
]
|
116
|
+
graphql_transport_ws_handler_class = BaseGraphQLTransportWSHandler
|
117
|
+
graphql_ws_handler_class = BaseGraphQLWSHandler
|
74
118
|
|
75
119
|
@property
|
76
120
|
@abc.abstractmethod
|
@@ -80,10 +124,16 @@ class AsyncBaseHTTPView(
|
|
80
124
|
async def get_sub_response(self, request: Request) -> SubResponse: ...
|
81
125
|
|
82
126
|
@abc.abstractmethod
|
83
|
-
async def get_context(
|
127
|
+
async def get_context(
|
128
|
+
self,
|
129
|
+
request: Union[Request, WebSocketRequest],
|
130
|
+
response: Union[SubResponse, WebSocketResponse],
|
131
|
+
) -> Context: ...
|
84
132
|
|
85
133
|
@abc.abstractmethod
|
86
|
-
async def get_root_value(
|
134
|
+
async def get_root_value(
|
135
|
+
self, request: Union[Request, WebSocketRequest]
|
136
|
+
) -> Optional[RootValue]: ...
|
87
137
|
|
88
138
|
@abc.abstractmethod
|
89
139
|
def create_response(
|
@@ -102,6 +152,21 @@ class AsyncBaseHTTPView(
|
|
102
152
|
) -> Response:
|
103
153
|
raise ValueError("Multipart responses are not supported")
|
104
154
|
|
155
|
+
@abc.abstractmethod
|
156
|
+
def is_websocket_request(
|
157
|
+
self, request: Union[Request, WebSocketRequest]
|
158
|
+
) -> TypeGuard[WebSocketRequest]: ...
|
159
|
+
|
160
|
+
@abc.abstractmethod
|
161
|
+
async def pick_websocket_subprotocol(
|
162
|
+
self, request: WebSocketRequest
|
163
|
+
) -> Optional[str]: ...
|
164
|
+
|
165
|
+
@abc.abstractmethod
|
166
|
+
async def create_websocket_response(
|
167
|
+
self, request: WebSocketRequest, subprotocol: Optional[str]
|
168
|
+
) -> WebSocketResponse: ...
|
169
|
+
|
105
170
|
async def execute_operation(
|
106
171
|
self, request: Request, context: Context, root_value: Optional[RootValue]
|
107
172
|
) -> Union[ExecutionResult, SubscriptionExecutionResult]:
|
@@ -167,35 +232,90 @@ class AsyncBaseHTTPView(
|
|
167
232
|
) -> None:
|
168
233
|
"""Hook to allow custom handling of errors, used by the Sentry Integration."""
|
169
234
|
|
235
|
+
@overload
|
170
236
|
async def run(
|
171
237
|
self,
|
172
238
|
request: Request,
|
173
239
|
context: Optional[Context] = UNSET,
|
174
240
|
root_value: Optional[RootValue] = UNSET,
|
175
|
-
) -> Response:
|
176
|
-
request_adapter = self.request_adapter_class(request)
|
241
|
+
) -> Response: ...
|
177
242
|
|
178
|
-
|
179
|
-
|
243
|
+
@overload
|
244
|
+
async def run(
|
245
|
+
self,
|
246
|
+
request: WebSocketRequest,
|
247
|
+
context: Optional[Context] = UNSET,
|
248
|
+
root_value: Optional[RootValue] = UNSET,
|
249
|
+
) -> WebSocketResponse: ...
|
180
250
|
|
181
|
-
|
182
|
-
|
183
|
-
|
251
|
+
async def run(
|
252
|
+
self,
|
253
|
+
request: Union[Request, WebSocketRequest],
|
254
|
+
context: Optional[Context] = UNSET,
|
255
|
+
root_value: Optional[RootValue] = UNSET,
|
256
|
+
) -> Union[Response, WebSocketResponse]:
|
257
|
+
root_value = (
|
258
|
+
await self.get_root_value(request) if root_value is UNSET else root_value
|
259
|
+
)
|
260
|
+
|
261
|
+
if self.is_websocket_request(request):
|
262
|
+
websocket_subprotocol = await self.pick_websocket_subprotocol(request)
|
263
|
+
websocket_response = await self.create_websocket_response(
|
264
|
+
request, websocket_subprotocol
|
265
|
+
)
|
266
|
+
websocket = self.websocket_adapter_class(request, websocket_response)
|
267
|
+
|
268
|
+
context = (
|
269
|
+
await self.get_context(request, response=websocket_response)
|
270
|
+
if context is UNSET
|
271
|
+
else context
|
272
|
+
)
|
273
|
+
|
274
|
+
if websocket_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
|
275
|
+
await self.graphql_transport_ws_handler_class(
|
276
|
+
websocket=websocket,
|
277
|
+
context=context,
|
278
|
+
root_value=root_value,
|
279
|
+
schema=self.schema,
|
280
|
+
debug=self.debug,
|
281
|
+
connection_init_wait_timeout=self.connection_init_wait_timeout,
|
282
|
+
).handle()
|
283
|
+
elif websocket_subprotocol == GRAPHQL_WS_PROTOCOL:
|
284
|
+
await self.graphql_ws_handler_class(
|
285
|
+
websocket=websocket,
|
286
|
+
context=context,
|
287
|
+
root_value=root_value,
|
288
|
+
schema=self.schema,
|
289
|
+
debug=self.debug,
|
290
|
+
keep_alive=self.keep_alive,
|
291
|
+
keep_alive_interval=self.keep_alive_interval,
|
292
|
+
).handle()
|
184
293
|
else:
|
185
|
-
|
294
|
+
await websocket.close(4406, "Subprotocol not acceptable")
|
295
|
+
|
296
|
+
return websocket_response
|
297
|
+
else:
|
298
|
+
request = cast(Request, request)
|
186
299
|
|
300
|
+
request_adapter = self.request_adapter_class(request)
|
187
301
|
sub_response = await self.get_sub_response(request)
|
188
302
|
context = (
|
189
303
|
await self.get_context(request, response=sub_response)
|
190
304
|
if context is UNSET
|
191
305
|
else context
|
192
306
|
)
|
193
|
-
root_value = (
|
194
|
-
await self.get_root_value(request) if root_value is UNSET else root_value
|
195
|
-
)
|
196
307
|
|
197
308
|
assert context
|
198
309
|
|
310
|
+
if not self.is_request_allowed(request_adapter):
|
311
|
+
raise HTTPException(405, "GraphQL only supports GET and POST requests.")
|
312
|
+
|
313
|
+
if self.should_render_graphql_ide(request_adapter):
|
314
|
+
if self.graphql_ide:
|
315
|
+
return await self.render_graphql_ide(request)
|
316
|
+
else:
|
317
|
+
raise HTTPException(404, "Not Found")
|
318
|
+
|
199
319
|
try:
|
200
320
|
result = await self.execute_operation(
|
201
321
|
request=request, context=context, root_value=root_value
|
strawberry/http/exceptions.py
CHANGED
strawberry/http/typevars.py
CHANGED
@@ -3,8 +3,18 @@ from typing import TypeVar
|
|
3
3
|
Request = TypeVar("Request", contravariant=True)
|
4
4
|
Response = TypeVar("Response")
|
5
5
|
SubResponse = TypeVar("SubResponse")
|
6
|
+
WebSocketRequest = TypeVar("WebSocketRequest")
|
7
|
+
WebSocketResponse = TypeVar("WebSocketResponse")
|
6
8
|
Context = TypeVar("Context")
|
7
9
|
RootValue = TypeVar("RootValue")
|
8
10
|
|
9
11
|
|
10
|
-
__all__ = [
|
12
|
+
__all__ = [
|
13
|
+
"Request",
|
14
|
+
"Response",
|
15
|
+
"SubResponse",
|
16
|
+
"WebSocketRequest",
|
17
|
+
"WebSocketResponse",
|
18
|
+
"Context",
|
19
|
+
"RootValue",
|
20
|
+
]
|
@@ -7,19 +7,19 @@ from datetime import timedelta
|
|
7
7
|
from typing import (
|
8
8
|
TYPE_CHECKING,
|
9
9
|
Any,
|
10
|
+
AsyncGenerator,
|
10
11
|
AsyncIterator,
|
11
12
|
Callable,
|
12
13
|
Dict,
|
13
14
|
FrozenSet,
|
14
|
-
List,
|
15
15
|
Optional,
|
16
|
-
Set,
|
17
16
|
Tuple,
|
18
17
|
Type,
|
19
18
|
TypedDict,
|
20
19
|
Union,
|
21
20
|
cast,
|
22
21
|
)
|
22
|
+
from typing_extensions import TypeGuard
|
23
23
|
|
24
24
|
from msgspec import Struct
|
25
25
|
|
@@ -35,23 +35,24 @@ from litestar import (
|
|
35
35
|
)
|
36
36
|
from litestar.background_tasks import BackgroundTasks
|
37
37
|
from litestar.di import Provide
|
38
|
-
from litestar.exceptions import
|
38
|
+
from litestar.exceptions import (
|
39
|
+
NotFoundException,
|
40
|
+
SerializationException,
|
41
|
+
ValidationException,
|
42
|
+
WebSocketDisconnect,
|
43
|
+
)
|
39
44
|
from litestar.response.streaming import Stream
|
40
45
|
from litestar.status_codes import HTTP_200_OK
|
41
46
|
from strawberry.exceptions import InvalidCustomContext
|
42
|
-
from strawberry.http.async_base_view import
|
43
|
-
|
47
|
+
from strawberry.http.async_base_view import (
|
48
|
+
AsyncBaseHTTPView,
|
49
|
+
AsyncHTTPRequestAdapter,
|
50
|
+
AsyncWebSocketAdapter,
|
51
|
+
)
|
52
|
+
from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
|
44
53
|
from strawberry.http.types import FormData, HTTPMethod, QueryParams
|
45
54
|
from strawberry.http.typevars import Context, RootValue
|
46
55
|
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
|
47
|
-
from strawberry.subscriptions.protocols.graphql_transport_ws import (
|
48
|
-
WS_4406_PROTOCOL_NOT_ACCEPTABLE,
|
49
|
-
)
|
50
|
-
|
51
|
-
from .handlers.graphql_transport_ws_handler import (
|
52
|
-
GraphQLTransportWSHandler as BaseGraphQLTransportWSHandler,
|
53
|
-
)
|
54
|
-
from .handlers.graphql_ws_handler import GraphQLWSHandler as BaseGraphQLWSHandler
|
55
56
|
|
56
57
|
if TYPE_CHECKING:
|
57
58
|
from collections.abc import Mapping
|
@@ -152,22 +153,6 @@ class GraphQLResource(Struct):
|
|
152
153
|
extensions: Optional[dict[str, object]]
|
153
154
|
|
154
155
|
|
155
|
-
class GraphQLWSHandler(BaseGraphQLWSHandler):
|
156
|
-
async def get_context(self) -> Any:
|
157
|
-
return await self._get_context()
|
158
|
-
|
159
|
-
async def get_root_value(self) -> Any:
|
160
|
-
return await self._get_root_value()
|
161
|
-
|
162
|
-
|
163
|
-
class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
|
164
|
-
async def get_context(self) -> Any:
|
165
|
-
return await self._get_context()
|
166
|
-
|
167
|
-
async def get_root_value(self) -> Any:
|
168
|
-
return await self._get_root_value()
|
169
|
-
|
170
|
-
|
171
156
|
class LitestarRequestAdapter(AsyncHTTPRequestAdapter):
|
172
157
|
def __init__(self, request: Request[Any, Any, Any]) -> None:
|
173
158
|
self.request = request
|
@@ -203,10 +188,37 @@ class LitestarRequestAdapter(AsyncHTTPRequestAdapter):
|
|
203
188
|
return FormData(form=multipart_data, files=multipart_data)
|
204
189
|
|
205
190
|
|
191
|
+
class LitestarWebSocketAdapter(AsyncWebSocketAdapter):
|
192
|
+
def __init__(self, request: WebSocket, response: WebSocket) -> None:
|
193
|
+
self.ws = response
|
194
|
+
|
195
|
+
async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
|
196
|
+
try:
|
197
|
+
try:
|
198
|
+
while self.ws.connection_state != "disconnect":
|
199
|
+
yield await self.ws.receive_json()
|
200
|
+
except (SerializationException, ValueError):
|
201
|
+
raise NonJsonMessageReceived()
|
202
|
+
except WebSocketDisconnect:
|
203
|
+
pass
|
204
|
+
|
205
|
+
async def send_json(self, message: Mapping[str, object]) -> None:
|
206
|
+
await self.ws.send_json(message)
|
207
|
+
|
208
|
+
async def close(self, code: int, reason: str) -> None:
|
209
|
+
await self.ws.close(code=code, reason=reason)
|
210
|
+
|
211
|
+
|
206
212
|
class GraphQLController(
|
207
213
|
Controller,
|
208
214
|
AsyncBaseHTTPView[
|
209
|
-
Request[Any, Any, Any],
|
215
|
+
Request[Any, Any, Any],
|
216
|
+
Response[Any],
|
217
|
+
Response[Any],
|
218
|
+
WebSocket,
|
219
|
+
WebSocket,
|
220
|
+
Context,
|
221
|
+
RootValue,
|
210
222
|
],
|
211
223
|
):
|
212
224
|
path: str = ""
|
@@ -219,10 +231,7 @@ class GraphQLController(
|
|
219
231
|
}
|
220
232
|
|
221
233
|
request_adapter_class = LitestarRequestAdapter
|
222
|
-
|
223
|
-
graphql_transport_ws_handler_class: Type[GraphQLTransportWSHandler] = (
|
224
|
-
GraphQLTransportWSHandler
|
225
|
-
)
|
234
|
+
websocket_adapter_class = LitestarWebSocketAdapter
|
226
235
|
|
227
236
|
allow_queries_via_get: bool = True
|
228
237
|
graphiql_allowed_accept: FrozenSet[str] = frozenset({"text/html", "*/*"})
|
@@ -236,6 +245,23 @@ class GraphQLController(
|
|
236
245
|
keep_alive: bool = False
|
237
246
|
keep_alive_interval: float = 1
|
238
247
|
|
248
|
+
def is_websocket_request(
|
249
|
+
self, request: Union[Request, WebSocket]
|
250
|
+
) -> TypeGuard[WebSocket]:
|
251
|
+
return isinstance(request, WebSocket)
|
252
|
+
|
253
|
+
async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]:
|
254
|
+
subprotocols = request.scope["subprotocols"]
|
255
|
+
intersection = set(subprotocols) & set(self.protocols)
|
256
|
+
sorted_intersection = sorted(intersection, key=subprotocols.index)
|
257
|
+
return next(iter(sorted_intersection), None)
|
258
|
+
|
259
|
+
async def create_websocket_response(
|
260
|
+
self, request: WebSocket, subprotocol: Optional[str]
|
261
|
+
) -> WebSocket:
|
262
|
+
await request.accept(subprotocols=subprotocol)
|
263
|
+
return request
|
264
|
+
|
239
265
|
async def execute_request(
|
240
266
|
self,
|
241
267
|
request: Request[Any, Any, Any],
|
@@ -245,8 +271,6 @@ class GraphQLController(
|
|
245
271
|
try:
|
246
272
|
return await self.run(
|
247
273
|
request,
|
248
|
-
# TODO: check the dependency, above, can we make it so that
|
249
|
-
# we don't need to type ignore here?
|
250
274
|
context=context,
|
251
275
|
root_value=root_value,
|
252
276
|
)
|
@@ -328,14 +352,29 @@ class GraphQLController(
|
|
328
352
|
root_value=root_value,
|
329
353
|
)
|
330
354
|
|
355
|
+
@websocket()
|
356
|
+
async def websocket_endpoint(
|
357
|
+
self,
|
358
|
+
socket: WebSocket,
|
359
|
+
context_ws: Any,
|
360
|
+
root_value: Any,
|
361
|
+
) -> None:
|
362
|
+
await self.run(
|
363
|
+
request=socket,
|
364
|
+
context=context_ws,
|
365
|
+
root_value=root_value,
|
366
|
+
)
|
367
|
+
|
331
368
|
async def get_context(
|
332
|
-
self,
|
369
|
+
self,
|
370
|
+
request: Union[Request[Any, Any, Any], WebSocket],
|
371
|
+
response: Union[Response, WebSocket],
|
333
372
|
) -> Context: # pragma: no cover
|
334
373
|
msg = "`get_context` is not used by Litestar's controller"
|
335
374
|
raise ValueError(msg)
|
336
375
|
|
337
376
|
async def get_root_value(
|
338
|
-
self, request: Request[Any, Any, Any]
|
377
|
+
self, request: Union[Request[Any, Any, Any], WebSocket]
|
339
378
|
) -> RootValue | None: # pragma: no cover
|
340
379
|
msg = "`get_root_value` is not used by Litestar's controller"
|
341
380
|
raise ValueError(msg)
|
@@ -343,54 +382,6 @@ class GraphQLController(
|
|
343
382
|
async def get_sub_response(self, request: Request[Any, Any, Any]) -> Response:
|
344
383
|
return self.temporal_response
|
345
384
|
|
346
|
-
@websocket()
|
347
|
-
async def websocket_endpoint(
|
348
|
-
self,
|
349
|
-
socket: WebSocket,
|
350
|
-
context_ws: Any,
|
351
|
-
root_value: Any,
|
352
|
-
) -> None:
|
353
|
-
async def _get_context() -> Any:
|
354
|
-
return context_ws
|
355
|
-
|
356
|
-
async def _get_root_value() -> Any:
|
357
|
-
return root_value
|
358
|
-
|
359
|
-
preferred_protocol = self.pick_preferred_protocol(socket)
|
360
|
-
if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
|
361
|
-
await self.graphql_transport_ws_handler_class(
|
362
|
-
schema=self.schema,
|
363
|
-
debug=self.debug,
|
364
|
-
connection_init_wait_timeout=self.connection_init_wait_timeout,
|
365
|
-
get_context=_get_context,
|
366
|
-
get_root_value=_get_root_value,
|
367
|
-
ws=socket,
|
368
|
-
).handle()
|
369
|
-
elif preferred_protocol == GRAPHQL_WS_PROTOCOL:
|
370
|
-
await self.graphql_ws_handler_class(
|
371
|
-
schema=self.schema,
|
372
|
-
debug=self.debug,
|
373
|
-
keep_alive=self.keep_alive,
|
374
|
-
keep_alive_interval=self.keep_alive_interval,
|
375
|
-
get_context=_get_context,
|
376
|
-
get_root_value=_get_root_value,
|
377
|
-
ws=socket,
|
378
|
-
).handle()
|
379
|
-
else:
|
380
|
-
await socket.close(code=WS_4406_PROTOCOL_NOT_ACCEPTABLE)
|
381
|
-
|
382
|
-
def pick_preferred_protocol(self, socket: WebSocket) -> str | None:
|
383
|
-
protocols: List[str] = socket.scope["subprotocols"]
|
384
|
-
intersection: Set[str] = set(protocols) & set(self.protocols)
|
385
|
-
return (
|
386
|
-
min(
|
387
|
-
intersection,
|
388
|
-
key=lambda i: protocols.index(i) if i else "",
|
389
|
-
default=None,
|
390
|
-
)
|
391
|
-
or None
|
392
|
-
)
|
393
|
-
|
394
385
|
|
395
386
|
def make_graphql_controller(
|
396
387
|
schema: BaseSchema,
|