strawberry-graphql 0.168.2__py3-none-any.whl → 0.170.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import json
4
3
  from datetime import timedelta
5
4
  from inspect import signature
6
5
  from typing import (
@@ -8,8 +7,7 @@ from typing import (
8
7
  Any,
9
8
  Awaitable,
10
9
  Callable,
11
- Dict,
12
- Iterable,
10
+ Mapping,
13
11
  Optional,
14
12
  Sequence,
15
13
  Union,
@@ -27,19 +25,20 @@ from starlette.responses import (
27
25
  from starlette.websockets import WebSocket
28
26
 
29
27
  from fastapi import APIRouter, Depends
30
- from strawberry.exceptions import InvalidCustomContext, MissingQueryError
28
+ from strawberry.exceptions import InvalidCustomContext
31
29
  from strawberry.fastapi.context import BaseContext, CustomContext
32
30
  from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler
33
- from strawberry.file_uploads.utils import replace_placeholders_with_files
34
31
  from strawberry.http import (
35
- parse_query_params,
36
- parse_request_data,
37
32
  process_result,
38
33
  )
39
- from strawberry.schema.exceptions import InvalidOperationTypeError
34
+ from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter
35
+ from strawberry.http.exceptions import HTTPException
36
+ from strawberry.http.types import FormData, HTTPMethod, QueryParams
37
+ from strawberry.http.typevars import (
38
+ Context,
39
+ RootValue,
40
+ )
40
41
  from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
41
- from strawberry.types.graphql import OperationType
42
- from strawberry.utils.debug import pretty_print_graphql_operation
43
42
  from strawberry.utils.graphiql import get_graphiql_html
44
43
 
45
44
  if TYPE_CHECKING:
@@ -51,9 +50,42 @@ if TYPE_CHECKING:
51
50
  from strawberry.types import ExecutionResult
52
51
 
53
52
 
54
- class GraphQLRouter(APIRouter):
53
+ class FastAPIRequestAdapter(AsyncHTTPRequestAdapter):
54
+ def __init__(self, request: Request):
55
+ self.request = request
56
+
57
+ @property
58
+ def query_params(self) -> QueryParams:
59
+ return dict(self.request.query_params)
60
+
61
+ @property
62
+ def method(self) -> HTTPMethod:
63
+ return cast(HTTPMethod, self.request.method.upper())
64
+
65
+ @property
66
+ def headers(self) -> Mapping[str, str]:
67
+ return self.request.headers
68
+
69
+ @property
70
+ def content_type(self) -> Optional[str]:
71
+ return self.request.headers.get("Content-Type", None)
72
+
73
+ async def get_body(self) -> bytes:
74
+ return await self.request.body()
75
+
76
+ async def get_form_data(self) -> FormData:
77
+ multipart_data = await self.request.form()
78
+
79
+ return FormData(files=multipart_data, form=multipart_data)
80
+
81
+
82
+ class GraphQLRouter(
83
+ AsyncBaseHTTPView[Request, Response, Response, Context, RootValue], APIRouter
84
+ ):
55
85
  graphql_ws_handler_class = GraphQLWSHandler
56
86
  graphql_transport_ws_handler_class = GraphQLTransportWSHandler
87
+ allow_queries_via_get = True
88
+ request_adapter_class = FastAPIRequestAdapter
57
89
 
58
90
  @staticmethod
59
91
  async def __get_root_value():
@@ -119,9 +151,12 @@ class GraphQLRouter(APIRouter):
119
151
  keep_alive: bool = False,
120
152
  keep_alive_interval: float = 1,
121
153
  debug: bool = False,
122
- root_value_getter=None,
123
- context_getter=None,
124
- subscription_protocols=(GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL),
154
+ root_value_getter: Optional[Callable[[], RootValue]] = None,
155
+ context_getter: Optional[Callable[..., Optional[Context]]] = None,
156
+ subscription_protocols: Sequence[str] = (
157
+ GRAPHQL_TRANSPORT_WS_PROTOCOL,
158
+ GRAPHQL_WS_PROTOCOL,
159
+ ),
125
160
  connection_init_wait_timeout: timedelta = timedelta(minutes=1),
126
161
  default: Optional[ASGIApp] = None,
127
162
  on_startup: Optional[Sequence[Callable[[], Any]]] = None,
@@ -139,8 +174,9 @@ class GraphQLRouter(APIRouter):
139
174
  self.keep_alive_interval = keep_alive_interval
140
175
  self.debug = debug
141
176
  self.root_value_getter = root_value_getter or self.__get_root_value
177
+ # TODO: clean this type up
142
178
  self.context_getter = self.__get_context_getter(
143
- context_getter or (lambda: None)
179
+ context_getter or (lambda: None) # type: ignore
144
180
  )
145
181
  self.protocols = subscription_protocols
146
182
  self.connection_init_wait_timeout = connection_init_wait_timeout
@@ -156,100 +192,49 @@ class GraphQLRouter(APIRouter):
156
192
  },
157
193
  },
158
194
  )
