strawberry-graphql 0.243.0__py3-none-any.whl → 0.244.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strawberry/aiohttp/views.py +58 -44
- strawberry/asgi/__init__.py +62 -61
- strawberry/channels/__init__.py +1 -6
- strawberry/channels/handlers/base.py +2 -2
- strawberry/channels/handlers/http_handler.py +18 -1
- strawberry/channels/handlers/ws_handler.py +113 -58
- strawberry/django/views.py +19 -1
- strawberry/ext/mypy_plugin.py +8 -1
- strawberry/fastapi/router.py +27 -45
- strawberry/flask/views.py +15 -1
- strawberry/http/async_base_view.py +136 -16
- strawberry/http/exceptions.py +4 -0
- strawberry/http/typevars.py +11 -1
- strawberry/litestar/controller.py +77 -86
- strawberry/quart/views.py +15 -1
- strawberry/sanic/views.py +21 -1
- strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +41 -40
- strawberry/subscriptions/protocols/graphql_ws/handlers.py +46 -45
- {strawberry_graphql-0.243.0.dist-info → strawberry_graphql-0.244.0.dist-info}/METADATA +1 -1
- {strawberry_graphql-0.243.0.dist-info → strawberry_graphql-0.244.0.dist-info}/RECORD +23 -37
- strawberry/aiohttp/handlers/__init__.py +0 -6
- strawberry/aiohttp/handlers/graphql_transport_ws_handler.py +0 -62
- strawberry/aiohttp/handlers/graphql_ws_handler.py +0 -69
- strawberry/asgi/handlers/__init__.py +0 -6
- strawberry/asgi/handlers/graphql_transport_ws_handler.py +0 -66
- strawberry/asgi/handlers/graphql_ws_handler.py +0 -71
- strawberry/channels/handlers/graphql_transport_ws_handler.py +0 -62
- strawberry/channels/handlers/graphql_ws_handler.py +0 -72
- strawberry/fastapi/handlers/__init__.py +0 -6
- strawberry/fastapi/handlers/graphql_transport_ws_handler.py +0 -20
- strawberry/fastapi/handlers/graphql_ws_handler.py +0 -18
- strawberry/litestar/handlers/__init__.py +0 -0
- strawberry/litestar/handlers/graphql_transport_ws_handler.py +0 -60
- strawberry/litestar/handlers/graphql_ws_handler.py +0 -66
- {strawberry_graphql-0.243.0.dist-info → strawberry_graphql-0.244.0.dist-info}/LICENSE +0 -0
- {strawberry_graphql-0.243.0.dist-info → strawberry_graphql-0.244.0.dist-info}/WHEEL +0 -0
- {strawberry_graphql-0.243.0.dist-info → strawberry_graphql-0.244.0.dist-info}/entry_points.txt +0 -0
strawberry/aiohttp/views.py
CHANGED
@@ -4,6 +4,7 @@ import asyncio
|
|
4
4
|
import warnings
|
5
5
|
from datetime import timedelta
|
6
6
|
from io import BytesIO
|
7
|
+
from json.decoder import JSONDecodeError
|
7
8
|
from typing import (
|
8
9
|
TYPE_CHECKING,
|
9
10
|
Any,
|
@@ -16,15 +17,16 @@ from typing import (
|
|
16
17
|
Union,
|
17
18
|
cast,
|
18
19
|
)
|
20
|
+
from typing_extensions import TypeGuard
|
19
21
|
|
20
|
-
from aiohttp import web
|
22
|
+
from aiohttp import http, web
|
21
23
|
from aiohttp.multipart import BodyPartReader
|
22
|
-
from strawberry.
|
23
|
-
|
24
|
-
|
24
|
+
from strawberry.http.async_base_view import (
|
25
|
+
AsyncBaseHTTPView,
|
26
|
+
AsyncHTTPRequestAdapter,
|
27
|
+
AsyncWebSocketAdapter,
|
25
28
|
)
|
26
|
-
from strawberry.http.
|
27
|
-
from strawberry.http.exceptions import HTTPException
|
29
|
+
from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
|
28
30
|
from strawberry.http.types import FormData, HTTPMethod, QueryParams
|
29
31
|
from strawberry.http.typevars import (
|
30
32
|
Context,
|
@@ -79,11 +81,36 @@ class AioHTTPRequestAdapter(AsyncHTTPRequestAdapter):
|
|
79
81
|
return self.headers.get("content-type")
|
80
82
|
|
81
83
|
|
84
|
+
class AioHTTPWebSocketAdapter(AsyncWebSocketAdapter):
|
85
|
+
def __init__(self, request: web.Request, ws: web.WebSocketResponse) -> None:
|
86
|
+
self.request = request
|
87
|
+
self.ws = ws
|
88
|
+
|
89
|
+
async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
|
90
|
+
async for ws_message in self.ws:
|
91
|
+
if ws_message.type == http.WSMsgType.TEXT:
|
92
|
+
try:
|
93
|
+
yield ws_message.json()
|
94
|
+
except JSONDecodeError:
|
95
|
+
raise NonJsonMessageReceived()
|
96
|
+
|
97
|
+
elif ws_message.type == http.WSMsgType.BINARY:
|
98
|
+
raise NonJsonMessageReceived()
|
99
|
+
|
100
|
+
async def send_json(self, message: Mapping[str, object]) -> None:
|
101
|
+
await self.ws.send_json(message)
|
102
|
+
|
103
|
+
async def close(self, code: int, reason: str) -> None:
|
104
|
+
await self.ws.close(code=code, message=reason.encode())
|
105
|
+
|
106
|
+
|
82
107
|
class GraphQLView(
|
83
108
|
AsyncBaseHTTPView[
|
84
109
|
web.Request,
|
85
110
|
Union[web.Response, web.StreamResponse],
|
86
111
|
web.Response,
|
112
|
+
web.Request,
|
113
|
+
web.WebSocketResponse,
|
87
114
|
Context,
|
88
115
|
RootValue,
|
89
116
|
]
|
@@ -92,10 +119,9 @@ class GraphQLView(
|
|
92
119
|
# bare handler function.
|
93
120
|
_is_coroutine = asyncio.coroutines._is_coroutine # type: ignore[attr-defined]
|
94
121
|
|
95
|
-
graphql_transport_ws_handler_class = GraphQLTransportWSHandler
|
96
|
-
graphql_ws_handler_class = GraphQLWSHandler
|
97
122
|
allow_queries_via_get = True
|
98
123
|
request_adapter_class = AioHTTPRequestAdapter
|
124
|
+
websocket_adapter_class = AioHTTPWebSocketAdapter
|
99
125
|
|
100
126
|
def __init__(
|
101
127
|
self,
|
@@ -138,48 +164,36 @@ class GraphQLView(
|
|
138
164
|
async def get_sub_response(self, request: web.Request) -> web.Response:
|
139
165
|
return web.Response()
|
140
166
|
|
141
|
-
|
167
|
+
def is_websocket_request(self, request: web.Request) -> TypeGuard[web.Request]:
|
142
168
|
ws = web.WebSocketResponse(protocols=self.subscription_protocols)
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
schema=self.schema,
|
166
|
-
debug=self.debug,
|
167
|
-
keep_alive=self.keep_alive,
|
168
|
-
keep_alive_interval=self.keep_alive_interval,
|
169
|
-
get_context=self.get_context,
|
170
|
-
get_root_value=self.get_root_value,
|
171
|
-
request=request,
|
172
|
-
).handle()
|
173
|
-
else:
|
174
|
-
await ws.prepare(request)
|
175
|
-
await ws.close(code=4406, message=b"Subprotocol not acceptable")
|
176
|
-
return ws
|
169
|
+
return ws.can_prepare(request).ok
|
170
|
+
|
171
|
+
async def pick_websocket_subprotocol(self, request: web.Request) -> Optional[str]:
|
172
|
+
ws = web.WebSocketResponse(protocols=self.subscription_protocols)
|
173
|
+
return ws.can_prepare(request).protocol
|
174
|
+
|
175
|
+
async def create_websocket_response(
|
176
|
+
self, request: web.Request, subprotocol: Optional[str]
|
177
|
+
) -> web.WebSocketResponse:
|
178
|
+
protocols = [subprotocol] if subprotocol else []
|
179
|
+
ws = web.WebSocketResponse(protocols=protocols)
|
180
|
+
await ws.prepare(request)
|
181
|
+
return ws
|
182
|
+
|
183
|
+
async def __call__(self, request: web.Request) -> web.StreamResponse:
|
184
|
+
try:
|
185
|
+
return await self.run(request=request)
|
186
|
+
except HTTPException as e:
|
187
|
+
return web.Response(
|
188
|
+
body=e.reason,
|
189
|
+
status=e.status_code,
|
190
|
+
)
|
177
191
|
|
178
192
|
async def get_root_value(self, request: web.Request) -> Optional[RootValue]:
|
179
193
|
return None
|
180
194
|
|
181
195
|
async def get_context(
|
182
|
-
self, request: web.Request, response: web.Response
|
196
|
+
self, request: web.Request, response: Union[web.Response, web.WebSocketResponse]
|
183
197
|
) -> Context:
|
184
198
|
return {"request": request, "response": response} # type: ignore
|
185
199
|
|
strawberry/asgi/__init__.py
CHANGED
@@ -2,9 +2,11 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import warnings
|
4
4
|
from datetime import timedelta
|
5
|
+
from json import JSONDecodeError
|
5
6
|
from typing import (
|
6
7
|
TYPE_CHECKING,
|
7
8
|
Any,
|
9
|
+
AsyncGenerator,
|
8
10
|
AsyncIterator,
|
9
11
|
Callable,
|
10
12
|
Dict,
|
@@ -14,6 +16,7 @@ from typing import (
|
|
14
16
|
Union,
|
15
17
|
cast,
|
16
18
|
)
|
19
|
+
from typing_extensions import TypeGuard
|
17
20
|
|
18
21
|
from starlette import status
|
19
22
|
from starlette.requests import Request
|
@@ -23,14 +26,14 @@ from starlette.responses import (
|
|
23
26
|
Response,
|
24
27
|
StreamingResponse,
|
25
28
|
)
|
26
|
-
from starlette.websockets import WebSocket
|
29
|
+
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
|
27
30
|
|
28
|
-
from strawberry.
|
29
|
-
|
30
|
-
|
31
|
+
from strawberry.http.async_base_view import (
|
32
|
+
AsyncBaseHTTPView,
|
33
|
+
AsyncHTTPRequestAdapter,
|
34
|
+
AsyncWebSocketAdapter,
|
31
35
|
)
|
32
|
-
from strawberry.http.
|
33
|
-
from strawberry.http.exceptions import HTTPException
|
36
|
+
from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
|
34
37
|
from strawberry.http.types import FormData, HTTPMethod, QueryParams
|
35
38
|
from strawberry.http.typevars import (
|
36
39
|
Context,
|
@@ -78,19 +81,41 @@ class ASGIRequestAdapter(AsyncHTTPRequestAdapter):
|
|
78
81
|
)
|
79
82
|
|
80
83
|
|
84
|
+
class ASGIWebSocketAdapter(AsyncWebSocketAdapter):
|
85
|
+
def __init__(self, request: WebSocket, response: WebSocket) -> None:
|
86
|
+
self.ws = response
|
87
|
+
|
88
|
+
async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
|
89
|
+
try:
|
90
|
+
try:
|
91
|
+
while self.ws.application_state != WebSocketState.DISCONNECTED:
|
92
|
+
yield await self.ws.receive_json()
|
93
|
+
except (KeyError, JSONDecodeError):
|
94
|
+
raise NonJsonMessageReceived()
|
95
|
+
except WebSocketDisconnect: # pragma: no cover
|
96
|
+
pass
|
97
|
+
|
98
|
+
async def send_json(self, message: Mapping[str, object]) -> None:
|
99
|
+
await self.ws.send_json(message)
|
100
|
+
|
101
|
+
async def close(self, code: int, reason: str) -> None:
|
102
|
+
await self.ws.close(code=code, reason=reason)
|
103
|
+
|
104
|
+
|
81
105
|
class GraphQL(
|
82
106
|
AsyncBaseHTTPView[
|
83
|
-
|
107
|
+
Request,
|
84
108
|
Response,
|
85
109
|
Response,
|
110
|
+
WebSocket,
|
111
|
+
WebSocket,
|
86
112
|
Context,
|
87
113
|
RootValue,
|
88
114
|
]
|
89
115
|
):
|
90
|
-
graphql_transport_ws_handler_class = GraphQLTransportWSHandler
|
91
|
-
graphql_ws_handler_class = GraphQLWSHandler
|
92
116
|
allow_queries_via_get = True
|
93
|
-
request_adapter_class = ASGIRequestAdapter
|
117
|
+
request_adapter_class = ASGIRequestAdapter
|
118
|
+
websocket_adapter_class = ASGIWebSocketAdapter
|
94
119
|
|
95
120
|
def __init__(
|
96
121
|
self,
|
@@ -129,51 +154,25 @@ class GraphQL(
|
|
129
154
|
|
130
155
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
131
156
|
if scope["type"] == "http":
|
132
|
-
|
157
|
+
http_request = Request(scope=scope, receive=receive)
|
133
158
|
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
|
139
|
-
await self.graphql_transport_ws_handler_class(
|
140
|
-
schema=self.schema,
|
141
|
-
debug=self.debug,
|
142
|
-
connection_init_wait_timeout=self.connection_init_wait_timeout,
|
143
|
-
get_context=self.get_context,
|
144
|
-
get_root_value=self.get_root_value,
|
145
|
-
ws=ws,
|
146
|
-
).handle()
|
147
|
-
|
148
|
-
elif preferred_protocol == GRAPHQL_WS_PROTOCOL:
|
149
|
-
await self.graphql_ws_handler_class(
|
150
|
-
schema=self.schema,
|
151
|
-
debug=self.debug,
|
152
|
-
keep_alive=self.keep_alive,
|
153
|
-
keep_alive_interval=self.keep_alive_interval,
|
154
|
-
get_context=self.get_context,
|
155
|
-
get_root_value=self.get_root_value,
|
156
|
-
ws=ws,
|
157
|
-
).handle()
|
158
|
-
|
159
|
-
else:
|
160
|
-
# Subprotocol not acceptable
|
161
|
-
await ws.close(code=4406)
|
159
|
+
try:
|
160
|
+
response = await self.run(http_request)
|
161
|
+
except HTTPException as e:
|
162
|
+
response = PlainTextResponse(e.reason, status_code=e.status_code)
|
162
163
|
|
164
|
+
await response(scope, receive, send)
|
165
|
+
elif scope["type"] == "websocket":
|
166
|
+
ws_request = WebSocket(scope, receive=receive, send=send)
|
167
|
+
await self.run(ws_request)
|
163
168
|
else: # pragma: no cover
|
164
169
|
raise ValueError("Unknown scope type: {!r}".format(scope["type"]))
|
165
170
|
|
166
|
-
def pick_preferred_protocol(self, ws: WebSocket) -> Optional[str]:
|
167
|
-
protocols = ws["subprotocols"]
|
168
|
-
intersection = set(protocols) & set(self.protocols)
|
169
|
-
sorted_intersection = sorted(intersection, key=protocols.index)
|
170
|
-
return next(iter(sorted_intersection), None)
|
171
|
-
|
172
171
|
async def get_root_value(self, request: Union[Request, WebSocket]) -> Optional[Any]:
|
173
172
|
return None
|
174
173
|
|
175
174
|
async def get_context(
|
176
|
-
self, request: Union[Request, WebSocket], response: Response
|
175
|
+
self, request: Union[Request, WebSocket], response: Union[Response, WebSocket]
|
177
176
|
) -> Context:
|
178
177
|
return {"request": request, "response": response} # type: ignore
|
179
178
|
|
@@ -187,21 +186,6 @@ class GraphQL(
|
|
187
186
|
|
188
187
|
return sub_response
|
189
188
|
|
190
|
-
async def handle_http(
|
191
|
-
self,
|
192
|
-
scope: Scope,
|
193
|
-
receive: Receive,
|
194
|
-
send: Send,
|
195
|
-
) -> None:
|
196
|
-
request = Request(scope=scope, receive=receive)
|
197
|
-
|
198
|
-
try:
|
199
|
-
response = await self.run(request)
|
200
|
-
except HTTPException as e:
|
201
|
-
response = PlainTextResponse(e.reason, status_code=e.status_code) # pyright: ignore
|
202
|
-
|
203
|
-
await response(scope, receive, send)
|
204
|
-
|
205
189
|
async def render_graphql_ide(self, request: Union[Request, WebSocket]) -> Response:
|
206
190
|
return HTMLResponse(self.graphql_ide_html)
|
207
191
|
|
@@ -239,3 +223,20 @@ class GraphQL(
|
|
239
223
|
**headers,
|
240
224
|
},
|
241
225
|
)
|
226
|
+
|
227
|
+
def is_websocket_request(
|
228
|
+
self, request: Union[Request, WebSocket]
|
229
|
+
) -> TypeGuard[WebSocket]:
|
230
|
+
return request.scope["type"] == "websocket"
|
231
|
+
|
232
|
+
async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]:
|
233
|
+
protocols = request["subprotocols"]
|
234
|
+
intersection = set(protocols) & set(self.protocols)
|
235
|
+
sorted_intersection = sorted(intersection, key=protocols.index)
|
236
|
+
return next(iter(sorted_intersection), None)
|
237
|
+
|
238
|
+
async def create_websocket_response(
|
239
|
+
self, request: WebSocket, subprotocol: Optional[str]
|
240
|
+
) -> WebSocket:
|
241
|
+
await request.accept(subprotocol=subprotocol)
|
242
|
+
return request
|
strawberry/channels/__init__.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1
|
-
from .handlers.base import ChannelsConsumer
|
2
|
-
from .handlers.graphql_transport_ws_handler import GraphQLTransportWSHandler
|
3
|
-
from .handlers.graphql_ws_handler import GraphQLWSHandler
|
1
|
+
from .handlers.base import ChannelsConsumer
|
4
2
|
from .handlers.http_handler import (
|
5
3
|
ChannelsRequest,
|
6
4
|
GraphQLHTTPConsumer,
|
@@ -12,10 +10,7 @@ from .router import GraphQLProtocolTypeRouter
|
|
12
10
|
__all__ = [
|
13
11
|
"ChannelsConsumer",
|
14
12
|
"ChannelsRequest",
|
15
|
-
"ChannelsWSConsumer",
|
16
13
|
"GraphQLProtocolTypeRouter",
|
17
|
-
"GraphQLWSHandler",
|
18
|
-
"GraphQLTransportWSHandler",
|
19
14
|
"GraphQLHTTPConsumer",
|
20
15
|
"GraphQLWSConsumer",
|
21
16
|
"SyncGraphQLHTTPConsumer",
|
@@ -16,7 +16,7 @@ from typing_extensions import Literal, Protocol, TypedDict
|
|
16
16
|
from weakref import WeakSet
|
17
17
|
|
18
18
|
from channels.consumer import AsyncConsumer
|
19
|
-
from channels.generic.websocket import
|
19
|
+
from channels.generic.websocket import AsyncWebsocketConsumer
|
20
20
|
|
21
21
|
|
22
22
|
class ChannelsMessage(TypedDict, total=False):
|
@@ -210,7 +210,7 @@ class ChannelsConsumer(AsyncConsumer):
|
|
210
210
|
return
|
211
211
|
|
212
212
|
|
213
|
-
class ChannelsWSConsumer(ChannelsConsumer,
|
213
|
+
class ChannelsWSConsumer(ChannelsConsumer, AsyncWebsocketConsumer):
|
214
214
|
"""Base channels websocket async consumer."""
|
215
215
|
|
216
216
|
|
@@ -15,7 +15,7 @@ from typing import (
|
|
15
15
|
Optional,
|
16
16
|
Union,
|
17
17
|
)
|
18
|
-
from typing_extensions import assert_never
|
18
|
+
from typing_extensions import TypeGuard, assert_never
|
19
19
|
from urllib.parse import parse_qs
|
20
20
|
|
21
21
|
from django.conf import settings
|
@@ -233,6 +233,8 @@ class GraphQLHTTPConsumer(
|
|
233
233
|
ChannelsRequest,
|
234
234
|
Union[ChannelsResponse, MultipartChannelsResponse],
|
235
235
|
TemporalResponse,
|
236
|
+
ChannelsRequest,
|
237
|
+
TemporalResponse,
|
236
238
|
Context,
|
237
239
|
RootValue,
|
238
240
|
],
|
@@ -298,6 +300,21 @@ class GraphQLHTTPConsumer(
|
|
298
300
|
content=self.graphql_ide_html.encode(), content_type="text/html"
|
299
301
|
)
|
300
302
|
|
303
|
+
def is_websocket_request(
|
304
|
+
self, request: ChannelsRequest
|
305
|
+
) -> TypeGuard[ChannelsRequest]:
|
306
|
+
return False
|
307
|
+
|
308
|
+
async def pick_websocket_subprotocol(
|
309
|
+
self, request: ChannelsRequest
|
310
|
+
) -> Optional[str]:
|
311
|
+
return None
|
312
|
+
|
313
|
+
async def create_websocket_response(
|
314
|
+
self, request: ChannelsRequest, subprotocol: Optional[str]
|
315
|
+
) -> TemporalResponse:
|
316
|
+
raise NotImplementedError
|
317
|
+
|
301
318
|
|
302
319
|
class SyncGraphQLHTTPConsumer(
|
303
320
|
BaseGraphQLHTTPConsumer,
|
@@ -1,20 +1,76 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import asyncio
|
3
4
|
import datetime
|
4
|
-
|
5
|
-
|
5
|
+
import json
|
6
|
+
from typing import (
|
7
|
+
TYPE_CHECKING,
|
8
|
+
AsyncGenerator,
|
9
|
+
Dict,
|
10
|
+
Mapping,
|
11
|
+
Optional,
|
12
|
+
Tuple,
|
13
|
+
TypedDict,
|
14
|
+
Union,
|
15
|
+
)
|
16
|
+
from typing_extensions import TypeGuard
|
17
|
+
|
18
|
+
from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter
|
19
|
+
from strawberry.http.exceptions import NonJsonMessageReceived
|
20
|
+
from strawberry.http.typevars import Context, RootValue
|
6
21
|
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
|
7
22
|
|
8
|
-
from .base import
|
9
|
-
from .graphql_transport_ws_handler import GraphQLTransportWSHandler
|
10
|
-
from .graphql_ws_handler import GraphQLWSHandler
|
23
|
+
from .base import ChannelsWSConsumer
|
11
24
|
|
12
25
|
if TYPE_CHECKING:
|
13
|
-
from strawberry.http
|
26
|
+
from strawberry.http import GraphQLHTTPResponse
|
14
27
|
from strawberry.schema import BaseSchema
|
15
28
|
|
16
29
|
|
17
|
-
class
|
30
|
+
class ChannelsWebSocketAdapter(AsyncWebSocketAdapter):
|
31
|
+
def __init__(self, request: GraphQLWSConsumer, response: GraphQLWSConsumer) -> None:
|
32
|
+
self.ws_consumer = response
|
33
|
+
|
34
|
+
async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
|
35
|
+
while True:
|
36
|
+
message = await self.ws_consumer.message_queue.get()
|
37
|
+
|
38
|
+
if message["disconnected"]:
|
39
|
+
break
|
40
|
+
|
41
|
+
if message["message"] is None:
|
42
|
+
raise NonJsonMessageReceived()
|
43
|
+
|
44
|
+
try:
|
45
|
+
yield json.loads(message["message"])
|
46
|
+
except json.JSONDecodeError:
|
47
|
+
raise NonJsonMessageReceived()
|
48
|
+
|
49
|
+
async def send_json(self, message: Mapping[str, object]) -> None:
|
50
|
+
serialized_message = json.dumps(message)
|
51
|
+
await self.ws_consumer.send(serialized_message)
|
52
|
+
|
53
|
+
async def close(self, code: int, reason: str) -> None:
|
54
|
+
await self.ws_consumer.close(code=code, reason=reason)
|
55
|
+
|
56
|
+
|
57
|
+
class MessageQueueData(TypedDict):
|
58
|
+
message: Union[str, None]
|
59
|
+
disconnected: bool
|
60
|
+
|
61
|
+
|
62
|
+
class GraphQLWSConsumer(
|
63
|
+
ChannelsWSConsumer,
|
64
|
+
AsyncBaseHTTPView[
|
65
|
+
"GraphQLWSConsumer",
|
66
|
+
"GraphQLWSConsumer",
|
67
|
+
"GraphQLWSConsumer",
|
68
|
+
"GraphQLWSConsumer",
|
69
|
+
"GraphQLWSConsumer",
|
70
|
+
Context,
|
71
|
+
RootValue,
|
72
|
+
],
|
73
|
+
):
|
18
74
|
"""A channels websocket consumer for GraphQL.
|
19
75
|
|
20
76
|
This handles the connections, then hands off to the appropriate
|
@@ -39,9 +95,7 @@ class GraphQLWSConsumer(ChannelsWSConsumer):
|
|
39
95
|
```
|
40
96
|
"""
|
41
97
|
|
42
|
-
|
43
|
-
graphql_ws_handler_class = GraphQLWSHandler
|
44
|
-
_handler: Union[GraphQLWSHandler, GraphQLTransportWSHandler]
|
98
|
+
websocket_adapter_class = ChannelsWebSocketAdapter
|
45
99
|
|
46
100
|
def __init__(
|
47
101
|
self,
|
@@ -63,70 +117,71 @@ class GraphQLWSConsumer(ChannelsWSConsumer):
|
|
63
117
|
self.keep_alive_interval = keep_alive_interval
|
64
118
|
self.debug = debug
|
65
119
|
self.protocols = subscription_protocols
|
120
|
+
self.message_queue: asyncio.Queue[MessageQueueData] = asyncio.Queue()
|
121
|
+
self.run_task: Optional[asyncio.Task] = None
|
66
122
|
|
67
123
|
super().__init__()
|
68
124
|
|
69
|
-
def pick_preferred_protocol(
|
70
|
-
self, accepted_subprotocols: Sequence[str]
|
71
|
-
) -> Optional[str]:
|
72
|
-
intersection = set(accepted_subprotocols) & set(self.protocols)
|
73
|
-
sorted_intersection = sorted(intersection, key=accepted_subprotocols.index)
|
74
|
-
return next(iter(sorted_intersection), None)
|
75
|
-
|
76
125
|
async def connect(self) -> None:
|
77
|
-
|
78
|
-
|
79
|
-
if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
|
80
|
-
self._handler = self.graphql_transport_ws_handler_class(
|
81
|
-
schema=self.schema,
|
82
|
-
debug=self.debug,
|
83
|
-
connection_init_wait_timeout=self.connection_init_wait_timeout,
|
84
|
-
get_context=self.get_context,
|
85
|
-
get_root_value=self.get_root_value,
|
86
|
-
ws=self,
|
87
|
-
)
|
88
|
-
elif preferred_protocol == GRAPHQL_WS_PROTOCOL:
|
89
|
-
self._handler = self.graphql_ws_handler_class(
|
90
|
-
schema=self.schema,
|
91
|
-
debug=self.debug,
|
92
|
-
keep_alive=self.keep_alive,
|
93
|
-
keep_alive_interval=self.keep_alive_interval,
|
94
|
-
get_context=self.get_context,
|
95
|
-
get_root_value=self.get_root_value,
|
96
|
-
ws=self,
|
97
|
-
)
|
98
|
-
else:
|
99
|
-
# Subprotocol not acceptable
|
100
|
-
return await self.close(code=4406)
|
126
|
+
self.run_task = asyncio.create_task(self.run(self))
|
101
127
|
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
except ValueError:
|
110
|
-
reason = "WebSocket message type must be text"
|
111
|
-
await self._handler.handle_invalid_message(reason)
|
112
|
-
|
113
|
-
async def receive_json(self, content: Any, **kwargs: Any) -> None:
|
114
|
-
await self._handler.handle_message(content)
|
128
|
+
async def receive(
|
129
|
+
self, text_data: Optional[str] = None, bytes_data: Optional[bytes] = None
|
130
|
+
) -> None:
|
131
|
+
if text_data:
|
132
|
+
self.message_queue.put_nowait({"message": text_data, "disconnected": False})
|
133
|
+
else:
|
134
|
+
self.message_queue.put_nowait({"message": None, "disconnected": False})
|
115
135
|
|
116
136
|
async def disconnect(self, code: int) -> None:
|
117
|
-
|
137
|
+
self.message_queue.put_nowait({"message": None, "disconnected": True})
|
138
|
+
assert self.run_task
|
139
|
+
await self.run_task
|
118
140
|
|
119
|
-
async def get_root_value(self, request:
|
141
|
+
async def get_root_value(self, request: GraphQLWSConsumer) -> Optional[RootValue]:
|
120
142
|
return None
|
121
143
|
|
122
144
|
async def get_context(
|
123
|
-
self, request:
|
145
|
+
self, request: GraphQLWSConsumer, response: GraphQLWSConsumer
|
124
146
|
) -> Context:
|
125
147
|
return {
|
126
148
|
"request": request,
|
127
|
-
"connection_params": connection_params,
|
128
149
|
"ws": request,
|
129
150
|
} # type: ignore
|
130
151
|
|
152
|
+
@property
|
153
|
+
def allow_queries_via_get(self) -> bool:
|
154
|
+
return False
|
155
|
+
|
156
|
+
async def get_sub_response(self, request: GraphQLWSConsumer) -> GraphQLWSConsumer:
|
157
|
+
raise NotImplementedError
|
158
|
+
|
159
|
+
def create_response(
|
160
|
+
self, response_data: GraphQLHTTPResponse, sub_response: GraphQLWSConsumer
|
161
|
+
) -> GraphQLWSConsumer:
|
162
|
+
raise NotImplementedError
|
163
|
+
|
164
|
+
async def render_graphql_ide(self, request: GraphQLWSConsumer) -> GraphQLWSConsumer:
|
165
|
+
raise NotImplementedError
|
166
|
+
|
167
|
+
def is_websocket_request(
|
168
|
+
self, request: GraphQLWSConsumer
|
169
|
+
) -> TypeGuard[GraphQLWSConsumer]:
|
170
|
+
return True
|
171
|
+
|
172
|
+
async def pick_websocket_subprotocol(
|
173
|
+
self, request: GraphQLWSConsumer
|
174
|
+
) -> Optional[str]:
|
175
|
+
protocols = request.scope["subprotocols"]
|
176
|
+
intersection = set(protocols) & set(self.protocols)
|
177
|
+
sorted_intersection = sorted(intersection, key=protocols.index)
|
178
|
+
return next(iter(sorted_intersection), None)
|
179
|
+
|
180
|
+
async def create_websocket_response(
|
181
|
+
self, request: GraphQLWSConsumer, subprotocol: Optional[str]
|
182
|
+
) -> GraphQLWSConsumer:
|
183
|
+
await request.accept(subprotocol=subprotocol)
|
184
|
+
return request
|
185
|
+
|
131
186
|
|
132
187
|
__all__ = ["GraphQLWSConsumer"]
|
strawberry/django/views.py
CHANGED
@@ -13,6 +13,7 @@ from typing import (
|
|
13
13
|
Union,
|
14
14
|
cast,
|
15
15
|
)
|
16
|
+
from typing_extensions import TypeGuard
|
16
17
|
|
17
18
|
from asgiref.sync import markcoroutinefunction
|
18
19
|
from django.core.serializers.json import DjangoJSONEncoder
|
@@ -258,7 +259,13 @@ class GraphQLView(
|
|
258
259
|
class AsyncGraphQLView(
|
259
260
|
BaseView,
|
260
261
|
AsyncBaseHTTPView[
|
261
|
-
HttpRequest,
|
262
|
+
HttpRequest,
|
263
|
+
HttpResponseBase,
|
264
|
+
TemporalHttpResponse,
|
265
|
+
HttpRequest,
|
266
|
+
TemporalHttpResponse,
|
267
|
+
Context,
|
268
|
+
RootValue,
|
262
269
|
],
|
263
270
|
View,
|
264
271
|
):
|
@@ -312,5 +319,16 @@ class AsyncGraphQLView(
|
|
312
319
|
|
313
320
|
return response
|
314
321
|
|
322
|
+
def is_websocket_request(self, request: HttpRequest) -> TypeGuard[HttpRequest]:
|
323
|
+
return False
|
324
|
+
|
325
|
+
async def pick_websocket_subprotocol(self, request: HttpRequest) -> Optional[str]:
|
326
|
+
raise NotImplementedError
|
327
|
+
|
328
|
+
async def create_websocket_response(
|
329
|
+
self, request: HttpRequest, subprotocol: Optional[str]
|
330
|
+
) -> TemporalHttpResponse:
|
331
|
+
raise NotImplementedError
|
332
|
+
|
315
333
|
|
316
334
|
__all__ = ["GraphQLView", "AsyncGraphQLView"]
|
strawberry/ext/mypy_plugin.py
CHANGED
@@ -481,7 +481,14 @@ def strawberry_pydantic_class_callback(ctx: ClassDefContext) -> None:
|
|
481
481
|
# Based on pydantic's default value
|
482
482
|
# https://github.com/pydantic/pydantic/pull/9606/files#diff-469037bbe55bbf9aa359480a16040d368c676adad736e133fb07e5e20d6ac523R1066
|
483
483
|
extra["force_typevars_invariant"] = False
|
484
|
-
|
484
|
+
if PYDANTIC_VERSION >= (2, 9, 0):
|
485
|
+
extra["model_strict"] = model_type.type.metadata[
|
486
|
+
PYDANTIC_METADATA_KEY
|
487
|
+
]["config"].get("strict", False)
|
488
|
+
extra["is_root_model_root"] = any(
|
489
|
+
"pydantic.root_model.RootModel" in base.fullname
|
490
|
+
for base in model_type.type.mro[:-1]
|
491
|
+
)
|
485
492
|
add_method(
|
486
493
|
ctx,
|
487
494
|
"to_pydantic",
|