schemathesis 4.0.0a11__py3-none-any.whl → 4.0.0a12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. schemathesis/__init__.py +28 -25
  2. schemathesis/auths.py +65 -24
  3. schemathesis/checks.py +60 -36
  4. schemathesis/cli/commands/run/__init__.py +23 -21
  5. schemathesis/cli/commands/run/context.py +6 -1
  6. schemathesis/cli/commands/run/events.py +7 -1
  7. schemathesis/cli/commands/run/executor.py +12 -7
  8. schemathesis/cli/commands/run/handlers/output.py +175 -80
  9. schemathesis/cli/commands/run/validation.py +21 -6
  10. schemathesis/config/__init__.py +2 -1
  11. schemathesis/config/_generation.py +12 -13
  12. schemathesis/config/_operations.py +14 -0
  13. schemathesis/config/_phases.py +41 -5
  14. schemathesis/config/_projects.py +28 -0
  15. schemathesis/config/_report.py +6 -2
  16. schemathesis/config/_warnings.py +25 -0
  17. schemathesis/config/schema.json +49 -1
  18. schemathesis/core/errors.py +5 -2
  19. schemathesis/core/transport.py +36 -1
  20. schemathesis/engine/context.py +1 -0
  21. schemathesis/engine/errors.py +60 -1
  22. schemathesis/engine/events.py +10 -2
  23. schemathesis/engine/phases/probes.py +3 -0
  24. schemathesis/engine/phases/stateful/__init__.py +2 -1
  25. schemathesis/engine/phases/stateful/_executor.py +38 -5
  26. schemathesis/engine/phases/stateful/context.py +2 -2
  27. schemathesis/engine/phases/unit/_executor.py +36 -7
  28. schemathesis/generation/__init__.py +0 -3
  29. schemathesis/generation/case.py +1 -0
  30. schemathesis/generation/coverage.py +1 -1
  31. schemathesis/generation/hypothesis/builder.py +31 -7
  32. schemathesis/generation/metrics.py +93 -0
  33. schemathesis/generation/modes.py +0 -8
  34. schemathesis/generation/stateful/__init__.py +4 -0
  35. schemathesis/generation/stateful/state_machine.py +1 -0
  36. schemathesis/graphql/loaders.py +138 -4
  37. schemathesis/hooks.py +62 -35
  38. schemathesis/openapi/loaders.py +120 -4
  39. schemathesis/pytest/loaders.py +24 -0
  40. schemathesis/pytest/plugin.py +22 -0
  41. schemathesis/schemas.py +9 -6
  42. schemathesis/specs/graphql/scalars.py +37 -3
  43. schemathesis/specs/graphql/schemas.py +12 -3
  44. schemathesis/specs/openapi/_hypothesis.py +14 -20
  45. schemathesis/specs/openapi/checks.py +21 -18
  46. schemathesis/specs/openapi/formats.py +30 -3
  47. schemathesis/specs/openapi/media_types.py +44 -1
  48. schemathesis/specs/openapi/schemas.py +8 -2
  49. schemathesis/specs/openapi/stateful/__init__.py +2 -1
  50. schemathesis/transport/__init__.py +54 -16
  51. schemathesis/transport/prepare.py +31 -7
  52. schemathesis/transport/requests.py +9 -8
  53. schemathesis/transport/wsgi.py +8 -8
  54. {schemathesis-4.0.0a11.dist-info → schemathesis-4.0.0a12.dist-info}/METADATA +44 -90
  55. {schemathesis-4.0.0a11.dist-info → schemathesis-4.0.0a12.dist-info}/RECORD +58 -60
  56. schemathesis/contrib/__init__.py +0 -9
  57. schemathesis/contrib/openapi/__init__.py +0 -9
  58. schemathesis/contrib/openapi/fill_missing_examples.py +0 -20
  59. schemathesis/generation/targets.py +0 -69
  60. {schemathesis-4.0.0a11.dist-info → schemathesis-4.0.0a12.dist-info}/WHEEL +0 -0
  61. {schemathesis-4.0.0a11.dist-info → schemathesis-4.0.0a12.dist-info}/entry_points.txt +0 -0
  62. {schemathesis-4.0.0a11.dist-info → schemathesis-4.0.0a12.dist-info}/licenses/LICENSE +0 -0
@@ -115,6 +115,15 @@ class GraphQLSchema(BaseSchema):
115
115
  return map
116
116
  raise KeyError(key)
117
117
 
118
+ def find_operation_by_label(self, label: str) -> APIOperation | None:
119
+ if label.startswith(("Query.", "Mutation.")):
120
+ ty, field = label.split(".", maxsplit=1)
121
+ try:
122
+ return self[ty][field]
123
+ except KeyError:
124
+ return None
125
+ return None
126
+
118
127
  def on_missing_operation(self, item: str, exc: KeyError) -> NoReturn:
