schemathesis 3.19.7__py3-none-any.whl → 3.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. schemathesis/_compat.py +3 -2
  2. schemathesis/_hypothesis.py +21 -6
  3. schemathesis/_xml.py +177 -0
  4. schemathesis/auths.py +48 -10
  5. schemathesis/cli/__init__.py +77 -19
  6. schemathesis/cli/callbacks.py +42 -18
  7. schemathesis/cli/context.py +2 -1
  8. schemathesis/cli/output/default.py +102 -34
  9. schemathesis/cli/sanitization.py +15 -0
  10. schemathesis/code_samples.py +141 -0
  11. schemathesis/constants.py +1 -24
  12. schemathesis/exceptions.py +127 -26
  13. schemathesis/experimental/__init__.py +85 -0
  14. schemathesis/extra/pytest_plugin.py +10 -4
  15. schemathesis/fixups/__init__.py +8 -2
  16. schemathesis/fixups/fast_api.py +11 -1
  17. schemathesis/fixups/utf8_bom.py +7 -1
  18. schemathesis/hooks.py +63 -0
  19. schemathesis/lazy.py +10 -4
  20. schemathesis/loaders.py +57 -0
  21. schemathesis/models.py +120 -96
  22. schemathesis/parameters.py +3 -0
  23. schemathesis/runner/__init__.py +3 -0
  24. schemathesis/runner/events.py +55 -20
  25. schemathesis/runner/impl/core.py +54 -54
  26. schemathesis/runner/serialization.py +75 -34
  27. schemathesis/sanitization.py +248 -0
  28. schemathesis/schemas.py +21 -6
  29. schemathesis/serializers.py +32 -3
  30. schemathesis/service/serialization.py +5 -1
  31. schemathesis/specs/graphql/loaders.py +44 -13
  32. schemathesis/specs/graphql/schemas.py +56 -25
  33. schemathesis/specs/openapi/_hypothesis.py +11 -23
  34. schemathesis/specs/openapi/definitions.py +572 -0
  35. schemathesis/specs/openapi/loaders.py +100 -49
  36. schemathesis/specs/openapi/parameters.py +2 -2
  37. schemathesis/specs/openapi/schemas.py +87 -13
  38. schemathesis/specs/openapi/security.py +1 -0
  39. schemathesis/stateful.py +2 -2
  40. schemathesis/utils.py +30 -9
  41. schemathesis-3.20.1.dist-info/METADATA +342 -0
  42. {schemathesis-3.19.7.dist-info → schemathesis-3.20.1.dist-info}/RECORD +45 -39
  43. schemathesis-3.19.7.dist-info/METADATA +0 -291
  44. {schemathesis-3.19.7.dist-info → schemathesis-3.20.1.dist-info}/WHEEL +0 -0
  45. {schemathesis-3.19.7.dist-info → schemathesis-3.20.1.dist-info}/entry_points.txt +0 -0
  46. {schemathesis-3.19.7.dist-info → schemathesis-3.20.1.dist-info}/licenses/LICENSE +0 -0
schemathesis/schemas.py CHANGED
@@ -37,8 +37,9 @@ from requests.structures import CaseInsensitiveDict
37
37
 
38
38
  from ._hypothesis import create_test
39
39
  from .auths import AuthStorage
40
- from .constants import DEFAULT_DATA_GENERATION_METHODS, CodeSampleStyle, DataGenerationMethod
41
- from .exceptions import InvalidSchema, UsageError
40
+ from .code_samples import CodeSampleStyle
41
+ from .constants import DEFAULT_DATA_GENERATION_METHODS, DataGenerationMethod
42
+ from .exceptions import OperationSchemaError, UsageError
42
43
  from .hooks import HookContext, HookDispatcher, HookScope, dispatch
43
44
  from .models import APIOperation, Case
44
45
  from .stateful import APIStateMachine, Stateful, StatefulTest
@@ -100,6 +101,7 @@ class BaseSchema(Mapping):
100
101
  )