159
- async def handle_http_get(
195
+ async def handle_http_get( # pyright: ignore
160
196
  request: Request,
161
197
  response: Response,
162
- context=Depends(self.context_getter),
163
- root_value=Depends(self.root_value_getter),
198
+ context: Context = Depends(self.context_getter),
199
+ root_value: RootValue = Depends(self.root_value_getter),
164
200
  ) -> Response:
165
- if request.query_params:
166
- try:
167
- query_data = parse_query_params(request.query_params._dict)
168
-
169
- except json.JSONDecodeError:
170
- return PlainTextResponse(
171
- "Unable to parse request body as JSON",
172
- status_code=status.HTTP_400_BAD_REQUEST,
173
- )
174
-
175
- return await self.execute_request(
176
- request=request,
177
- response=response,
178
- data=query_data,
179
- context=context,
180
- root_value=root_value,
201
+ self.temporal_response = response
202
+
203
+ try:
204
+ return await self.run(
205
+ request=request, context=context, root_value=root_value
206
+ )
207
+ except HTTPException as e:
208
+ return PlainTextResponse(
209
+ e.reason,
210
+ status_code=e.status_code,
181
211
  )
182
- elif self.should_render_graphiql(request):
183
- return self.get_graphiql_response()
184
- return Response(status_code=status.HTTP_404_NOT_FOUND)
185
212
 
186
213
  @self.post(path)
187
- async def handle_http_post(
214
+ async def handle_http_post( # pyright: ignore
188
215
  request: Request,
189
216
  response: Response,
190
- context=Depends(self.context_getter),
191
- root_value=Depends(self.root_value_getter),
217
+ # TODO: use Annotated in future
218
+ context: Context = Depends(self.context_getter),
219
+ root_value: RootValue = Depends(self.root_value_getter),
192
220
  ) -> Response:
193
- actual_response: Response
194
-
195
- content_type = request.headers.get("content-type", "")
196
-
197
- if "application/json" in content_type:
198
- try:
199
- data = await request.json()
200
- except json.JSONDecodeError:
201
- actual_response = PlainTextResponse(
202
- "Unable to parse request body as JSON",
203
- status_code=status.HTTP_400_BAD_REQUEST,
204
- )
205
-
206
- return self._merge_responses(response, actual_response)
207
- elif content_type.startswith("multipart/form-data"):
208
- multipart_data = await request.form()
209
- try:
210
- operations_text = multipart_data.get("operations", "{}")
211
- operations = json.loads(operations_text) # type: ignore
212
- files_map = json.loads(multipart_data.get("map", "{}")) # type: ignore # noqa: E501
213
- except json.JSONDecodeError:
214
- actual_response = PlainTextResponse(
215
- "Unable to parse request body as JSON",
216
- status_code=status.HTTP_400_BAD_REQUEST,
217
- )
218
-
219
- return self._merge_responses(response, actual_response)
220
-
221
- try:
222
- data = replace_placeholders_with_files(
223
- operations, files_map, multipart_data
224
- )
225
- except KeyError:
226
- actual_response = PlainTextResponse(
227
- "File(s) missing in form data",
228
- status_code=status.HTTP_400_BAD_REQUEST,
229
- )
230
-
231
- return self._merge_responses(response, actual_response)
232
- else:
233
- actual_response = PlainTextResponse(
234
- "Unsupported Media Type",
235
- status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
236
- )
237
-
238
- return self._merge_responses(response, actual_response)
221
+ self.temporal_response = response
239
222
 
240
- return await self.execute_request(
241
- request=request,
242
- response=response,
243
- data=data,
244
- context=context,
245
- root_value=root_value,
246
- )
223
+ try:
224
+ return await self.run(
225
+ request=request, context=context, root_value=root_value
226
+ )
227
+ except HTTPException as e:
228
+ return PlainTextResponse(
229
+ e.reason,
230
+ status_code=e.status_code,
231
+ )
247
232
 
248
233
  @self.websocket(path)
249
- async def websocket_endpoint(
234
+ async def websocket_endpoint( # pyright: ignore
250
235
  websocket: WebSocket,
251
- context=Depends(self.context_getter),
252
- root_value=Depends(self.root_value_getter),
236
+ context: Context = Depends(self.context_getter),
237
+ root_value: RootValue = Depends(self.root_value_getter),
253
238
  ):
254
239
  async def _get_context():
255
240
  return context
@@ -290,15 +275,7 @@ class GraphQLRouter(APIRouter):
290
275
  default=None,
291
276
  )
292
277
 
293
- def should_render_graphiql(self, request: Request) -> bool:
294
- if not self.graphiql:
295
- return False
296
- return any(
297
- supported_header in request.headers.get("accept", "")
298
- for supported_header in ("text/html", "*/*")
299
- )
300
-
301
- def get_graphiql_response(self) -> HTMLResponse:
278
+ def render_graphiql(self, request: Request) -> HTMLResponse:
302
279
  html = get_graphiql_html()
303
280
  return HTMLResponse(html)
304
281
 
@@ -310,73 +287,33 @@ class GraphQLRouter(APIRouter):
310
287
 
311
288
  return actual_response
312
289
 
313
- async def execute(
314
- self,
315
- query: Optional[str],
316
- variables: Optional[Dict[str, Any]] = None,
317
- context: Any = None,
318
- operation_name: Optional[str] = None,
319
- root_value: Any = None,
320
- allowed_operation_types: Optional[Iterable[OperationType]] = None,
321
- ) -> ExecutionResult:
322
- if self.debug and query:
323
- pretty_print_graphql_operation(operation_name, query, variables)
324
-
325
- return await self.schema.execute(
326
- query,
327
- root_value=root_value,
328
- variable_values=variables,
329
- operation_name=operation_name,
330
- context_value=context,
331
- allowed_operation_types=allowed_operation_types,
332
- )
333
-
334
290
  async def process_result(
335
291
  self, request: Request, result: ExecutionResult
336
292
  ) -> GraphQLHTTPResponse:
337
293
  return process_result(result)
338
294
 
339
- async def execute_request(
340
- self, request: Request, response: Response, data: dict, context, root_value
295
+ async def get_context(
296
+ self, request: Request, response: Response
297
+ ) -> Context: # pragma: no cover
298
+ raise ValueError("`get_context` is not used by FastAPI GraphQL Router")
299
+
300
+ async def get_root_value(
301
+ self, request: Request
302
+ ) -> Optional[RootValue]: # pragma: no cover
303
+ raise ValueError("`get_root_value` is not used by FastAPI GraphQL Router")
304
+
305
+ async def get_sub_response(self, request: Request) -> Response:
306
+ return self.temporal_response
307
+
308
+ def create_response(
309
+ self, response_data: GraphQLHTTPResponse, sub_response: Response
341
310
  ) -> Response:
342
- request_data = parse_request_data(data)
343
-
344
- method = request.method
345
- allowed_operation_types = OperationType.from_http(method)
346
-
347
- if not self.allow_queries_via_get and method == "GET":
348
- allowed_operation_types = allowed_operation_types - {OperationType.QUERY}
349
-
350
- try:
351
- result = await self.execute(
352
- request_data.query,
353
- variables=request_data.variables,
354
- context=context,
355
- operation_name=request_data.operation_name,
356
- root_value=root_value,
357
- allowed_operation_types=allowed_operation_types,
358
- )
359
- except InvalidOperationTypeError as e:
360
- return PlainTextResponse(
361
- e.as_http_error_reason(method),
362
- status_code=status.HTTP_400_BAD_REQUEST,
363
- )
364
- except MissingQueryError:
365
- missing_query_response = PlainTextResponse(
366
- "No GraphQL query found in the request",
367
- status_code=status.HTTP_400_BAD_REQUEST,
368
- )
369
- return self._merge_responses(response, missing_query_response)
370
-
371
- response_data = await self.process_result(request, result)
372
-
373
- actual_response = Response(
311
+ response = Response(
374
312
  self.encode_json(response_data),
375
313
  media_type="application/json",
376
- status_code=status.HTTP_200_OK,
314
+ status_code=sub_response.status_code or status.HTTP_200_OK,
377
315
  )
378
316
 
379
- return self._merge_responses(response, actual_response)
317
+ response.headers.raw.extend(sub_response.headers.raw)
380
318
 
381
- def encode_json(self, response_data: GraphQLHTTPResponse) -> str:
382
- return json.dumps(response_data)
319
+ return response
@@ -1,10 +1,10 @@
1
1
  import copy
2
- from typing import Any, Dict, List, Mapping
2
+ from typing import Any, Dict, Mapping
3
3
 
4
4
 
5
5
  def replace_placeholders_with_files(
6
6
  operations_with_placeholders: Dict[str, Any],
7
- files_map: Mapping[str, List[str]],
7
+ files_map: Mapping[str, Any],
8
8
  files: Mapping[str, Any],
9
9
  ) -> Dict[str, Any]:
10
10
  # TODO: test this with missing variables in operations_with_placeholders