119
128
  raw_schema = self.raw_schema["__schema"]
120
129
  type_names = [type_def["name"] for type_def in raw_schema.get("types", [])]
@@ -223,7 +232,7 @@ class GraphQLSchema(BaseSchema):
223
232
  operation: APIOperation,
224
233
  hooks: HookDispatcher | None = None,
225
234
  auth_storage: AuthStorage | None = None,
226
- generation_mode: GenerationMode = GenerationMode.default(),
235
+ generation_mode: GenerationMode = GenerationMode.POSITIVE,
227
236
  **kwargs: Any,
228
237
  ) -> SearchStrategy:
229
238
  return graphql_cases(
@@ -321,7 +330,7 @@ def graphql_cases(
321
330
  operation: APIOperation,
322
331
  hooks: HookDispatcher | None = None,
323
332
  auth_storage: auths.AuthStorage | None = None,
324
- generation_mode: GenerationMode = GenerationMode.default(),
333
+ generation_mode: GenerationMode = GenerationMode.POSITIVE,
325
334
  path_parameters: NotSet | dict[str, Any] = NOT_SET,
326
335
  headers: NotSet | dict[str, Any] = NOT_SET,
327
336
  cookies: NotSet | dict[str, Any] = NOT_SET,
@@ -336,7 +345,7 @@ def graphql_cases(
336
345
  RootType.QUERY: gql_st.queries,
337
346
  RootType.MUTATION: gql_st.mutations,
338
347
  }[definition.root_type]
339
- hook_context = HookContext(operation)
348
+ hook_context = HookContext(operation=operation)
340
349
  custom_scalars = {**get_extra_scalar_strategies(), **CUSTOM_SCALARS}
341
350
  generation = operation.schema.config.generation_for(operation=operation, phase="fuzzing")
342
351
  strategy = strategy_factory(
@@ -52,7 +52,7 @@ def openapi_cases(
52
52
  operation: APIOperation,
53
53
  hooks: HookDispatcher | None = None,
54
54
  auth_storage: auths.AuthStorage | None = None,
55
- generation_mode: GenerationMode = GenerationMode.default(),
55
+ generation_mode: GenerationMode = GenerationMode.POSITIVE,
56
56
  path_parameters: NotSet | dict[str, Any] = NOT_SET,
57
57
  headers: NotSet | dict[str, Any] = NOT_SET,
58
58
  cookies: NotSet | dict[str, Any] = NOT_SET,
@@ -80,18 +80,14 @@ def openapi_cases(
80
80
  phase_name = "stateful" if __is_stateful_phase else phase.value
81
81
  generation_config = operation.schema.config.generation_for(operation=operation, phase=phase_name)
82
82
 
83
- context = HookContext(operation)
83
+ ctx = HookContext(operation=operation)
84
84
 
85
85
  path_parameters_ = generate_parameter(
86
- "path", path_parameters, operation, draw, context, hooks, generation_mode, generation_config
86
+ "path", path_parameters, operation, draw, ctx, hooks, generation_mode, generation_config
87
87
  )
88
- headers_ = generate_parameter(
89
- "header", headers, operation, draw, context, hooks, generation_mode, generation_config
90
- )
91
- cookies_ = generate_parameter(
92
- "cookie", cookies, operation, draw, context, hooks, generation_mode, generation_config
93
- )
94
- query_ = generate_parameter("query", query, operation, draw, context, hooks, generation_mode, generation_config)
88
+ headers_ = generate_parameter("header", headers, operation, draw, ctx, hooks, generation_mode, generation_config)
89
+ cookies_ = generate_parameter("cookie", cookies, operation, draw, ctx, hooks, generation_mode, generation_config)
90
+ query_ = generate_parameter("query", query, operation, draw, ctx, hooks, generation_mode, generation_config)
95
91
 
96
92
  if body is NOT_SET:
97
93
  if operation.body:
@@ -108,7 +104,7 @@ def openapi_cases(
108
104
  candidates = operation.body.items
109
105
  parameter = draw(st.sampled_from(candidates))
110
106
  strategy = _get_body_strategy(parameter, strategy_factory, operation, generation_config)
111
- strategy = apply_hooks(operation, context, hooks, strategy, "body")
107
+ strategy = apply_hooks(operation, ctx, hooks, strategy, "body")
112
108
  # Parameter may have a wildcard media type. In this case, choose any supported one
113
109
  possible_media_types = sorted(
114
110
  operation.schema.transport.get_matching_media_types(parameter.media_type), key=lambda x: x[0]
@@ -219,7 +215,7 @@ def get_parameters_value(
219
215
  location: str,
220
216
  draw: Callable,
221
217
  operation: APIOperation,
222
- context: HookContext,
218
+ ctx: HookContext,
223
219
  hooks: HookDispatcher | None,
224
220
  strategy_factory: StrategyFactory,
225
221
  generation_config: GenerationConfig,
@@ -231,10 +227,10 @@ def get_parameters_value(
231
227
  """
232
228
  if isinstance(value, NotSet) or not value:
233
229
  strategy = get_parameters_strategy(operation, strategy_factory, location, generation_config)
234
- strategy = apply_hooks(operation, context, hooks, strategy, location)
230
+ strategy = apply_hooks(operation, ctx, hooks, strategy, location)
235
231
  return draw(strategy)
236
232
  strategy = get_parameters_strategy(operation, strategy_factory, location, generation_config, exclude=value.keys())
237
- strategy = apply_hooks(operation, context, hooks, strategy, location)
233
+ strategy = apply_hooks(operation, ctx, hooks, strategy, location)
238
234
  new = draw(strategy)
239
235
  if new is not None:
240
236
  copied = deepclone(value)
@@ -272,7 +268,7 @@ def generate_parameter(
272
268
  explicit: NotSet | dict[str, Any],
273
269
  operation: APIOperation,
274
270
  draw: Callable,
275
- context: HookContext,
271
+ ctx: HookContext,
276
272
  hooks: HookDispatcher | None,
277
273
  generator: GenerationMode,
278
274
  generation_config: GenerationConfig,
@@ -291,9 +287,7 @@ def generate_parameter(
291
287
  generator = GenerationMode.POSITIVE
292
288
  else:
293
289
  strategy_factory = GENERATOR_MODE_TO_STRATEGY_FACTORY[generator]
294
- value = get_parameters_value(
295
- explicit, location, draw, operation, context, hooks, strategy_factory, generation_config
296
- )
290
+ value = get_parameters_value(explicit, location, draw, operation, ctx, hooks, strategy_factory, generation_config)
297
291
  used_generator: GenerationMode | None = generator
298
292
  if value == explicit:
299
293
  # When we pass `explicit`, then its parts are excluded from generation of the final value
@@ -494,11 +488,11 @@ def quote_all(parameters: dict[str, Any]) -> dict[str, Any]:
494
488
 
495
489
  def apply_hooks(
496
490
  operation: APIOperation,
497
- context: HookContext,
491
+ ctx: HookContext,
498
492
  hooks: HookDispatcher | None,
499
493
  strategy: st.SearchStrategy,
500
494
  location: str,
501
495
  ) -> st.SearchStrategy:
502
496
  """Apply all hooks related to the given location."""
503
497
  container = LOCATION_TO_CONTAINER[location]
504
- return apply_to_all_dispatchers(operation, context, hooks, strategy, container)
498
+ return apply_to_all_dispatchers(operation, ctx, hooks, strategy, container)
@@ -352,12 +352,12 @@ def use_after_free(ctx: CheckContext, response: Response, case: Case) -> bool |
352
352
  if response.status_code == 404 or response.status_code >= 500:
353
353
  return None
354
354
 
355
- for related_case in ctx.find_related(case_id=case.id):
356
- parent = ctx.find_parent(case_id=related_case.id)
355
+ for related_case in ctx._find_related(case_id=case.id):
356
+ parent = ctx._find_parent(case_id=related_case.id)
357
357
  if not parent:
358
358
  continue
359
359
 
360
- parent_response = ctx.find_response(case_id=parent.id)
360
+ parent_response = ctx._find_response(case_id=parent.id)
361
361
 
362
362
  if (
363
363
  related_case.operation.method.lower() == "delete"
@@ -395,10 +395,10 @@ def ensure_resource_availability(ctx: CheckContext, response: Response, case: Ca
395
395
  if not (400 <= response.status_code < 500):
396
396
  return None
397
397
 
398
- parent = ctx.find_parent(case_id=case.id)
398
+ parent = ctx._find_parent(case_id=case.id)
399
399
  if parent is None:
400
400
  return None
401
- parent_response = ctx.find_response(case_id=parent.id)
401
+ parent_response = ctx._find_response(case_id=parent.id)
402
402
  if parent_response is None:
403
403
  return None
404
404
 
@@ -424,8 +424,8 @@ def ensure_resource_availability(ctx: CheckContext, response: Response, case: Ca
424
424
  return None
425
425
 
426
426
  # Look for any successful DELETE operations on this resource
427
- for related_case in ctx.find_related(case_id=case.id):
428
- related_response = ctx.find_response(case_id=related_case.id)
427
+ for related_case in ctx._find_related(case_id=case.id):
428
+ related_response = ctx._find_response(case_id=related_case.id)
429
429
  if (
430
430
  related_case.operation.method.upper() == "DELETE"
431
431
  and related_response is not None
@@ -478,25 +478,25 @@ def ignored_auth(ctx: CheckContext, response: Response, case: Case) -> bool | No
478
478
  # Auth is explicitly set, it is expected to be valid
479
479
  # Check if invalid auth will give an error
480
480
  no_auth_case = remove_auth(case, security_parameters)
481
- kwargs = ctx.transport_kwargs or {}
481
+ kwargs = ctx._transport_kwargs or {}
482
482
  kwargs.copy()
483
483
  if "headers" in kwargs:
484
484
  headers = kwargs["headers"].copy()
485
485
  _remove_auth_from_explicit_headers(headers, security_parameters)
486
486
  kwargs["headers"] = headers
487
487
  kwargs.pop("session", None)
488
- ctx.record_case(parent_id=case.id, case=no_auth_case)
488
+ ctx._record_case(parent_id=case.id, case=no_auth_case)
489
489
  no_auth_response = case.operation.schema.transport.send(no_auth_case, **kwargs)
490
- ctx.record_response(case_id=no_auth_case.id, response=no_auth_response)
490
+ ctx._record_response(case_id=no_auth_case.id, response=no_auth_response)
491
491
  if no_auth_response.status_code != 401:
492
492
  _raise_no_auth_error(no_auth_response, no_auth_case, "that requires authentication")
493
493
  # Try to set invalid auth and check if it succeeds
494
494
  for parameter in security_parameters:
495
495
  invalid_auth_case = remove_auth(case, security_parameters)
496
496
  _set_auth_for_case(invalid_auth_case, parameter)
497
- ctx.record_case(parent_id=case.id, case=invalid_auth_case)
497
+ ctx._record_case(parent_id=case.id, case=invalid_auth_case)
498
498
  invalid_auth_response = case.operation.schema.transport.send(invalid_auth_case, **kwargs)
499
- ctx.record_response(case_id=invalid_auth_case.id, response=invalid_auth_response)
499
+ ctx._record_response(case_id=invalid_auth_case.id, response=invalid_auth_response)
500
500
  if invalid_auth_response.status_code != 401:
501
501
  _raise_no_auth_error(invalid_auth_response, invalid_auth_case, "with any auth")
502
502
  elif auth == AuthKind.GENERATED:
@@ -540,7 +540,7 @@ def _contains_auth(
540
540
  from requests.cookies import RequestsCookieJar
541
541
 
542
542
  # If auth comes from explicit `auth` option or a custom auth, it is always explicit
543
- if ctx.auth is not None or case._has_explicit_auth:
543
+ if ctx._auth is not None or case._has_explicit_auth:
544
544
  return AuthKind.EXPLICIT
545
545
  parsed = urlparse(request.url)
546
546
  query = parse_qs(parsed.query) # type: ignore
@@ -563,19 +563,19 @@ def _contains_auth(
563
563
  for parameter in security_parameters:
564
564
  name = parameter["name"]
565
565
  if has_header(parameter):
566
- if (ctx.headers is not None and name in ctx.headers) or (ctx.override and name in ctx.override.headers):
566
+ if (ctx._headers is not None and name in ctx._headers) or (ctx._override and name in ctx._override.headers):
567
567
  return AuthKind.EXPLICIT
568
568
  return AuthKind.GENERATED
569
569
  if has_cookie(parameter):
570
- if ctx.headers is not None and "Cookie" in ctx.headers:
571
- cookies = cast(RequestsCookieJar, ctx.headers["Cookie"]) # type: ignore
570
+ if ctx._headers is not None and "Cookie" in ctx._headers:
571
+ cookies = cast(RequestsCookieJar, ctx._headers["Cookie"]) # type: ignore
572
572
  if name in cookies:
573
573
  return AuthKind.EXPLICIT
574
- if ctx.override and name in ctx.override.cookies:
574
+ if ctx._override and name in ctx._override.cookies:
575
575
  return AuthKind.EXPLICIT
576
576
  return AuthKind.GENERATED
577
577
  if has_query(parameter):
578
- if ctx.override and name in ctx.override.query:
578
+ if ctx._override and name in ctx._override.query:
579
579
  return AuthKind.EXPLICIT
580
580
  return AuthKind.GENERATED
581
581
 
@@ -628,6 +628,9 @@ def _set_auth_for_case(case: Case, parameter: SecurityParameter) -> None:
628
628
  ):
629
629
  if parameter["in"] == location:
630
630
  container = getattr(case, attr_name, {})
631
+ # Could happen in the negative testing mode
632
+ if not isinstance(container, dict):
633
+ container = {}
631
634
  container[name] = "SCHEMATHESIS-INVALID-VALUE"
632
635
  setattr(case, attr_name, container)
633
636
 
@@ -15,10 +15,37 @@ STRING_FORMATS: dict[str, st.SearchStrategy] = {}
15
15
 
16
16
 
17
17
  def register_string_format(name: str, strategy: st.SearchStrategy) -> None:
18
- """Register a new strategy for generating data for specific string "format".
18
+ r"""Register a custom Hypothesis strategy for generating string format data.
19
+
20
+ Args:
21
+ name: String format name that matches the "format" keyword in your API schema
22
+ strategy: Hypothesis strategy to generate values for this format
23
+
24
+ Example:
25
+ ```python
26
+ import schemathesis
27
+ from hypothesis import strategies as st
28
+
29
+ # Register phone number format
30
+ phone_strategy = st.from_regex(r"\+1-\d{3}-\d{3}-\d{4}")
31
+ schemathesis.openapi.format("phone", phone_strategy)
32
+
33
+ # Register email with specific domain
34
+ email_strategy = st.from_regex(r"[a-z]+@company\.com")
35
+ schemathesis.openapi.format("company-email", email_strategy)
36
+ ```
37
+
38
+ Schema usage:
39
+ ```yaml
40
+ properties:
41
+ phone:
42
+ type: string
43
+ format: phone # Uses your phone_strategy
44
+ contact_email:
45
+ type: string
46
+ format: company-email # Uses your email_strategy
47
+ ```
19
48
 
20
- :param str name: Format name. It should correspond the one used in the API schema as the "format" keyword value.
21
- :param strategy: Hypothesis strategy you'd like to use to generate values for this format.
22
49
  """
23
50
  from hypothesis.strategies import SearchStrategy
24
51
 
@@ -15,7 +15,50 @@ MEDIA_TYPES: dict[str, st.SearchStrategy[bytes]] = {}
15
15
 
16
16
 
17
17
  def register_media_type(name: str, strategy: st.SearchStrategy[bytes], *, aliases: Collection[str] = ()) -> None:
18
- """Register a strategy for the given media type."""
18
+ r"""Register a custom Hypothesis strategy for generating media type content.
19
+
20
+ Args:
21
+ name: Media type name that matches your OpenAPI requestBody content type
22
+ strategy: Hypothesis strategy that generates bytes for this media type
23
+ aliases: Additional media type names that use the same strategy
24
+
25
+ Example:
26
+ ```python
27
+ import schemathesis
28
+ from hypothesis import strategies as st
29
+
30
+ # Register PDF file strategy
31
+ pdf_strategy = st.sampled_from([
32
+ b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\n%%EOF",
33
+ b"%PDF-1.5\n%\xe2\xe3\xcf\xd3\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\n%%EOF"
34
+ ])
35
+ schemathesis.openapi.media_type("application/pdf", pdf_strategy)
36
+
37
+ # Dynamic content generation
38
+ @st.composite
39
+ def xml_content(draw):
40
+ tag = draw(st.text(min_size=3, max_size=10))
41
+ content = draw(st.text(min_size=1, max_size=50))
42
+ return f"<?xml version='1.0'?><{tag}>{content}</{tag}>".encode()
43
+
44
+ schemathesis.openapi.media_type("application/xml", xml_content())
45
+ ```
46
+
47
+ Schema usage:
48
+ ```yaml
49
+ requestBody:
50
+ content:
51
+ application/pdf: # Uses your PDF strategy
52
+ schema:
53
+ type: string
54
+ format: binary
55
+ application/xml: # Uses your XML strategy
56
+ schema:
57
+ type: string
58
+ format: binary
59
+ ```
60
+
61
+ """
19
62
 
20
63
  @REQUESTS_TRANSPORT.serializer(name, *aliases)
21
64
  @ASGI_TRANSPORT.serializer(name, *aliases)
@@ -120,13 +120,18 @@ class BaseOpenAPISchema(BaseSchema):
120
120
  if map is not None:
121
121
  return map
122
122
  path_item = self.raw_schema.get("paths", {})[path]
123
- scope, path_item = self._resolve_path_item(path_item)
123
+ with in_scope(self.resolver, self.location or ""):
124
+ scope, path_item = self._resolve_path_item(path_item)
124
125
  self.dispatch_hook("before_process_path", HookContext(), path, path_item)
125
126
  map = APIOperationMap(self, {})
126
127
  map._data = MethodMap(map, scope, path, CaseInsensitiveDict(path_item))
127
128
  cache.insert_map(path, map)
128
129
  return map
129
130
 
131
+ def find_operation_by_label(self, label: str) -> APIOperation | None:
132
+ method, path = label.split(" ", maxsplit=1)
133
+ return self[path][method]
134
+
130
135
  def on_missing_operation(self, item: str, exc: KeyError) -> NoReturn:
131
136
  matches = get_close_matches(item, list(self))
132
137
  self._on_missing_operation(item, exc, matches)
@@ -536,7 +541,7 @@ class BaseOpenAPISchema(BaseSchema):
536
541
  operation: APIOperation,
537
542
  hooks: HookDispatcher | None = None,
538
543
  auth_storage: AuthStorage | None = None,
539
- generation_mode: GenerationMode = GenerationMode.default(),
544
+ generation_mode: GenerationMode = GenerationMode.POSITIVE,
540
545
  **kwargs: Any,
541
546
  ) -> SearchStrategy:
542
547
  return openapi_cases(
@@ -658,6 +663,7 @@ class BaseOpenAPISchema(BaseSchema):
658
663
  return jsonschema.Draft4Validator
659
664
 
660
665
  def validate_response(self, operation: APIOperation, response: Response) -> bool | None:
666
+ __tracebackhide__ = True
661
667
  responses = {str(key): value for key, value in operation.definition.raw.get("responses", {}).items()}
662
668
  status_code = str(response.status_code)
663
669
  if status_code in responses:
@@ -14,6 +14,7 @@ from schemathesis.engine.recorder import ScenarioRecorder
14
14
  from schemathesis.generation import GenerationMode
15
15
  from schemathesis.generation.case import Case
16
16
  from schemathesis.generation.hypothesis import strategies
17
+ from schemathesis.generation.stateful import STATEFUL_TESTS_LABEL
17
18
  from schemathesis.generation.stateful.state_machine import APIStateMachine, StepInput, StepOutput, _normalize_name
18
19
  from schemathesis.schemas import APIOperation
19
20
  from schemathesis.specs.openapi.stateful.control import TransitionController
@@ -32,7 +33,7 @@ class OpenAPIStateMachine(APIStateMachine):
32
33
  _transitions: ApiTransitions
33
34
 
34
35
  def __init__(self) -> None:
35
- self.recorder = ScenarioRecorder(label="Stateful tests")
36
+ self.recorder = ScenarioRecorder(label=STATEFUL_TESTS_LABEL)
36
37
  self.control = TransitionController(self._transitions)
37
38
  super().__init__()
38
39
 
@@ -2,11 +2,15 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
  from inspect import iscoroutinefunction
5
- from typing import Any, Callable, Generic, Iterator, TypeVar
5
+ from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, TypeVar, Union
6
6
 
7
7
  from schemathesis.core import media_types
8
8
  from schemathesis.core.errors import SerializationNotPossible
9
9
 
10
+ if TYPE_CHECKING:
11
+ from schemathesis.core.transport import Response
12
+ from schemathesis.generation.case import Case
13
+
10
14
 
11
15
  def get(app: Any) -> BaseTransport:
12
16
  """Get transport to send the data to the application."""
@@ -23,41 +27,43 @@ def get(app: Any) -> BaseTransport:
23
27
  return WSGI_TRANSPORT
24
28
 
25
29
 
26
- C = TypeVar("C", contravariant=True)
27
- R = TypeVar("R", covariant=True)
28
30
  S = TypeVar("S", contravariant=True)
29
31
 
30
32
 
31
33
  @dataclass
32
- class SerializationContext(Generic[C]):
33
- """Generic context for serialization process."""
34
+ class SerializationContext:
35
+ """Context object passed to serializer functions.
36
+
37
+ It provides access to the generated test case and any related metadata.
38
+ """
34
39
 
35
- case: C
40
+ case: Case
41
+ """The generated test case."""
36
42
 
37
43
  __slots__ = ("case",)
38
44
 
39
45
 
40
- Serializer = Callable[[SerializationContext[C], Any], Any]
46
+ Serializer = Callable[[SerializationContext, Any], Any]
41
47
 
42
48
 
43
- class BaseTransport(Generic[C, R, S]):
49
+ class BaseTransport(Generic[S]):
44
50
  """Base implementation with serializer registration."""
45
51
 
46
52
  def __init__(self) -> None:
47
- self._serializers: dict[str, Serializer[C]] = {}
53
+ self._serializers: dict[str, Serializer] = {}
48
54
 
49
- def serialize_case(self, case: C, **kwargs: Any) -> dict[str, Any]:
55
+ def serialize_case(self, case: Case, **kwargs: Any) -> dict[str, Any]:
50
56
  """Prepare the case for sending."""
51
57
  raise NotImplementedError
52
58
 
53
- def send(self, case: C, *, session: S | None = None, **kwargs: Any) -> R:
59
+ def send(self, case: Case, *, session: S | None = None, **kwargs: Any) -> Response:
54
60
  """Send the case using this transport."""
55
61
  raise NotImplementedError
56
62
 
57
- def serializer(self, *media_types: str) -> Callable[[Serializer[C]], Serializer[C]]:
63
+ def serializer(self, *media_types: str) -> Callable[[Serializer], Serializer]:
58
64
  """Register a serializer for given media types."""
59
65
 
60
- def decorator(func: Serializer[C]) -> Serializer[C]:
66
+ def decorator(func: Serializer) -> Serializer:
61
67
  for media_type in media_types:
62
68
  self._serializers[media_type] = func
63
69
  return func
@@ -71,10 +77,10 @@ class BaseTransport(Generic[C, R, S]):
71
77
  def _copy_serializers_from(self, transport: BaseTransport) -> None:
72
78
  self._serializers.update(transport._serializers)
73
79
 
74
- def get_first_matching_media_type(self, media_type: str) -> tuple[str, Serializer[C]] | None:
80
+ def get_first_matching_media_type(self, media_type: str) -> tuple[str, Serializer] | None:
75
81
  return next(self.get_matching_media_types(media_type), None)
76
82
 
77
- def get_matching_media_types(self, media_type: str) -> Iterator[tuple[str, Serializer[C]]]:
83
+ def get_matching_media_types(self, media_type: str) -> Iterator[tuple[str, Serializer]]:
78
84
  """Get all registered media types matching the given media type."""
79
85
  if media_type == "*/*":
80
86
  # Shortcut to avoid comparing all values
@@ -96,9 +102,41 @@ class BaseTransport(Generic[C, R, S]):
96
102
  if main in ("*", target_main) and sub in ("*", target_sub):
97
103
  yield registered_media_type, serializer
98
104
 
99
- def _get_serializer(self, input_media_type: str) -> Serializer[C]:
105
+ def _get_serializer(self, input_media_type: str) -> Serializer:
100
106
  pair = self.get_first_matching_media_type(input_media_type)
101
107
  if pair is None:
102
108
  # This media type is set manually. Otherwise, it should have been rejected during the data generation
103
109
  raise SerializationNotPossible.for_media_type(input_media_type)
104
110
  return pair[1]
111
+
112
+
113
+ _Serializer = Callable[[SerializationContext, Any], Union[bytes, None]]
114
+
115
+
116
+ def serializer(*media_types: str) -> Callable[[_Serializer], None]:
117
+ """Register a serializer for specified media types on HTTP, ASGI, and WSGI transports.
118
+
119
+ Args:
120
+ *media_types: One or more MIME types (e.g., "application/json") this serializer handles.
121
+
122
+ Returns:
123
+ A decorator that wraps a function taking `(ctx: SerializationContext, value: Any)` and returning `bytes` for serialized body and `None` for omitting request body.
124
+
125
+ """
126
+
127
+ def register(func: _Serializer) -> None:
128
+ from schemathesis.transport.asgi import ASGI_TRANSPORT
129
+ from schemathesis.transport.requests import REQUESTS_TRANSPORT
130
+ from schemathesis.transport.wsgi import WSGI_TRANSPORT
131
+
132
+ @ASGI_TRANSPORT.serializer(*media_types)
133
+ @REQUESTS_TRANSPORT.serializer(*media_types)
134
+ @WSGI_TRANSPORT.serializer(*media_types)
135
+ def inner(ctx: SerializationContext, value: Any) -> dict[str, bytes]:
136
+ result = {}
137
+ serialized = func(ctx, value)
138
+ if serialized is not None:
139
+ result["data"] = serialized
140
+ return result
141
+
142
+ return register
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from functools import lru_cache
3
4
  from typing import TYPE_CHECKING, Any, Mapping, cast
4
5
  from urllib.parse import quote, unquote, urljoin, urlsplit, urlunsplit
5
6
 
@@ -8,6 +9,7 @@ from schemathesis.core import SCHEMATHESIS_TEST_CASE_HEADER, NotSet
8
9
  from schemathesis.core.errors import InvalidSchema
9
10
  from schemathesis.core.output.sanitization import sanitize_url, sanitize_value
10
11
  from schemathesis.core.transport import USER_AGENT
12
+ from schemathesis.generation.meta import CoveragePhaseData
11
13
 
12
14
  if TYPE_CHECKING:
13
15
  from requests import PreparedRequest
@@ -16,15 +18,37 @@ if TYPE_CHECKING:
16
18
  from schemathesis.generation.case import Case
17
19
 
18
20
 
19
- def prepare_headers(case: Case, headers: dict[str, str] | None = None) -> CaseInsensitiveDict:
20
- from requests.structures import CaseInsensitiveDict
21
+ @lru_cache()
22
+ def get_default_headers() -> CaseInsensitiveDict:
23
+ from requests.utils import default_headers
24
+
25
+ headers = default_headers()
26
+ headers["User-Agent"] = USER_AGENT
27
+ return headers
21
28
 
22
- final_headers = case.headers.copy() if case.headers is not None else CaseInsensitiveDict()
29
+
30
+ def prepare_headers(case: Case, headers: dict[str, str] | None = None) -> CaseInsensitiveDict:
31
+ default_headers = get_default_headers().copy()
32
+ if case.headers:
33
+ default_headers.update(case.headers)
34
+ default_headers.setdefault(SCHEMATHESIS_TEST_CASE_HEADER, case.id)
23
35
  if headers:
24
- final_headers.update(headers)
25
- final_headers.setdefault("User-Agent", USER_AGENT)
26
- final_headers.setdefault(SCHEMATHESIS_TEST_CASE_HEADER, case.id)
27
- return final_headers
36
+ default_headers.update(headers)
37
+ for header in get_exclude_headers(case):
38
+ default_headers.pop(header, None)
39
+ return default_headers
40
+
41
+
42
+ def get_exclude_headers(case: Case) -> list[str]:
43
+ if (
44
+ case.meta is not None
45
+ and isinstance(case.meta.phase.data, CoveragePhaseData)
46
+ and case.meta.phase.data.description.startswith("Missing")
47
+ and case.meta.phase.data.description.endswith("at header")
48
+ and case.meta.phase.data.parameter is not None
49
+ ):
50
+ return [case.meta.phase.data.parameter]
51
+ return []
28
52
 
29
53
 
30
54
  def prepare_url(case: Case, base_url: str | None) -> str:
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
21
21
  from schemathesis.generation.case import Case
22
22
 
23
23
 
24
- class RequestsTransport(BaseTransport["Case", Response, "requests.Session"]):
24
+ class RequestsTransport(BaseTransport["requests.Session"]):
25
25
  def serialize_case(self, case: Case, **kwargs: Any) -> dict[str, Any]:
26
26
  base_url = kwargs.get("base_url")
27
27
  headers = kwargs.get("headers")
@@ -92,6 +92,7 @@ class RequestsTransport(BaseTransport["Case", Response, "requests.Session"]):
92
92
  if session is None:
93
93
  validate_vanilla_requests_kwargs(data)
94
94
  session = requests.Session()
95
+ session.headers = {}
95
96
  close_session = True
96
97
  else:
97
98
  close_session = False
@@ -135,14 +136,14 @@ REQUESTS_TRANSPORT = RequestsTransport()
135
136
 
136
137
 
137
138
  @REQUESTS_TRANSPORT.serializer("application/json", "text/json")
138
- def json_serializer(ctx: SerializationContext[Case], value: Any) -> dict[str, Any]:
139
+ def json_serializer(ctx: SerializationContext, value: Any) -> dict[str, Any]:
139
140
  return serialize_json(value)
140
141
 
141
142
 
142
143
  @REQUESTS_TRANSPORT.serializer(
143
144
  "text/yaml", "text/x-yaml", "text/vnd.yaml", "text/yml", "application/yaml", "application/x-yaml"
144
145
  )
145
- def yaml_serializer(ctx: SerializationContext[Case], value: Any) -> dict[str, Any]:
146
+ def yaml_serializer(ctx: SerializationContext, value: Any) -> dict[str, Any]:
146
147
  return serialize_yaml(value)
147
148
 
148
149
 
@@ -188,7 +189,7 @@ def _encode_multipart(value: Any, boundary: str) -> bytes:
188
189
 
189
190
 
190
191
  @REQUESTS_TRANSPORT.serializer("multipart/form-data", "multipart/mixed")
191
- def multipart_serializer(ctx: SerializationContext[Case], value: Any) -> dict[str, Any]:
192
+ def multipart_serializer(ctx: SerializationContext, value: Any) -> dict[str, Any]:
192
193
  if isinstance(value, bytes):
193
194
  return {"data": value}
194
195
  if isinstance(value, dict):
@@ -204,7 +205,7 @@ def multipart_serializer(ctx: SerializationContext[Case], value: Any) -> dict[st
204
205
 
205
206
 
206
207
  @REQUESTS_TRANSPORT.serializer("application/xml", "text/xml")
207
- def xml_serializer(ctx: SerializationContext[Case], value: Any) -> dict[str, Any]:
208
+ def xml_serializer(ctx: SerializationContext, value: Any) -> dict[str, Any]:
208
209
  media_type = ctx.case.media_type
209
210
 
210
211
  assert media_type is not None
@@ -216,17 +217,17 @@ def xml_serializer(ctx: SerializationContext[Case], value: Any) -> dict[str, Any
216
217
 
217
218
 
218
219
  @REQUESTS_TRANSPORT.serializer("application/x-www-form-urlencoded")
219
- def urlencoded_serializer(ctx: SerializationContext[Case], value: Any) -> dict[str, Any]:
220
+ def urlencoded_serializer(ctx: SerializationContext, value: Any) -> dict[str, Any]:
220
221
  return {"data": value}
221
222
 
222
223
 
223
224
  @REQUESTS_TRANSPORT.serializer("text/plain")
224
- def text_serializer(ctx: SerializationContext[Case], value: Any) -> dict[str, Any]:
225
+ def text_serializer(ctx: SerializationContext, value: Any) -> dict[str, Any]:
225
226
  if isinstance(value, bytes):
226
227
  return {"data": value}
227
228
  return {"data": str(value).encode("utf8")}
228
229
 
229
230
 
230
231
  @REQUESTS_TRANSPORT.serializer("application/octet-stream")
231
- def binary_serializer(ctx: SerializationContext[Case], value: Any) -> dict[str, Any]:
232
+ def binary_serializer(ctx: SerializationContext, value: Any) -> dict[str, Any]:
232
233
  return {"data": serialize_binary(value)}