101
102
  code_sample_style: CodeSampleStyle = CodeSampleStyle.default()
102
103
  rate_limiter: Optional[Limiter] = None
104
+ sanitize_output: bool = True
103
105
 
104
106
  def __iter__(self) -> Iterator[str]:
105
107
  return iter(self.operations)
@@ -155,6 +157,9 @@ class BaseSchema(Mapping):
155
157
  return base_url.rstrip("/")
156
158
  return self._build_base_url()
157
159
 
160
+ def validate(self) -> None:
161
+ raise NotImplementedError
162
+
158
163
  @property
159
164
  def operations(self) -> Dict[str, MethodsDict]:
160
165
  if not hasattr(self, "_operations"):
@@ -166,7 +171,9 @@ class BaseSchema(Mapping):
166
171
  def operations_count(self) -> int:
167
172
  raise NotImplementedError
168
173
 
169
- def get_all_operations(self) -> Generator[Result[APIOperation, InvalidSchema], None, None]:
174
+ def get_all_operations(
175
+ self, hooks: Optional[HookDispatcher] = None
176
+ ) -> Generator[Result[APIOperation, OperationSchemaError], None, None]:
170
177
  raise NotImplementedError
171
178
 
172
179
  def get_strategies_from_examples(self, operation: APIOperation) -> List[SearchStrategy[Case]]:
@@ -193,10 +200,11 @@ class BaseSchema(Mapping):
193
200
  settings: Optional[hypothesis.settings] = None,
194
201
  seed: Optional[int] = None,
195
202
  as_strategy_kwargs: Optional[Dict[str, Any]] = None,
203
+ hooks: Optional[HookDispatcher] = None,
196
204
  _given_kwargs: Optional[Dict[str, GivenInput]] = None,
197
- ) -> Generator[Result[Tuple[APIOperation, Callable], InvalidSchema], None, None]:
205
+ ) -> Generator[Result[Tuple[APIOperation, Callable], OperationSchemaError], None, None]:
198
206
  """Generate all operations and Hypothesis tests for them."""
199
- for result in self.get_all_operations():
207
+ for result in self.get_all_operations(hooks=hooks):
200
208
  if isinstance(result, Ok):
201
209
  test = create_test(
202
210
  operation=result.ok(),
@@ -276,6 +284,7 @@ class BaseSchema(Mapping):
276
284
  data_generation_methods: Union[DataGenerationMethodInput, NotSet] = NOT_SET,
277
285
  code_sample_style: Union[CodeSampleStyle, NotSet] = NOT_SET,
278
286
  rate_limiter: Optional[Limiter] = NOT_SET,
287
+ sanitize_output: Optional[Union[bool, NotSet]] = NOT_SET,
279
288
  ) -> "BaseSchema":
280
289
  if base_url is NOT_SET:
281
290
  base_url = self.base_url
@@ -303,6 +312,8 @@ class BaseSchema(Mapping):
303
312
  code_sample_style = self.code_sample_style
304
313
  if rate_limiter is NOT_SET:
305
314
  rate_limiter = self.rate_limiter
315
+ if sanitize_output is NOT_SET:
316
+ sanitize_output = self.sanitize_output
306
317
 
307
318
  return self.__class__(
308
319
  self.raw_schema,
@@ -321,6 +332,7 @@ class BaseSchema(Mapping):
321
332
  data_generation_methods=data_generation_methods, # type: ignore
322
333
  code_sample_style=code_sample_style, # type: ignore
323
334
  rate_limiter=rate_limiter, # type: ignore
335
+ sanitize_output=sanitize_output, # type: ignore
324
336
  )
325
337
 
326
338
  def get_local_hook_dispatcher(self) -> Optional[HookDispatcher]:
@@ -398,9 +410,12 @@ class BaseSchema(Mapping):
398
410
  return self.rate_limiter.ratelimit(label, delay=True, max_delay=0)
