strawberry-graphql 0.243.1__py3-none-any.whl → 0.244.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strawberry/aiohttp/views.py +58 -44
- strawberry/asgi/__init__.py +62 -61
- strawberry/channels/__init__.py +1 -6
- strawberry/channels/handlers/base.py +2 -2
- strawberry/channels/handlers/http_handler.py +18 -1
- strawberry/channels/handlers/ws_handler.py +113 -58
- strawberry/codegen/query_codegen.py +8 -6
- strawberry/django/views.py +19 -1
- strawberry/fastapi/router.py +27 -45
- strawberry/flask/views.py +15 -1
- strawberry/http/async_base_view.py +136 -16
- strawberry/http/exceptions.py +4 -0
- strawberry/http/typevars.py +11 -1
- strawberry/litestar/controller.py +77 -86
- strawberry/quart/views.py +15 -1
- strawberry/sanic/views.py +21 -1
- strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +41 -40
- strawberry/subscriptions/protocols/graphql_ws/handlers.py +46 -45
- {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/METADATA +1 -1
- {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/RECORD +23 -37
- strawberry/aiohttp/handlers/__init__.py +0 -6
- strawberry/aiohttp/handlers/graphql_transport_ws_handler.py +0 -62
- strawberry/aiohttp/handlers/graphql_ws_handler.py +0 -69
- strawberry/asgi/handlers/__init__.py +0 -6
- strawberry/asgi/handlers/graphql_transport_ws_handler.py +0 -66
- strawberry/asgi/handlers/graphql_ws_handler.py +0 -71
- strawberry/channels/handlers/graphql_transport_ws_handler.py +0 -62
- strawberry/channels/handlers/graphql_ws_handler.py +0 -72
- strawberry/fastapi/handlers/__init__.py +0 -6
- strawberry/fastapi/handlers/graphql_transport_ws_handler.py +0 -20
- strawberry/fastapi/handlers/graphql_ws_handler.py +0 -18
- strawberry/litestar/handlers/__init__.py +0 -0
- strawberry/litestar/handlers/graphql_transport_ws_handler.py +0 -60
- strawberry/litestar/handlers/graphql_ws_handler.py +0 -66
- {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/LICENSE +0 -0
- {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/WHEEL +0 -0
- {strawberry_graphql-0.243.1.dist-info → strawberry_graphql-0.244.1.dist-info}/entry_points.txt +0 -0
@@ -7,19 +7,19 @@ from datetime import timedelta
|
|
7
7
|
from typing import (
|
8
8
|
TYPE_CHECKING,
|
9
9
|
Any,
|
10
|
+
AsyncGenerator,
|
10
11
|
AsyncIterator,
|
11
12
|
Callable,
|
12
13
|
Dict,
|
13
14
|
FrozenSet,
|
14
|
-
List,
|
15
15
|
Optional,
|
16
|
-
Set,
|
17
16
|
Tuple,
|
18
17
|
Type,
|
19
18
|
TypedDict,
|
20
19
|
Union,
|
21
20
|
cast,
|
22
21
|
)
|
22
|
+
from typing_extensions import TypeGuard
|
23
23
|
|
24
24
|
from msgspec import Struct
|
25
25
|
|
@@ -35,23 +35,24 @@ from litestar import (
|
|
35
35
|
)
|
36
36
|
from litestar.background_tasks import BackgroundTasks
|
37
37
|
from litestar.di import Provide
|
38
|
-
from litestar.exceptions import
|
38
|
+
from litestar.exceptions import (
|
39
|
+
NotFoundException,
|
40
|
+
SerializationException,
|
41
|
+
ValidationException,
|
42
|
+
WebSocketDisconnect,
|
43
|
+
)
|
39
44
|
from litestar.response.streaming import Stream
|
40
45
|
from litestar.status_codes import HTTP_200_OK
|
41
46
|
from strawberry.exceptions import InvalidCustomContext
|
42
|
-
from strawberry.http.async_base_view import
|
43
|
-
|
47
|
+
from strawberry.http.async_base_view import (
|
48
|
+
AsyncBaseHTTPView,
|
49
|
+
AsyncHTTPRequestAdapter,
|
50
|
+
AsyncWebSocketAdapter,
|
51
|
+
)
|
52
|
+
from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
|
44
53
|
from strawberry.http.types import FormData, HTTPMethod, QueryParams
|
45
54
|
from strawberry.http.typevars import Context, RootValue
|
46
55
|
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
|
47
|
-
from strawberry.subscriptions.protocols.graphql_transport_ws import (
|
48
|
-
WS_4406_PROTOCOL_NOT_ACCEPTABLE,
|
49
|
-
)
|
50
|
-
|
51
|
-
from .handlers.graphql_transport_ws_handler import (
|
52
|
-
GraphQLTransportWSHandler as BaseGraphQLTransportWSHandler,
|
53
|
-
)
|
54
|
-
from .handlers.graphql_ws_handler import GraphQLWSHandler as BaseGraphQLWSHandler
|
55
56
|
|
56
57
|
if TYPE_CHECKING:
|
57
58
|
from collections.abc import Mapping
|
@@ -152,22 +153,6 @@ class GraphQLResource(Struct):
|
|
152
153
|
extensions: Optional[dict[str, object]]
|
153
154
|
|
154
155
|
|
155
|
-
class GraphQLWSHandler(BaseGraphQLWSHandler):
|
156
|
-
async def get_context(self) -> Any:
|
157
|
-
return await self._get_context()
|
158
|
-
|
159
|
-
async def get_root_value(self) -> Any:
|
160
|
-
return await self._get_root_value()
|
161
|
-
|
162
|
-
|
163
|
-
class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
|
164
|
-
async def get_context(self) -> Any:
|
165
|
-
return await self._get_context()
|
166
|
-
|
167
|
-
async def get_root_value(self) -> Any:
|
168
|
-
return await self._get_root_value()
|
169
|
-
|
170
|
-
|
171
156
|
class LitestarRequestAdapter(AsyncHTTPRequestAdapter):
|
172
157
|
def __init__(self, request: Request[Any, Any, Any]) -> None:
|
173
158
|
self.request = request
|
@@ -203,10 +188,37 @@ class LitestarRequestAdapter(AsyncHTTPRequestAdapter):
|
|
203
188
|
return FormData(form=multipart_data, files=multipart_data)
|
204
189
|
|
205
190
|
|
191
|
+
class LitestarWebSocketAdapter(AsyncWebSocketAdapter):
|
192
|
+
def __init__(self, request: WebSocket, response: WebSocket) -> None:
|
193
|
+
self.ws = response
|
194
|
+
|
195
|
+
async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
|
196
|
+
try:
|
197
|
+
try:
|
198
|
+
while self.ws.connection_state != "disconnect":
|
199
|
+
yield await self.ws.receive_json()
|
200
|
+
except (SerializationException, ValueError):
|
201
|
+
raise NonJsonMessageReceived()
|
202
|
+
except WebSocketDisconnect:
|
203
|
+
pass
|
204
|
+
|
205
|
+
async def send_json(self, message: Mapping[str, object]) -> None:
|
206
|
+
await self.ws.send_json(message)
|
207
|
+
|
208
|
+
async def close(self, code: int, reason: str) -> None:
|
209
|
+
await self.ws.close(code=code, reason=reason)
|
210
|
+
|
211
|
+
|
206
212
|
class GraphQLController(
|
207
213
|
Controller,
|
208
214
|
AsyncBaseHTTPView[
|
209
|
-
Request[Any, Any, Any],
|
215
|
+
Request[Any, Any, Any],
|
216
|
+
Response[Any],
|
217
|
+
Response[Any],
|
218
|
+
WebSocket,
|
219
|
+
WebSocket,
|
220
|
+
Context,
|
221
|
+
RootValue,
|
210
222
|
],
|
211
223
|
):
|
212
224
|
path: str = ""
|
@@ -219,10 +231,7 @@ class GraphQLController(
|
|
219
231
|
}
|
220
232
|
|
221
233
|
request_adapter_class = LitestarRequestAdapter
|
222
|
-
|
223
|
-
graphql_transport_ws_handler_class: Type[GraphQLTransportWSHandler] = (
|
224
|
-
GraphQLTransportWSHandler
|
225
|
-
)
|
234
|
+
websocket_adapter_class = LitestarWebSocketAdapter
|
226
235
|
|
227
236
|
allow_queries_via_get: bool = True
|
228
237
|
graphiql_allowed_accept: FrozenSet[str] = frozenset({"text/html", "*/*"})
|
@@ -236,6 +245,23 @@ class GraphQLController(
|
|
236
245
|
keep_alive: bool = False
|
237
246
|
keep_alive_interval: float = 1
|
238
247
|
|
248
|
+
def is_websocket_request(
|
249
|
+
self, request: Union[Request, WebSocket]
|
250
|
+
) -> TypeGuard[WebSocket]:
|
251
|
+
return isinstance(request, WebSocket)
|
252
|
+
|
253
|
+
async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]:
|
254
|
+
subprotocols = request.scope["subprotocols"]
|
255
|
+
intersection = set(subprotocols) & set(self.protocols)
|
256
|
+
sorted_intersection = sorted(intersection, key=subprotocols.index)
|
257
|
+
return next(iter(sorted_intersection), None)
|
258
|
+
|
259
|
+
async def create_websocket_response(
|
260
|
+
self, request: WebSocket, subprotocol: Optional[str]
|
261
|
+
) -> WebSocket:
|
262
|
+
await request.accept(subprotocols=subprotocol)
|
263
|
+
return request
|
264
|
+
|
239
265
|
async def execute_request(
|
240
266
|
self,
|
241
267
|
request: Request[Any, Any, Any],
|
@@ -245,8 +271,6 @@ class GraphQLController(
|
|
245
271
|
try:
|
246
272
|
return await self.run(
|
247
273
|
request,
|
248
|
-
# TODO: check the dependency, above, can we make it so that
|
249
|
-
# we don't need to type ignore here?
|
250
274
|
context=context,
|
251
275
|
root_value=root_value,
|
252
276
|
)
|
@@ -328,14 +352,29 @@ class GraphQLController(
|
|
328
352
|
root_value=root_value,
|
329
353
|
)
|
330
354
|
|
355
|
+
@websocket()
|
356
|
+
async def websocket_endpoint(
|
357
|
+
self,
|
358
|
+
socket: WebSocket,
|
359
|
+
context_ws: Any,
|
360
|
+
root_value: Any,
|
361
|
+
) -> None:
|
362
|
+
await self.run(
|
363
|
+
request=socket,
|
364
|
+
context=context_ws,
|
365
|
+
root_value=root_value,
|
366
|
+
)
|
367
|
+
|
331
368
|
async def get_context(
|
332
|
-
self,
|
369
|
+
self,
|
370
|
+
request: Union[Request[Any, Any, Any], WebSocket],
|
371
|
+
response: Union[Response, WebSocket],
|
333
372
|
) -> Context: # pragma: no cover
|
334
373
|
msg = "`get_context` is not used by Litestar's controller"
|
335
374
|
raise ValueError(msg)
|
336
375
|
|
337
376
|
async def get_root_value(
|
338
|
-
self, request: Request[Any, Any, Any]
|
377
|
+
self, request: Union[Request[Any, Any, Any], WebSocket]
|
339
378
|
) -> RootValue | None: # pragma: no cover
|
340
379
|
msg = "`get_root_value` is not used by Litestar's controller"
|
341
380
|
raise ValueError(msg)
|
@@ -343,54 +382,6 @@ class GraphQLController(
|
|
343
382
|
async def get_sub_response(self, request: Request[Any, Any, Any]) -> Response:
|
344
383
|
return self.temporal_response
|
345
384
|
|
346
|
-
@websocket()
|
347
|
-
async def websocket_endpoint(
|
348
|
-
self,
|
349
|
-
socket: WebSocket,
|
350
|
-
context_ws: Any,
|
351
|
-
root_value: Any,
|
352
|
-
) -> None:
|
353
|
-
async def _get_context() -> Any:
|
354
|
-
return context_ws
|
355
|
-
|
356
|
-
async def _get_root_value() -> Any:
|
357
|
-
return root_value
|
358
|
-
|
359
|
-
preferred_protocol = self.pick_preferred_protocol(socket)
|
360
|
-
if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
|
361
|
-
await self.graphql_transport_ws_handler_class(
|
362
|
-
schema=self.schema,
|
363
|
-
debug=self.debug,
|
364
|
-
connection_init_wait_timeout=self.connection_init_wait_timeout,
|
365
|
-
get_context=_get_context,
|
366
|
-
get_root_value=_get_root_value,
|
367
|
-
ws=socket,
|
368
|
-
).handle()
|
369
|
-
elif preferred_protocol == GRAPHQL_WS_PROTOCOL:
|
370
|
-
await self.graphql_ws_handler_class(
|
371
|
-
schema=self.schema,
|
372
|
-
debug=self.debug,
|
373
|
-
keep_alive=self.keep_alive,
|
374
|
-
keep_alive_interval=self.keep_alive_interval,
|
375
|
-
get_context=_get_context,
|
376
|
-
get_root_value=_get_root_value,
|
377
|
-
ws=socket,
|
378
|
-
).handle()
|
379
|
-
else:
|
380
|
-
await socket.close(code=WS_4406_PROTOCOL_NOT_ACCEPTABLE)
|
381
|
-
|
382
|
-
def pick_preferred_protocol(self, socket: WebSocket) -> str | None:
|
383
|
-
protocols: List[str] = socket.scope["subprotocols"]
|
384
|
-
intersection: Set[str] = set(protocols) & set(self.protocols)
|
385
|
-
return (
|
386
|
-
min(
|
387
|
-
intersection,
|
388
|
-
key=lambda i: protocols.index(i) if i else "",
|
389
|
-
default=None,
|
390
|
-
)
|
391
|
-
or None
|
392
|
-
)
|
393
|
-
|
394
385
|
|
395
386
|
def make_graphql_controller(
|
396
387
|
schema: BaseSchema,
|
strawberry/quart/views.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import warnings
|
2
2
|
from collections.abc import Mapping
|
3
3
|
from typing import TYPE_CHECKING, AsyncGenerator, Callable, Dict, Optional, cast
|
4
|
+
from typing_extensions import TypeGuard
|
4
5
|
|
5
6
|
from quart import Request, Response, request
|
6
7
|
from quart.views import View
|
@@ -46,7 +47,9 @@ class QuartHTTPRequestAdapter(AsyncHTTPRequestAdapter):
|
|
46
47
|
|
47
48
|
|
48
49
|
class GraphQLView(
|
49
|
-
AsyncBaseHTTPView[
|
50
|
+
AsyncBaseHTTPView[
|
51
|
+
Request, Response, Response, Request, Response, Context, RootValue
|
52
|
+
],
|
50
53
|
View,
|
51
54
|
):
|
52
55
|
_ide_subscription_enabled = False
|
@@ -121,5 +124,16 @@ class GraphQLView(
|
|
121
124
|
},
|
122
125
|
)
|
123
126
|
|
127
|
+
def is_websocket_request(self, request: Request) -> TypeGuard[Request]:
|
128
|
+
return False
|
129
|
+
|
130
|
+
async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]:
|
131
|
+
raise NotImplementedError
|
132
|
+
|
133
|
+
async def create_websocket_response(
|
134
|
+
self, request: Request, subprotocol: Optional[str]
|
135
|
+
) -> Response:
|
136
|
+
raise NotImplementedError
|
137
|
+
|
124
138
|
|
125
139
|
__all__ = ["GraphQLView"]
|
strawberry/sanic/views.py
CHANGED
@@ -13,6 +13,7 @@ from typing import (
|
|
13
13
|
Type,
|
14
14
|
cast,
|
15
15
|
)
|
16
|
+
from typing_extensions import TypeGuard
|
16
17
|
|
17
18
|
from sanic.request import Request
|
18
19
|
from sanic.response import HTTPResponse, html
|
@@ -71,7 +72,15 @@ class SanicHTTPRequestAdapter(AsyncHTTPRequestAdapter):
|
|
71
72
|
|
72
73
|
|
73
74
|
class GraphQLView(
|
74
|
-
AsyncBaseHTTPView[
|
75
|
+
AsyncBaseHTTPView[
|
76
|
+
Request,
|
77
|
+
HTTPResponse,
|
78
|
+
TemporalResponse,
|
79
|
+
Request,
|
80
|
+
TemporalResponse,
|
81
|
+
Context,
|
82
|
+
RootValue,
|
83
|
+
],
|
75
84
|
HTTPMethodView,
|
76
85
|
):
|
77
86
|
"""Class based view to handle GraphQL HTTP Requests.
|
@@ -206,5 +215,16 @@ class GraphQLView(
|
|
206
215
|
# corner case
|
207
216
|
return None # type: ignore
|
208
217
|
|
218
|
+
def is_websocket_request(self, request: Request) -> TypeGuard[Request]:
|
219
|
+
return False
|
220
|
+
|
221
|
+
async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]:
|
222
|
+
raise NotImplementedError
|
223
|
+
|
224
|
+
async def create_websocket_response(
|
225
|
+
self, request: Request, subprotocol: Optional[str]
|
226
|
+
) -> TemporalResponse:
|
227
|
+
raise NotImplementedError
|
228
|
+
|
209
229
|
|
210
230
|
__all__ = ["GraphQLView"]
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import asyncio
|
4
4
|
import logging
|
5
|
-
from abc import ABC, abstractmethod
|
6
5
|
from contextlib import suppress
|
7
6
|
from typing import (
|
8
7
|
TYPE_CHECKING,
|
@@ -16,6 +15,7 @@ from typing import (
|
|
16
15
|
|
17
16
|
from graphql import GraphQLError, GraphQLSyntaxError, parse
|
18
17
|
|
18
|
+
from strawberry.http.exceptions import NonJsonMessageReceived
|
19
19
|
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
|
20
20
|
CompleteMessage,
|
21
21
|
ConnectionAckMessage,
|
@@ -38,6 +38,7 @@ from strawberry.utils.operation import get_operation_type
|
|
38
38
|
if TYPE_CHECKING:
|
39
39
|
from datetime import timedelta
|
40
40
|
|
41
|
+
from strawberry.http.async_base_view import AsyncWebSocketAdapter
|
41
42
|
from strawberry.schema import BaseSchema
|
42
43
|
from strawberry.schema.subscribe import SubscriptionResult
|
43
44
|
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
|
@@ -45,15 +46,21 @@ if TYPE_CHECKING:
|
|
45
46
|
)
|
46
47
|
|
47
48
|
|
48
|
-
class BaseGraphQLTransportWSHandler
|
49
|
+
class BaseGraphQLTransportWSHandler:
|
49
50
|
task_logger: logging.Logger = logging.getLogger("strawberry.ws.task")
|
50
51
|
|
51
52
|
def __init__(
|
52
53
|
self,
|
54
|
+
websocket: AsyncWebSocketAdapter,
|
55
|
+
context: object,
|
56
|
+
root_value: object,
|
53
57
|
schema: BaseSchema,
|
54
58
|
debug: bool,
|
55
59
|
connection_init_wait_timeout: timedelta,
|
56
60
|
) -> None:
|
61
|
+
self.websocket = websocket
|
62
|
+
self.context = context
|
63
|
+
self.root_value = root_value
|
57
64
|
self.schema = schema
|
58
65
|
self.debug = debug
|
59
66
|
self.connection_init_wait_timeout = connection_init_wait_timeout
|
@@ -65,28 +72,16 @@ class BaseGraphQLTransportWSHandler(ABC):
|
|
65
72
|
self.completed_tasks: List[asyncio.Task] = []
|
66
73
|
self.connection_params: Optional[Dict[str, Any]] = None
|
67
74
|
|
68
|
-
@abstractmethod
|
69
|
-
async def get_context(self) -> Any:
|
70
|
-
"""Return the operations context."""
|
71
|
-
|
72
|
-
@abstractmethod
|
73
|
-
async def get_root_value(self) -> Any:
|
74
|
-
"""Return the schemas root value."""
|
75
|
-
|
76
|
-
@abstractmethod
|
77
|
-
async def send_json(self, data: dict) -> None:
|
78
|
-
"""Send the data JSON encoded to the WebSocket client."""
|
79
|
-
|
80
|
-
@abstractmethod
|
81
|
-
async def close(self, code: int, reason: str) -> None:
|
82
|
-
"""Close the WebSocket with the passed code and reason."""
|
83
|
-
|
84
|
-
@abstractmethod
|
85
|
-
async def handle_request(self) -> Any:
|
86
|
-
"""Handle the request this instance was created for."""
|
87
|
-
|
88
75
|
async def handle(self) -> Any:
|
89
|
-
|
76
|
+
self.on_request_accepted()
|
77
|
+
|
78
|
+
try:
|
79
|
+
async for message in self.websocket.iter_json():
|
80
|
+
await self.handle_message(message)
|
81
|
+
except NonJsonMessageReceived:
|
82
|
+
await self.handle_invalid_message("WebSocket message type must be text")
|
83
|
+
finally:
|
84
|
+
await self.shutdown()
|
90
85
|
|
91
86
|
async def shutdown(self) -> None:
|
92
87
|
if self.connection_init_timeout_task:
|
@@ -118,7 +113,7 @@ class BaseGraphQLTransportWSHandler(ABC):
|
|
118
113
|
|
119
114
|
self.connection_timed_out = True
|
120
115
|
reason = "Connection initialisation timeout"
|
121
|
-
await self.close(code=4408, reason=reason)
|
116
|
+
await self.websocket.close(code=4408, reason=reason)
|
122
117
|
except Exception as error:
|
123
118
|
await self.handle_task_exception(error) # pragma: no cover
|
124
119
|
finally:
|
@@ -189,14 +184,16 @@ class BaseGraphQLTransportWSHandler(ABC):
|
|
189
184
|
)
|
190
185
|
|
191
186
|
if not isinstance(payload, dict):
|
192
|
-
await self.close(
|
187
|
+
await self.websocket.close(
|
188
|
+
code=4400, reason="Invalid connection init payload"
|
189
|
+
)
|
193
190
|
return
|
194
191
|
|
195
192
|
self.connection_params = payload
|
196
193
|
|
197
194
|
if self.connection_init_received:
|
198
195
|
reason = "Too many initialisation requests"
|
199
|
-
await self.close(code=4429, reason=reason)
|
196
|
+
await self.websocket.close(code=4429, reason=reason)
|
200
197
|
return
|
201
198
|
|
202
199
|
self.connection_init_received = True
|
@@ -211,13 +208,13 @@ class BaseGraphQLTransportWSHandler(ABC):
|
|
211
208
|
|
212
209
|
async def handle_subscribe(self, message: SubscribeMessage) -> None:
|
213
210
|
if not self.connection_acknowledged:
|
214
|
-
await self.close(code=4401, reason="Unauthorized")
|
211
|
+
await self.websocket.close(code=4401, reason="Unauthorized")
|
215
212
|
return
|
216
213
|
|
217
214
|
try:
|
218
215
|
graphql_document = parse(message.payload.query)
|
219
216
|
except GraphQLSyntaxError as exc:
|
220
|
-
await self.close(code=4400, reason=exc.message)
|
217
|
+
await self.websocket.close(code=4400, reason=exc.message)
|
221
218
|
return
|
222
219
|
|
223
220
|
try:
|
@@ -225,12 +222,14 @@ class BaseGraphQLTransportWSHandler(ABC):
|
|
225
222
|
graphql_document, message.payload.operationName
|
226
223
|
)
|
227
224
|
except RuntimeError:
|
228
|
-
await self.close(
|
225
|
+
await self.websocket.close(
|
226
|
+
code=4400, reason="Can't get GraphQL operation type"
|
227
|
+
)
|
229
228
|
return
|
230
229
|
|
231
230
|
if message.id in self.operations:
|
232
231
|
reason = f"Subscriber for {message.id} already exists"
|
233
|
-
await self.close(code=4409, reason=reason)
|
232
|
+
await self.websocket.close(code=4409, reason=reason)
|
234
233
|
return
|
235
234
|
|
236
235
|
if self.debug: # pragma: no cover
|
@@ -240,26 +239,28 @@ class BaseGraphQLTransportWSHandler(ABC):
|
|
240
239
|
message.payload.variables,
|
241
240
|
)
|
242
241
|
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
242
|
+
if isinstance(self.context, dict):
|
243
|
+
self.context["connection_params"] = self.connection_params
|
244
|
+
elif hasattr(self.context, "connection_params"):
|
245
|
+
self.context.connection_params = self.connection_params
|
246
|
+
|
247
247
|
result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult]
|
248
|
+
|
248
249
|
# Get an AsyncGenerator yielding the results
|
249
250
|
if operation_type == OperationType.SUBSCRIPTION:
|
250
251
|
result_source = self.schema.subscribe(
|
251
252
|
query=message.payload.query,
|
252
253
|
variable_values=message.payload.variables,
|
253
254
|
operation_name=message.payload.operationName,
|
254
|
-
context_value=context,
|
255
|
-
root_value=root_value,
|
255
|
+
context_value=self.context,
|
256
|
+
root_value=self.root_value,
|
256
257
|
)
|
257
258
|
else:
|
258
259
|
result_source = self.schema.execute(
|
259
260
|
query=message.payload.query,
|
260
261
|
variable_values=message.payload.variables,
|
261
|
-
context_value=context,
|
262
|
-
root_value=root_value,
|
262
|
+
context_value=self.context,
|
263
|
+
root_value=self.root_value,
|
263
264
|
operation_name=message.payload.operationName,
|
264
265
|
)
|
265
266
|
|
@@ -312,11 +313,11 @@ class BaseGraphQLTransportWSHandler(ABC):
|
|
312
313
|
await self.cleanup_operation(operation_id=message.id)
|
313
314
|
|
314
315
|
async def handle_invalid_message(self, error_message: str) -> None:
|
315
|
-
await self.close(code=4400, reason=error_message)
|
316
|
+
await self.websocket.close(code=4400, reason=error_message)
|
316
317
|
|
317
318
|
async def send_message(self, message: GraphQLTransportMessage) -> None:
|
318
319
|
data = message.as_dict()
|
319
|
-
await self.send_json(data)
|
320
|
+
await self.websocket.send_json(data)
|
320
321
|
|
321
322
|
async def cleanup_operation(self, operation_id: str) -> None:
|
322
323
|
if operation_id not in self.operations:
|