399
411
  return nullcontext()
400
412
 
413
+ def _get_payload_schema(self, definition: Dict[str, Any], media_type: str) -> Optional[Dict[str, Any]]:
414
+ raise NotImplementedError
415
+
401
416
 
402
417
  def operations_to_dict(
403
- operations: Generator[Result[APIOperation, InvalidSchema], None, None]
418
+ operations: Generator[Result[APIOperation, OperationSchemaError], None, None]
404
419
  ) -> Dict[str, MethodsDict]:
405
420
  output: Dict[str, MethodsDict] = {}
406
421
  for result in operations:
@@ -2,17 +2,17 @@ import binascii
2
2
  import os
3
3
  from dataclasses import dataclass
4
4
  from io import BytesIO
5
- from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, Generator, Optional, Type
5
+ from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, Generator, Optional, Type, cast
6
6
 
7
7
  import yaml
8
8
  from typing_extensions import Protocol, runtime_checkable
9
9
 
10
- from .utils import is_json_media_type, is_plain_text_media_type, parse_content_type
10
+ from ._xml import _to_xml
11
+ from .utils import is_json_media_type, is_plain_text_media_type, is_xml_media_type, parse_content_type
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from .models import Case
14
15
 
15
-
16
16
  try:
17
17
  from yaml import CSafeDumper as SafeDumper
18
18
  except ImportError:
@@ -31,6 +31,24 @@ class SerializerContext:
31
31
 
32
32
  case: "Case"
33
33
 
34
+ @property
35
+ def media_type(self) -> str:
36
+ # `media_type` is a string, otherwise we won't serialize anything
37
+ return cast(str, self.case.media_type)
38
+
39
+ # Note on type casting below.
40
+ # If we serialize data, then there should be non-empty definition for it in the first place
41
+ # Therefore `schema` is never `None` if called from here. However, `APIOperation.get_raw_payload_schema` is
42
+ # generic and can be called from other places where it may return `None`
43
+
44
+ def get_raw_payload_schema(self) -> Dict[str, Any]:
45
+ schema = self.case.operation.get_raw_payload_schema(self.media_type)
46
+ return cast(Dict[str, Any], schema)
47
+
48
+ def get_resolved_payload_schema(self) -> Dict[str, Any]:
49
+ schema = self.case.operation.get_resolved_payload_schema(self.media_type)
50
+ return cast(Dict[str, Any], schema)
51
+
34
52
 
35
53
  @runtime_checkable
36
54
  class Serializer(Protocol):
@@ -125,6 +143,15 @@ class YAMLSerializer:
125
143
  return _to_yaml(value)
126
144
 
127
145
 
146
+ @register("application/xml")
147
+ class XMLSerializer:
148
+ def as_requests(self, context: SerializerContext, value: Any) -> Dict[str, Any]:
149
+ return _to_xml(value, context.get_raw_payload_schema(), context.get_resolved_payload_schema())
150
+
151
+ def as_werkzeug(self, context: SerializerContext, value: Any) -> Dict[str, Any]:
152
+ return _to_xml(value, context.get_raw_payload_schema(), context.get_resolved_payload_schema())
153
+
154
+
128
155
  def _should_coerce_to_bytes(item: Any) -> bool:
129
156
  """Whether the item should be converted to bytes."""
130
157
  # These types are OK in forms, others should be coerced to bytes
@@ -249,4 +276,6 @@ def get(media_type: str) -> Optional[Type[Serializer]]:
249
276
  media_type = "application/json"
250
277
  if is_plain_text_media_type(media_type):
251
278
  media_type = "text/plain"
279
+ if is_xml_media_type(media_type):
280
+ media_type = "application/xml"
252
281
  return SERIALIZERS.get(media_type)
@@ -81,7 +81,6 @@ def serialize_after_execution(event: events.AfterExecution) -> Optional[Dict[str
81
81
  {
82
82
  "exception": error.exception,
83
83
  "exception_with_traceback": error.exception_with_traceback,
84
- "example": None if error.example is None else _serialize_case(error.example),
85
84
  }
86
85
  for error in event.result.errors
87
86
  ],
@@ -95,8 +94,13 @@ def serialize_interrupted(_: events.Interrupted) -> Optional[Dict[str, Any]]:
95
94
 
96
95
  def serialize_internal_error(event: events.InternalError) -> Optional[Dict[str, Any]]:
97
96
  return {
97
+ "type": event.type.value,
98
+ "subtype": event.subtype.value if event.subtype else event.subtype,
99
+ "title": event.title,
98
100
  "message": event.message,
101
+ "extras": event.extras,
99
102
  "exception_type": event.exception_type,
103
+ "exception": event.exception,
100
104
  "exception_with_traceback": event.exception_with_traceback,
101
105
  }
102
106
 
@@ -1,4 +1,5 @@
1
1
  import pathlib
2
+ from json import JSONDecodeError
2
3
  from typing import IO, Any, Callable, Dict, Optional, Union, cast
3
4
 
4
5
  import backoff
@@ -11,12 +12,14 @@ from starlette_testclient import TestClient as ASGIClient
11
12
  from werkzeug import Client
12
13
  from yarl import URL
13
14
 
14
- from ...constants import DEFAULT_DATA_GENERATION_METHODS, WAIT_FOR_SCHEMA_INTERVAL, CodeSampleStyle
15
- from ...exceptions import HTTPError
15
+ from ...code_samples import CodeSampleStyle
16
+ from ...constants import DEFAULT_DATA_GENERATION_METHODS, WAIT_FOR_SCHEMA_INTERVAL
17
+ from ...exceptions import SchemaError, SchemaErrorType
16
18
  from ...hooks import HookContext, dispatch
19
+ from ...loaders import load_schema_from_url
17
20
  from ...throttling import build_limiter
18
21
  from ...types import DataGenerationMethodInput, PathLike
19
- from ...utils import WSGIResponse, prepare_data_generation_methods, require_relative_url, setup_headers
22
+ from ...utils import GenericResponse, WSGIResponse, prepare_data_generation_methods, require_relative_url, setup_headers
20
23
  from .schemas import GraphQLSchema
21
24
 
22
25
  INTROSPECTION_QUERY = graphql.get_introspection_query()
@@ -32,6 +35,7 @@ def from_path(
32
35
  code_sample_style: str = CodeSampleStyle.default().name,
33
36
  rate_limit: Optional[str] = None,
34
37
  encoding: str = "utf8",
38
+ sanitize_output: bool = True,
35
39
  ) -> GraphQLSchema:
36
40
  """Load GraphQL schema via a file from an OS path.
37
41
 
@@ -47,9 +51,24 @@ def from_path(
47
51
  code_sample_style=code_sample_style,
48
52
  location=pathlib.Path(path).absolute().as_uri(),
49
53
  rate_limit=rate_limit,
54
+ sanitize_output=sanitize_output,
50
55
  )
51
56
 
52
57
 
58
+ def extract_schema_from_response(response: GenericResponse) -> Dict[str, Any]:
59
+ try:
60
+ if isinstance(response, requests.Response):
61
+ decoded = response.json()
62
+ else:
63
+ decoded = response.json
64
+ except JSONDecodeError as exc:
65
+ raise SchemaError(
66
+ SchemaErrorType.UNEXPECTED_CONTENT_TYPE,
67
+ "Received unsupported content while expecting a JSON payload for GraphQL",
68
+ ) from exc
69
+ return decoded
70
+
71
+
53
72
  def from_url(
54
73
  url: str,
55
74
  *,
@@ -60,6 +79,7 @@ def from_url(
60
79
  code_sample_style: str = CodeSampleStyle.default().name,
61
80
  wait_for_schema: Optional[float] = None,
62
81
  rate_limit: Optional[str] = None,
82
+ sanitize_output: bool = True,
63
83
  **kwargs: Any,
64
84
  ) -> GraphQLSchema:
65
85
  """Load GraphQL schema from the network.
@@ -91,17 +111,18 @@ def from_url(
91
111
 
92
112
  else:
93
113
  _load_schema = requests.post
94
- response = _load_schema(url, **kwargs)
95
- HTTPError.raise_for_status(response)
96
- decoded = response.json()
114
+
115
+ response = load_schema_from_url(lambda: _load_schema(url, **kwargs))
116
+ raw_schema = extract_schema_from_response(response)
97
117
  return from_dict(
98
- raw_schema=decoded["data"],
118
+ raw_schema=raw_schema,
99
119
  location=url,
100
120
  base_url=base_url,
101
121
  app=app,
102
122
  data_generation_methods=data_generation_methods,
103
123
  code_sample_style=code_sample_style,
104
124
  rate_limit=rate_limit,
125
+ sanitize_output=sanitize_output,
105
126
  )
106
127
 
107
128
 
@@ -114,6 +135,7 @@ def from_file(
114
135
  code_sample_style: str = CodeSampleStyle.default().name,
115
136
  location: Optional[str] = None,
116
137
  rate_limit: Optional[str] = None,
138
+ sanitize_output: bool = True,
117
139
  ) -> GraphQLSchema:
118
140
  """Load GraphQL schema from a file descriptor or a string.
119
141
 
@@ -140,6 +162,7 @@ def from_file(
140
162
  code_sample_style=code_sample_style,
141
163
  location=location,
142
164
  rate_limit=rate_limit,
165
+ sanitize_output=sanitize_output,
143
166
  )
144
167
 
145
168
 
@@ -152,6 +175,7 @@ def from_dict(
152
175
  data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
153
176
  code_sample_style: str = CodeSampleStyle.default().name,
154
177
  rate_limit: Optional[str] = None,
178
+ sanitize_output: bool = True,
155
179
  ) -> GraphQLSchema:
156
180
  """Load GraphQL schema from a Python dictionary.
157
181
 
@@ -163,6 +187,8 @@ def from_dict(
163
187
  """
164
188
  _code_sample_style = CodeSampleStyle.from_str(code_sample_style)
165
189
  hook_context = HookContext()
190
+ if "data" in raw_schema:
191
+ raw_schema = raw_schema["data"]
166
192
  dispatch("before_load_schema", hook_context, raw_schema)
167
193
  rate_limiter: Optional[Limiter] = None
168
194
  if rate_limit is not None:
@@ -175,6 +201,7 @@ def from_dict(
175
201
  data_generation_methods=prepare_data_generation_methods(data_generation_methods),
176
202
  code_sample_style=_code_sample_style,
177
203
  rate_limiter=rate_limiter,
204
+ sanitize_output=sanitize_output,
178
205
  ) # type: ignore
179
206
  dispatch("after_load_schema", hook_context, instance)
180
207
  return instance
@@ -188,6 +215,7 @@ def from_wsgi(
188
215
  data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
189
216
  code_sample_style: str = CodeSampleStyle.default().name,
190
217
  rate_limit: Optional[str] = None,
218
+ sanitize_output: bool = True,
191
219
  **kwargs: Any,
192
220
  ) -> GraphQLSchema:
193
221
  """Load GraphQL schema from a WSGI app.
@@ -201,16 +229,17 @@ def from_wsgi(
201
229
  setup_headers(kwargs)
202
230
  kwargs.setdefault("json", {"query": INTROSPECTION_QUERY})
203
231
  client = Client(app, WSGIResponse)
204
- response = client.post(schema_path, **kwargs)
205
- HTTPError.check_response(response, schema_path)
232
+ response = load_schema_from_url(lambda: client.post(schema_path, **kwargs))
233
+ raw_schema = extract_schema_from_response(response)
206
234
  return from_dict(
207
- raw_schema=response.json["data"],
235
+ raw_schema=raw_schema,
208
236
  location=schema_path,
209
237
  base_url=base_url,
210
238
  app=app,
211
239
  data_generation_methods=data_generation_methods,
212
240
  code_sample_style=code_sample_style,
213
241
  rate_limit=rate_limit,
242
+ sanitize_output=sanitize_output,
214
243
  )
215
244
 
216
245
 
@@ -222,6 +251,7 @@ def from_asgi(
222
251
  data_generation_methods: DataGenerationMethodInput = DEFAULT_DATA_GENERATION_METHODS,
223
252
  code_sample_style: str = CodeSampleStyle.default().name,
224
253
  rate_limit: Optional[str] = None,
254
+ sanitize_output: bool = True,
225
255
  **kwargs: Any,
226
256
  ) -> GraphQLSchema:
227
257
  """Load GraphQL schema from an ASGI app.
@@ -234,16 +264,17 @@ def from_asgi(
234
264
  setup_headers(kwargs)
235
265
  kwargs.setdefault("json", {"query": INTROSPECTION_QUERY})
236
266
  client = ASGIClient(app)
237
- response = client.post(schema_path, **kwargs)
238
- HTTPError.check_response(response, schema_path)
267
+ response = load_schema_from_url(lambda: client.post(schema_path, **kwargs))
268
+ raw_schema = extract_schema_from_response(response)
239
269
  return from_dict(
240
- response.json()["data"],
270
+ raw_schema=raw_schema,
241
271
  location=schema_path,
242
272
  base_url=base_url,
243
273
  app=app,
244
274
  data_generation_methods=data_generation_methods,
245
275
  code_sample_style=code_sample_style,
246
276
  rate_limit=rate_limit,
277
+ sanitize_output=sanitize_output,
247
278
  )
248
279
 
249
280
 
@@ -15,8 +15,14 @@ from ... import auths
15
15
  from ...auths import AuthStorage
16
16
  from ...checks import not_a_server_error
17
17
  from ...constants import DataGenerationMethod
18
- from ...exceptions import InvalidSchema
19
- from ...hooks import HookDispatcher
18
+ from ...exceptions import OperationSchemaError
19
+ from ...hooks import (
20
+ GLOBAL_HOOK_DISPATCHER,
21
+ HookContext,
22
+ HookDispatcher,
23
+ apply_to_all_dispatchers,
24
+ should_skip_operation,
25
+ )
20
26
  from ...models import APIOperation, Case, CheckFunction, OperationDefinition
21
27
  from ...schemas import BaseSchema
22
28
  from ...stateful import Stateful, StatefulTest
@@ -92,6 +98,14 @@ class GraphQLOperationDefinition(OperationDefinition):
92
98
  type_: graphql.GraphQLType
93
99
  root_type: RootType
94
100
 
101
+ @property
102
+ def is_query(self) -> bool:
103
+ return self.root_type == RootType.QUERY
104
+
105
+ @property
106
+ def is_mutation(self) -> bool:
107
+ return self.root_type == RootType.MUTATION
108
+
95
109
 
96
110
  @dataclass
97
111
  class GraphQLSchema(BaseSchema):
@@ -130,7 +144,9 @@ class GraphQLSchema(BaseSchema):
130
144
  total += len(type_def["fields"])
131
145
  return total
132
146
 
133
- def get_all_operations(self) -> Generator[Result[APIOperation, InvalidSchema], None, None]:
147
+ def get_all_operations(
148
+ self, hooks: Optional[HookDispatcher] = None
149
+ ) -> Generator[Result[APIOperation, OperationSchemaError], None, None]:
134
150
  schema = self.client_schema
135
151
  for root_type, operation_type in (
136
152
  (RootType.QUERY, schema.query_type),
@@ -139,27 +155,33 @@ class GraphQLSchema(BaseSchema):
139
155
  if operation_type is None:
140
156
  continue
141
157
  for field_name, definition in operation_type.fields.items():
142
- yield Ok(
143
- APIOperation(
144
- base_url=self.get_base_url(),
145
- path=self.base_path,
146
- verbose_name=f"{operation_type.name}.{field_name}",
147
- method="POST",
148
- app=self.app,
149
- schema=self,
150
- # Parameters are not yet supported
151
- definition=GraphQLOperationDefinition(
152
- raw=definition,
153
- resolved=definition,
154
- scope="",
155
- parameters=[],
156
- type_=operation_type,
157
- field_name=field_name,
158
- root_type=root_type,
159
- ),
160
- case_cls=GraphQLCase,
161
- )
158
+ operation: APIOperation = APIOperation(
159
+ base_url=self.get_base_url(),
160
+ path=self.base_path,
161
+ verbose_name=f"{operation_type.name}.{field_name}",
162
+ method="POST",
163
+ app=self.app,
164
+ schema=self,
165
+ # Parameters are not yet supported
166
+ definition=GraphQLOperationDefinition(
167
+ raw=definition,
168
+ resolved=definition,
169
+ scope="",
170
+ parameters=[],
171
+ type_=operation_type,
172
+ field_name=field_name,
173
+ root_type=root_type,
174
+ ),
175
+ case_cls=GraphQLCase,
162
176
  )
177
+ context = HookContext(operation=operation)
178
+ if (
179
+ should_skip_operation(GLOBAL_HOOK_DISPATCHER, context)
180
+ or should_skip_operation(self.hooks, context)
181
+ or (hooks and should_skip_operation(hooks, context))
182
+ ):
183
+ continue
184
+ yield Ok(operation)
163
185
 
164
186
  def get_case_strategy(
165
187
  self,
@@ -220,11 +242,16 @@ def get_case_strategy(
220
242
  **kwargs: Any,
221
243
  ) -> Any:
222
244
  definition = cast(GraphQLOperationDefinition, operation.definition)
223
- strategy = {
245
+ strategy_factory = {
224
246
  RootType.QUERY: gql_st.queries,
225
247
  RootType.MUTATION: gql_st.mutations,
226
248
  }[definition.root_type]
227
- body = draw(strategy(client_schema, fields=[definition.field_name], custom_scalars=CUSTOM_SCALARS))
249
+ hook_context = HookContext(operation)
250
+ strategy = strategy_factory(
251
+ client_schema, fields=[definition.field_name], custom_scalars=CUSTOM_SCALARS, print_ast=_noop # type: ignore
252
+ )
253
+ strategy = apply_to_all_dispatchers(operation, hook_context, hooks, strategy, "body").map(graphql.print_ast)
254
+ body = draw(strategy)
228
255
  instance = GraphQLCase(body=body, operation=operation, data_generation_method=data_generation_method) # type: ignore
229
256
  context = auths.AuthContext(
230
257
  operation=operation,
@@ -232,3 +259,7 @@ def get_case_strategy(
232
259
  )
233
260
  auths.set_on_case(instance, context, auth_storage)
234
261
  return instance
262
+
263
+
264
+ def _noop(node: graphql.Node) -> graphql.Node:
265
+ return node
@@ -14,8 +14,8 @@ from requests.structures import CaseInsensitiveDict
14
14
 
15
15
  from ... import auths, serializers, utils
16
16
  from ...constants import DataGenerationMethod
17
- from ...exceptions import InvalidSchema, SerializationNotPossible
18
- from ...hooks import GLOBAL_HOOK_DISPATCHER, HookContext, HookDispatcher
17
+ from ...exceptions import OperationSchemaError, SerializationNotPossible
18
+ from ...hooks import HookContext, HookDispatcher, apply_to_all_dispatchers
19
19
  from ...models import APIOperation, Case, cant_serialize
20
20
  from ...types import NotSet
21
21
  from ...utils import NOT_SET, compose, fast_deepcopy, skip
@@ -172,9 +172,9 @@ def get_case_strategy(
172
172
  # Other media types are possible - avoid choosing this media type in the future
173
173
  cant_serialize(parameter.media_type)
174
174
  media_type = draw(st.sampled_from(possible_media_types))
175
- body_ = ValueContainer(value=draw(strategy), generator=body_generator)
175
+ body_ = ValueContainer(value=draw(strategy), location="body", generator=body_generator)
176
176
  else:
177
- body_ = ValueContainer(value=body, generator=None)
177
+ body_ = ValueContainer(value=body, location="body", generator=None)
178
178
  else:
179
179
  media_types = operation.get_request_payload_content_types() or ["application/json"]
180
180
  # Take the first available media type.
@@ -183,10 +183,10 @@ def get_case_strategy(
183
183
  # - On Open API 3.0, media types are explicit, and each example has it.
184
184
  # We can pass `OpenAPIBody.media_type` here from the examples handling code.
185
185
  media_type = media_types[0]
186
- body_ = ValueContainer(value=body, generator=None)
186
+ body_ = ValueContainer(value=body, location="body", generator=None)
187
187
 
188
188
  if operation.schema.validate_schema and operation.method.upper() == "GET" and operation.body:
189
- raise InvalidSchema("Body parameters are defined for GET request.")
189
+ raise OperationSchemaError("Body parameters are defined for GET request.")
190
190
  # If we need to generate negative cases but no generated values were negated, then skip the whole test
191
191
  if generator.is_negative and not any_negated_values([query_, cookies_, headers_, path_parameters_, body_]):
192
192
  skip(operation.verbose_name)
@@ -265,12 +265,13 @@ class ValueContainer:
265
265
  """Container for a value generated by a data generator or explicitly provided."""
266
266
 
267
267
  value: Any
268
+ location: str
268
269
  generator: Optional[DataGenerationMethod]
269
270
 
270
271
  @property
271
272
  def is_generated(self) -> bool:
272
273
  """If value was generated."""
273
- return self.value is not None and self.generator is not None
274
+ return self.generator is not None and (self.location == "body" or self.value is not None)
274
275
 
275
276
 
276
277
  def any_negated_values(values: List[ValueContainer]) -> bool:
@@ -307,7 +308,7 @@ def generate_parameter(
307
308
  # When we pass `explicit`, then its parts are excluded from generation of the final value
308
309
  # If the final value is the same, then other parameters were generated at all
309
310
  used_generator = None
310
- return ValueContainer(value=value, generator=used_generator)
311
+ return ValueContainer(value=value, location=location, generator=used_generator)
311
312
 
312
313
 
313
314
  def can_negate_path_parameters(operation: APIOperation) -> bool:
@@ -499,22 +500,9 @@ def apply_hooks(
499
500
  strategy: st.SearchStrategy,
500
501
  location: str,
501
502
  ) -> st.SearchStrategy:
502
- """Apply all `before_generate_` hooks related to the given location."""
503
- strategy = _apply_hooks(context, GLOBAL_HOOK_DISPATCHER, strategy, location)
504
- strategy = _apply_hooks(context, operation.schema.hooks, strategy, location)
505
- if hooks is not None:
506
- strategy = _apply_hooks(context, hooks, strategy, location)
507
- return strategy
508
-
509
-
510
- def _apply_hooks(
511
- context: HookContext, hooks: HookDispatcher, strategy: st.SearchStrategy, location: str
512
- ) -> st.SearchStrategy:
513
- """Apply all `before_generate_` hooks related to the given location & dispatcher."""
503
+ """Apply all hooks related to the given location."""
514
504
  container = LOCATION_TO_CONTAINER[location]
515
- for hook in hooks.get_all_by_name(f"before_generate_{container}"):
516
- strategy = hook(context, strategy)
517
- return strategy
505
+ return apply_to_all_dispatchers(operation, context, hooks, strategy, container)
518
506
 
519
507
 
520
508
  def clear_cache() -> None: