schemathesis 4.0.0a2__py3-none-any.whl → 4.0.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. schemathesis/cli/__init__.py +15 -4
  2. schemathesis/cli/commands/run/__init__.py +148 -94
  3. schemathesis/cli/commands/run/context.py +72 -2
  4. schemathesis/cli/commands/run/events.py +22 -2
  5. schemathesis/cli/commands/run/executor.py +35 -12
  6. schemathesis/cli/commands/run/filters.py +1 -0
  7. schemathesis/cli/commands/run/handlers/cassettes.py +27 -46
  8. schemathesis/cli/commands/run/handlers/junitxml.py +1 -1
  9. schemathesis/cli/commands/run/handlers/output.py +180 -87
  10. schemathesis/cli/commands/run/hypothesis.py +30 -19
  11. schemathesis/cli/commands/run/reports.py +72 -0
  12. schemathesis/cli/commands/run/validation.py +18 -12
  13. schemathesis/cli/ext/groups.py +42 -13
  14. schemathesis/cli/ext/options.py +15 -8
  15. schemathesis/core/errors.py +85 -9
  16. schemathesis/core/failures.py +2 -1
  17. schemathesis/core/transforms.py +1 -1
  18. schemathesis/engine/core.py +1 -1
  19. schemathesis/engine/errors.py +17 -6
  20. schemathesis/engine/phases/stateful/__init__.py +1 -0
  21. schemathesis/engine/phases/stateful/_executor.py +9 -12
  22. schemathesis/engine/phases/unit/__init__.py +2 -3
  23. schemathesis/engine/phases/unit/_executor.py +16 -13
  24. schemathesis/engine/recorder.py +22 -21
  25. schemathesis/errors.py +23 -13
  26. schemathesis/filters.py +8 -0
  27. schemathesis/generation/coverage.py +10 -5
  28. schemathesis/generation/hypothesis/builder.py +15 -12
  29. schemathesis/generation/stateful/state_machine.py +57 -12
  30. schemathesis/pytest/lazy.py +2 -3
  31. schemathesis/pytest/plugin.py +2 -3
  32. schemathesis/schemas.py +1 -1
  33. schemathesis/specs/openapi/checks.py +77 -37
  34. schemathesis/specs/openapi/expressions/__init__.py +22 -6
  35. schemathesis/specs/openapi/expressions/nodes.py +15 -21
  36. schemathesis/specs/openapi/expressions/parser.py +1 -1
  37. schemathesis/specs/openapi/parameters.py +0 -2
  38. schemathesis/specs/openapi/patterns.py +170 -2
  39. schemathesis/specs/openapi/schemas.py +67 -39
  40. schemathesis/specs/openapi/stateful/__init__.py +207 -84
  41. schemathesis/specs/openapi/stateful/control.py +87 -0
  42. schemathesis/specs/openapi/{links.py → stateful/links.py} +72 -14
  43. {schemathesis-4.0.0a2.dist-info → schemathesis-4.0.0a4.dist-info}/METADATA +1 -1
  44. {schemathesis-4.0.0a2.dist-info → schemathesis-4.0.0a4.dist-info}/RECORD +47 -45
  45. {schemathesis-4.0.0a2.dist-info → schemathesis-4.0.0a4.dist-info}/WHEEL +0 -0
  46. {schemathesis-4.0.0a2.dist-info → schemathesis-4.0.0a4.dist-info}/entry_points.txt +0 -0
  47. {schemathesis-4.0.0a2.dist-info → schemathesis-4.0.0a4.dist-info}/licenses/LICENSE +0 -0
@@ -41,11 +41,12 @@ from schemathesis.generation.case import Case
41
41
  from schemathesis.generation.meta import CaseMetadata
42
42
  from schemathesis.generation.overrides import Override, OverrideMark, check_no_override_mark
43
43
  from schemathesis.openapi.checks import JsonSchemaError, MissingContentType
44
+ from schemathesis.specs.openapi.stateful import links
44
45
 
45
46
  from ...generation import GenerationConfig, GenerationMode
46
47
  from ...hooks import HookContext, HookDispatcher
47
48
  from ...schemas import APIOperation, APIOperationMap, ApiStatistic, BaseSchema, OperationDefinition
48
- from . import links, serialization
49
+ from . import serialization
49
50
  from ._cache import OperationCache
50
51
  from ._hypothesis import openapi_cases
51
52
  from .converter import to_json_schema, to_json_schema_recursive
@@ -66,8 +67,8 @@ from .stateful import create_state_machine
66
67
  if TYPE_CHECKING:
67
68
  from hypothesis.strategies import SearchStrategy
68
69
 
69
- from ...auths import AuthStorage
70
- from ...stateful.state_machine import APIStateMachine
70
+ from schemathesis.auths import AuthStorage
71
+ from schemathesis.generation.stateful import APIStateMachine
71
72
 
72
73
  HTTP_METHODS = frozenset({"get", "put", "post", "delete", "options", "head", "patch", "trace"})
73
74
  SCHEMA_ERROR_MESSAGE = "Ensure that the definition complies with the OpenAPI specification"
@@ -155,7 +156,6 @@ class BaseOpenAPISchema(BaseSchema):
155
156
  return True
156
157
  if self.filter_set.is_empty():
157
158
  return False
158
- path = self.get_full_path(path)
159
159
  # Attribute assignment is way faster than creating a new namespace every time
160
160
  operation = _ctx_cache.operation
161
161
  operation.method = method
@@ -174,30 +174,64 @@ class BaseOpenAPISchema(BaseSchema):
174
174
  return statistic
175
175
 
176
176
  resolve = self.resolver.resolve
177
+ resolve_path_item = self._resolve_path_item
177
178
  should_skip = self._should_skip
178
179
  links_field = self.links_field
179
180
 
181
+ # For operationId lookup
182
+ selected_operations_by_id: set[str] = set()
183
+ # Tuples of (method, path)
184
+ selected_operations_by_path: set[tuple[str, str]] = set()
185
+ collected_links: list[dict] = []
186
+
180
187
  for path, path_item in paths.items():
181
188
  try:
182
- if "$ref" in path_item:
183
- _, path_item = resolve(path_item["$ref"])
184
- for method, definition in path_item.items():
185
- if method not in HTTP_METHODS:
186
- continue
187
- statistic.operations.total += 1
188
- is_selected = not should_skip(path, method, definition)
189
- if is_selected:
190
- statistic.operations.selected += 1
191
- for response in definition.get("responses", {}).values():
192
- if "$ref" in response:
193
- _, response = resolve(response["$ref"])
194
- defined_links = response.get(links_field)
195
- if defined_links is not None:
196
- statistic.links.total += len(defined_links)
197
- if is_selected:
198
- statistic.links.selected = len(defined_links)
189
+ scope, path_item = resolve_path_item(path_item)
190
+ self.resolver.push_scope(scope)
191
+ try:
192
+ for method, definition in path_item.items():
193
+ if method not in HTTP_METHODS:
194
+ continue
195
+ statistic.operations.total += 1
196
+ is_selected = not should_skip(path, method, definition)
197
+ if is_selected:
198
+ statistic.operations.selected += 1
199
+ # Store both identifiers
200
+ if "operationId" in definition:
201
+ selected_operations_by_id.add(definition["operationId"])
202
+ selected_operations_by_path.add((method, path))
203
+ for response in definition.get("responses", {}).values():
204
+ if "$ref" in response:
205
+ _, response = resolve(response["$ref"])
206
+ defined_links = response.get(links_field)
207
+ if defined_links is not None:
208
+ statistic.links.total += len(defined_links)
209
+ if is_selected:
210
+ collected_links.extend(defined_links.values())
211
+ finally:
212
+ self.resolver.pop_scope()
199
213
  except SCHEMA_PARSING_ERRORS:
200
214
  continue
215
+
216
+ def is_link_selected(link: dict) -> bool:
217
+ if "$ref" in link:
218
+ _, link = resolve(link["$ref"])
219
+
220
+ if "operationId" in link:
221
+ return link["operationId"] in selected_operations_by_id
222
+ else:
223
+ try:
224
+ scope, _ = resolve(link["operationRef"])
225
+ path, method = scope.rsplit("/", maxsplit=2)[-2:]
226
+ path = path.replace("~1", "/").replace("~0", "~")
227
+ return (method, path) in selected_operations_by_path
228
+ except Exception:
229
+ return False
230
+
231
+ for link in collected_links:
232
+ if is_link_selected(link):
233
+ statistic.links.selected += 1
234
+
201
235
  return statistic
202
236
 
203
237
  def _operation_iter(self) -> Generator[dict[str, Any], None, None]:
@@ -331,28 +365,24 @@ class BaseOpenAPISchema(BaseSchema):
331
365
  def _into_err(self, error: Exception, path: str | None, method: str | None) -> Err[InvalidSchema]:
332
366
  __tracebackhide__ = True
333
367
  try:
334
- full_path = self.get_full_path(path) if isinstance(path, str) else None
335
- self._raise_invalid_schema(error, full_path, path, method)
368
+ self._raise_invalid_schema(error, path, method)
336
369
  except InvalidSchema as exc:
337
370
  return Err(exc)
338
371
 
339
372
  def _raise_invalid_schema(
340
373
  self,
341
374
  error: Exception,
342
- full_path: str | None = None,
343
375
  path: str | None = None,
344
376
  method: str | None = None,
345
377
  ) -> NoReturn:
346
378
  __tracebackhide__ = True
347
379
  if isinstance(error, RefResolutionError):
348
- raise InvalidSchema.from_reference_resolution_error(
349
- error, path=path, method=method, full_path=full_path
350
- ) from None
380
+ raise InvalidSchema.from_reference_resolution_error(error, path=path, method=method) from None
351
381
  try:
352
382
  self.validate()
353
383
  except jsonschema.ValidationError as exc:
354
- raise InvalidSchema.from_jsonschema_error(exc, path=path, method=method, full_path=full_path) from None
355
- raise InvalidSchema(SCHEMA_ERROR_MESSAGE, path=path, method=method, full_path=full_path) from error
384
+ raise InvalidSchema.from_jsonschema_error(exc, path=path, method=method) from None
385
+ raise InvalidSchema(SCHEMA_ERROR_MESSAGE, path=path, method=method) from error
356
386
 
357
387
  def validate(self) -> None:
358
388
  with suppress(TypeError):
@@ -550,8 +580,7 @@ class BaseOpenAPISchema(BaseSchema):
550
580
  responses = operation.definition.raw["responses"]
551
581
  except KeyError as exc:
552
582
  path = operation.path
553
- full_path = self.get_full_path(path) if isinstance(path, str) else None
554
- self._raise_invalid_schema(exc, full_path, path, operation.method)
583
+ self._raise_invalid_schema(exc, path, operation.method)
555
584
  status_code = str(response.status_code)
556
585
  if status_code in responses:
557
586
  return self.resolver.resolve_in_scope(responses[status_code], operation.definition.scope)
@@ -569,13 +598,7 @@ class BaseOpenAPISchema(BaseSchema):
569
598
  return scopes, definitions.get("headers")
570
599
 
571
600
  def as_state_machine(self) -> type[APIStateMachine]:
572
- try:
573
- return create_state_machine(self)
574
- except OperationNotFound as exc:
575
- raise LoaderError(
576
- kind=LoaderErrorKind.OPEN_API_INVALID_SCHEMA,
577
- message=f"Invalid Open API link definition: Operation `{exc.item}` not found",
578
- ) from exc
601
+ return create_state_machine(self)
579
602
 
580
603
  def add_link(
581
604
  self,
@@ -626,7 +649,12 @@ class BaseOpenAPISchema(BaseSchema):
626
649
  def get_links(self, operation: APIOperation) -> dict[str, dict[str, Any]]:
627
650
  result: dict[str, dict[str, Any]] = defaultdict(dict)
628
651
  for status_code, link in links.get_all_links(operation):
629
- result[status_code][link.name] = link
652
+ if isinstance(link, Ok):
653
+ name = link.ok().name
654
+ else:
655
+ name = link.err().name
656
+ result[status_code][name] = link
657
+
630
658
  return result
631
659
 
632
660
  def get_tags(self, operation: APIOperation) -> list[str] | None:
@@ -1,32 +1,40 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections import defaultdict
3
+ from dataclasses import dataclass
4
4
  from functools import lru_cache
5
5
  from typing import TYPE_CHECKING, Any, Callable, Iterator
6
6
 
7
7
  from hypothesis import strategies as st
8
8
  from hypothesis.stateful import Bundle, Rule, precondition, rule
9
9
 
10
+ from schemathesis.core.errors import InvalidStateMachine
10
11
  from schemathesis.core.result import Ok
12
+ from schemathesis.core.transforms import UNRESOLVABLE
13
+ from schemathesis.engine.recorder import ScenarioRecorder
14
+ from schemathesis.generation import GenerationMode
11
15
  from schemathesis.generation.case import Case
12
16
  from schemathesis.generation.hypothesis import strategies
13
17
  from schemathesis.generation.stateful.state_machine import APIStateMachine, StepInput, StepOutput, _normalize_name
14
18
  from schemathesis.schemas import APIOperation
15
-
16
- from ....generation import GenerationMode
17
- from ..links import OpenApiLink, get_all_links
18
- from ..utils import expand_status_code
19
+ from schemathesis.specs.openapi.stateful.control import TransitionController
20
+ from schemathesis.specs.openapi.stateful.links import OpenApiLink, get_all_links
21
+ from schemathesis.specs.openapi.utils import expand_status_code
19
22
 
20
23
  if TYPE_CHECKING:
21
24
  from schemathesis.generation.stateful.state_machine import StepOutput
22
-
23
- from ..schemas import BaseOpenAPISchema
25
+ from schemathesis.specs.openapi.schemas import BaseOpenAPISchema
24
26
 
25
27
  FilterFunction = Callable[["StepOutput"], bool]
26
28
 
27
29
 
28
30
  class OpenAPIStateMachine(APIStateMachine):
29
31
  _response_matchers: dict[str, Callable[[StepOutput], str | None]]
32
+ _transitions: ApiTransitions
33
+
34
+ def __init__(self) -> None:
35
+ self.recorder = ScenarioRecorder(label="Stateful tests")
36
+ self.control = TransitionController(self._transitions)
37
+ super().__init__()
30
38
 
31
39
  def _get_target_for_result(self, result: StepOutput) -> str | None:
32
40
  matcher = self._response_matchers.get(result.case.operation.label)
@@ -36,81 +44,155 @@ class OpenAPIStateMachine(APIStateMachine):
36
44
 
37
45
 
38
46
  # The proportion of negative tests generated for "root" transitions
39
- NEGATIVE_TEST_CASES_THRESHOLD = 20
47
+ NEGATIVE_TEST_CASES_THRESHOLD = 10
40
48
 
41
49
 
42
- def create_state_machine(schema: BaseOpenAPISchema) -> type[APIStateMachine]:
43
- """Create a state machine class.
50
+ @dataclass
51
+ class OperationTransitions:
52
+ """Transitions for a single operation."""
44
53
 
45
- It aims to avoid making calls that are not likely to lead to a stateful call later. For example:
46
- 1. POST /users/
47
- 2. GET /users/{id}/
54
+ __slots__ = ("incoming", "outgoing")
48
55
 
49
- This state machine won't make calls to (2) without having a proper response from (1) first.
50
- """
56
+ def __init__(self) -> None:
57
+ self.incoming: list[OpenApiLink] = []
58
+ self.outgoing: list[OpenApiLink] = []
59
+
60
+
61
+ @dataclass
62
+ class ApiTransitions:
63
+ """Stores all transitions grouped by operation."""
64
+
65
+ __slots__ = ("operations",)
66
+
67
+ def __init__(self) -> None:
68
+ # operation label -> its transitions
69
+ self.operations: dict[str, OperationTransitions] = {}
70
+
71
+ def add_outgoing(self, source: str, link: OpenApiLink) -> None:
72
+ """Record an outgoing transition from source operation."""
73
+ self.operations.setdefault(source, OperationTransitions()).outgoing.append(link)
74
+ self.operations.setdefault(link.target.label, OperationTransitions()).incoming.append(link)
75
+
76
+
77
+ @dataclass
78
+ class RootTransitions:
79
+ """Classification of API operations that can serve as entry points."""
80
+
81
+ __slots__ = ("reliable", "fallback")
82
+
83
+ def __init__(self) -> None:
84
+ # Operations likely to succeed and provide data for other transitions
85
+ self.reliable: set[str] = set()
86
+ # Operations that might work but are less reliable
87
+ self.fallback: set[str] = set()
88
+
89
+
90
+ def collect_transitions(operations: list[APIOperation]) -> ApiTransitions:
91
+ """Collect all transitions between operations."""
92
+ transitions = ApiTransitions()
93
+
94
+ selected_labels = {operation.label for operation in operations}
95
+ errors = []
96
+ for operation in operations:
97
+ for _, link in get_all_links(operation):
98
+ if isinstance(link, Ok):
99
+ if link.ok().target.label in selected_labels:
100
+ transitions.add_outgoing(operation.label, link.ok())
101
+ else:
102
+ errors.append(link.err())
103
+
104
+ if errors:
105
+ raise InvalidStateMachine(errors)
106
+
107
+ return transitions
108
+
109
+
110
+ def create_state_machine(schema: BaseOpenAPISchema) -> type[APIStateMachine]:
51
111
  operations = [result.ok() for result in schema.get_all_operations() if isinstance(result, Ok)]
52
112
  bundles = {}
53
- incoming_transitions = defaultdict(list)
113
+ transitions = collect_transitions(operations)
54
114
  _response_matchers: dict[str, Callable[[StepOutput], str | None]] = {}
55
- # Statistic structure follows the links and count for each response status code
115
+
116
+ # Create bundles and matchers
56
117
  for operation in operations:
57
118
  all_status_codes = tuple(operation.definition.raw["responses"])
58
119
  bundle_matchers = []
59
- for _, link in get_all_links(operation):
60
- bundle_name = f"{operation.label} -> {link.status_code}"
61
- bundles[bundle_name] = Bundle(bundle_name)
62
- incoming_transitions[link.target.label].append(link)
63
- bundle_matchers.append((bundle_name, make_response_filter(link.status_code, all_status_codes)))
120
+
121
+ if operation.label in transitions.operations:
122
+ # Use outgoing transitions
123
+ for link in transitions.operations[operation.label].outgoing:
124
+ bundle_name = f"{operation.label} -> {link.status_code}"
125
+ bundles[bundle_name] = Bundle(bundle_name)
126
+ bundle_matchers.append((bundle_name, make_response_filter(link.status_code, all_status_codes)))
127
+
64
128
  if bundle_matchers:
65
129
  _response_matchers[operation.label] = make_response_matcher(bundle_matchers)
130
+
66
131
  rules = {}
67
132
  catch_all = Bundle("catch_all")
68
133
 
134
+ # We want stateful testing to be effective and focus on meaningful transitions.
135
+ # An operation is considered as a "root" transition (entry point) if it satisfies certain criteria
136
+ # that indicate it's likely to succeed and provide data for other transitions.
137
+ # For example:
138
+ # - POST operations that create resources
139
+ # - GET operations without path parameters (e.g., GET /users/ to list all users)
140
+ #
141
+ # We avoid adding operations as roots if they:
142
+ # 1. Have incoming transitions that will provide proper data
143
+ # Example: If POST /users/ -> GET /users/{id} exists, we don't need
144
+ # to generate random user IDs for GET /users/{id}
145
+ # 2. Are unlikely to succeed with random data
146
+ # Example: GET /users/{id} with random ID is likely to return 404
147
+ #
148
+ # This way we:
149
+ # 1. Maximize the chance of successful transitions
150
+ # 2. Don't waste the test budget (limited number of steps) on likely-to-fail operations
151
+ # 3. Focus on transitions that are designed to work together via links
152
+
153
+ roots = classify_root_transitions(operations, transitions)
154
+
69
155
  for target in operations:
70
- incoming = incoming_transitions.get(target.label)
71
- if incoming is not None:
72
- for link in incoming:
73
- bundle_name = f"{link.source.label} -> {link.status_code}"
74
- name = _normalize_name(f"{link.status_code} -> {target.label}")
75
- rules[name] = precondition(ensure_non_empty_bundle(bundle_name))(
76
- transition(
77
- name=name,
78
- target=catch_all,
79
- input=bundles[bundle_name].flatmap(
80
- into_step_input(target=target, link=link, modes=schema.generation_config.modes)
81
- ),
156
+ if target.label in transitions.operations:
157
+ incoming = transitions.operations[target.label].incoming
158
+ if incoming:
159
+ for link in incoming:
160
+ bundle_name = f"{link.source.label} -> {link.status_code}"
161
+ name = _normalize_name(
162
+ f"{link.source.label} -> {link.status_code} -> {link.name} -> {target.label}"
82
163
  )
164
+ assert name not in rules, name
165
+ rules[name] = precondition(is_transition_allowed(bundle_name, link.source.label, target.label))(
166
+ transition(
167
+ name=name,
168
+ target=catch_all,
169
+ input=bundles[bundle_name].flatmap(
170
+ into_step_input(target=target, link=link, modes=schema.generation_config.modes)
171
+ ),
172
+ )
173
+ )
174
+ if target.label in roots.reliable or (not roots.reliable and target.label in roots.fallback):
175
+ name = _normalize_name(f"RANDOM -> {target.label}")
176
+ if len(schema.generation_config.modes) == 1:
177
+ case_strategy = target.as_strategy(generation_mode=schema.generation_config.modes[0])
178
+ else:
179
+ _strategies = {
180
+ method: target.as_strategy(generation_mode=method) for method in schema.generation_config.modes
181
+ }
182
+
183
+ @st.composite # type: ignore[misc]
184
+ def case_strategy_factory(
185
+ draw: st.DrawFn, strategies: dict[GenerationMode, st.SearchStrategy] = _strategies
186
+ ) -> Case:
187
+ if draw(st.integers(min_value=0, max_value=99)) < NEGATIVE_TEST_CASES_THRESHOLD:
188
+ return draw(strategies[GenerationMode.NEGATIVE])
189
+ return draw(strategies[GenerationMode.POSITIVE])
190
+
191
+ case_strategy = case_strategy_factory()
192
+
193
+ rules[name] = precondition(is_root_allowed(target.label))(
194
+ transition(name=name, target=catch_all, input=case_strategy.map(StepInput.initial))
83
195
  )
84
- elif any(
85
- incoming.source.label == target.label
86
- for transitions in incoming_transitions.values()
87
- for incoming in transitions
88
- ):
89
- # No incoming transitions, but has at least one outgoing transition
90
- # For example, POST /users/ -> GET /users/{id}/
91
- # The source operation has no prerequisite, but we need to allow this rule to be executed
92
- # in order to reach other transitions
93
- name = _normalize_name(f"{target.label} -> X")
94
- if len(schema.generation_config.modes) == 1:
95
- case_strategy = target.as_strategy(generation_mode=schema.generation_config.modes[0])
96
- else:
97
- _strategies = {
98
- method: target.as_strategy(generation_mode=method) for method in schema.generation_config.modes
99
- }
100
-
101
- @st.composite # type: ignore[misc]
102
- def case_strategy_factory(
103
- draw: st.DrawFn, strategies: dict[GenerationMode, st.SearchStrategy] = _strategies
104
- ) -> Case:
105
- if draw(st.integers(min_value=0, max_value=99)) < NEGATIVE_TEST_CASES_THRESHOLD:
106
- return draw(strategies[GenerationMode.NEGATIVE])
107
- return draw(strategies[GenerationMode.POSITIVE])
108
-
109
- case_strategy = case_strategy_factory()
110
-
111
- rules[name] = precondition(ensure_links_followed)(
112
- transition(name=name, target=catch_all, input=case_strategy.map(StepInput.initial))
113
- )
114
196
 
115
197
  return type(
116
198
  "APIWorkflow",
@@ -119,69 +201,110 @@ def create_state_machine(schema: BaseOpenAPISchema) -> type[APIStateMachine]:
119
201
  "schema": schema,
120
202
  "bundles": bundles,
121
203
  "_response_matchers": _response_matchers,
204
+ "_transitions": transitions,
122
205
  **rules,
123
206
  },
124
207
  )
125
208
 
126
209
 
210
+ def classify_root_transitions(operations: list[APIOperation], transitions: ApiTransitions) -> RootTransitions:
211
+ """Find operations that can serve as root transitions."""
212
+ roots = RootTransitions()
213
+
214
+ for operation in operations:
215
+ # Skip if operation has no outgoing transitions
216
+ operation_transitions = transitions.operations.get(operation.label)
217
+ if not operation_transitions or not operation_transitions.outgoing:
218
+ continue
219
+
220
+ if is_likely_root_transition(operation, operation_transitions):
221
+ roots.reliable.add(operation.label)
222
+ else:
223
+ roots.fallback.add(operation.label)
224
+
225
+ return roots
226
+
227
+
228
+ def is_likely_root_transition(operation: APIOperation, transitions: OperationTransitions) -> bool:
229
+ """Check if operation is likely to succeed as a root transition."""
230
+ # POST operations with request bodies are likely to create resources
231
+ if operation.method == "post" and operation.body:
232
+ return True
233
+
234
+ # GET operations without path parameters are likely to return lists
235
+ if operation.method == "get" and not operation.path_parameters:
236
+ return True
237
+
238
+ return False
239
+
240
+
127
241
  def into_step_input(
128
242
  target: APIOperation, link: OpenApiLink, modes: list[GenerationMode]
129
243
  ) -> Callable[[StepOutput], st.SearchStrategy[StepInput]]:
130
244
  def builder(_output: StepOutput) -> st.SearchStrategy[StepInput]:
131
245
  @st.composite # type: ignore[misc]
132
246
  def inner(draw: st.DrawFn, output: StepOutput) -> StepInput:
133
- transition_data = link.extract(output)
247
+ transition = link.extract(output)
134
248
 
135
249
  kwargs: dict[str, Any] = {
136
250
  container: {
137
251
  name: extracted.value.ok()
138
252
  for name, extracted in data.items()
139
- if isinstance(extracted.value, Ok) and extracted.value.ok() is not None
253
+ if isinstance(extracted.value, Ok) and extracted.value.ok() not in (None, UNRESOLVABLE)
140
254
  }
141
- for container, data in transition_data.parameters.items()
255
+ for container, data in transition.parameters.items()
142
256
  }
257
+
143
258
  if (
144
- transition_data.request_body is not None
145
- and isinstance(transition_data.request_body.value, Ok)
259
+ transition.request_body is not None
260
+ and isinstance(transition.request_body.value, Ok)
261
+ and transition.request_body.value.ok() is not UNRESOLVABLE
146
262
  and not link.merge_body
147
263
  ):
148
- kwargs["body"] = transition_data.request_body.value.ok()
264
+ kwargs["body"] = transition.request_body.value.ok()
149
265
  cases = strategies.combine([target.as_strategy(generation_mode=mode, **kwargs) for mode in modes])
150
266
  case = draw(cases)
151
267
  if (
152
- transition_data.request_body is not None
153
- and isinstance(transition_data.request_body.value, Ok)
268
+ transition.request_body is not None
269
+ and isinstance(transition.request_body.value, Ok)
270
+ and transition.request_body.value.ok() is not UNRESOLVABLE
154
271
  and link.merge_body
155
272
  ):
156
- new = transition_data.request_body.value.ok()
273
+ new = transition.request_body.value.ok()
157
274
  if isinstance(case.body, dict) and isinstance(new, dict):
158
275
  case.body = {**case.body, **new}
159
276
  else:
160
277
  case.body = new
161
- return StepInput(case=case, transition=transition_data)
278
+ return StepInput(case=case, transition=transition)
162
279
 
163
280
  return inner(output=_output)
164
281
 
165
282
  return builder
166
283
 
167
284
 
168
- def ensure_non_empty_bundle(bundle_name: str) -> Callable[[APIStateMachine], bool]:
169
- def inner(machine: APIStateMachine) -> bool:
170
- return bool(machine.bundles.get(bundle_name))
285
+ def is_transition_allowed(bundle_name: str, source: str, target: str) -> Callable[[OpenAPIStateMachine], bool]:
286
+ def inner(machine: OpenAPIStateMachine) -> bool:
287
+ return bool(machine.bundles.get(bundle_name)) and machine.control.allow_transition(source, target)
171
288
 
172
289
  return inner
173
290
 
174
291
 
175
- def ensure_links_followed(machine: APIStateMachine) -> bool:
176
- # If there are responses that have links to follow, reject any rule without incoming transitions
177
- for bundle in machine.bundles.values():
178
- if bundle:
179
- return False
180
- return True
292
+ def is_root_allowed(label: str) -> Callable[[OpenAPIStateMachine], bool]:
293
+ def inner(machine: OpenAPIStateMachine) -> bool:
294
+ return machine.control.allow_root_transition(label, machine.bundles)
295
+
296
+ return inner
181
297
 
182
298
 
183
299
  def transition(*, name: str, target: Bundle, input: st.SearchStrategy[StepInput]) -> Callable[[Callable], Rule]:
184
- def step_function(self: APIStateMachine, input: StepInput) -> StepOutput | None:
300
+ def step_function(self: OpenAPIStateMachine, input: StepInput) -> StepOutput | None:
301
+ if input.transition is not None:
302
+ self.recorder.record_case(
303
+ parent_id=input.transition.parent_id, transition=input.transition, case=input.case
304
+ )
305
+ else:
306
+ self.recorder.record_case(parent_id=None, transition=None, case=input.case)
307
+ self.control.record_step(input, self.recorder)
185
308
  return APIStateMachine._step(self, input=input)
186
309
 
187
310
  step_function.__name__ = name
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import Counter
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING
6
+
7
+ from schemathesis.engine.recorder import ScenarioRecorder
8
+ from schemathesis.generation.stateful.state_machine import DEFAULT_STATEFUL_STEP_COUNT
9
+
10
+ if TYPE_CHECKING:
11
+ from requests.structures import CaseInsensitiveDict
12
+
13
+ from schemathesis.generation.stateful.state_machine import StepInput
14
+ from schemathesis.specs.openapi.stateful import ApiTransitions
15
+
16
+
17
+ # It is enough to be able to catch double-click type of issues
18
+ MAX_OPERATIONS_PER_SOURCE_CAP = 2
19
+ # Maximum number of concurrent root sources (e.g., active users in the system)
20
+ MAX_ROOT_SOURCES = 2
21
+
22
+
23
+ def _get_max_operations_per_source(transitions: ApiTransitions) -> int:
24
+ """Calculate global limit based on number of sources to maximize diversity of used API calls."""
25
+ sources = len(transitions.operations)
26
+
27
+ if sources == 0:
28
+ return MAX_OPERATIONS_PER_SOURCE_CAP
29
+
30
+ # Total steps divided by number of sources, but never below the cap
31
+ return max(MAX_OPERATIONS_PER_SOURCE_CAP, DEFAULT_STATEFUL_STEP_COUNT // sources)
32
+
33
+
34
+ @dataclass
35
+ class TransitionController:
36
+ """Controls which transitions can be executed in a state machine."""
37
+
38
+ __slots__ = ("transitions", "max_operations_per_source", "statistic")
39
+
40
+ def __init__(self, transitions: ApiTransitions) -> None:
41
+ # Incoming & outgoing transitions available in the state machine
42
+ self.transitions = transitions
43
+ self.max_operations_per_source = _get_max_operations_per_source(transitions)
44
+ # source -> derived API calls
45
+ self.statistic: dict[str, dict[str, Counter[str]]] = {}
46
+
47
+ def record_step(self, input: StepInput, recorder: ScenarioRecorder) -> None:
48
+ """Record API call input."""
49
+ case = input.case
50
+
51
+ if (
52
+ case.operation.label in self.transitions.operations
53
+ and self.transitions.operations[case.operation.label].outgoing
54
+ ):
55
+ # This API operation has outgoing transitions, hence record it as a source
56
+ entry = self.statistic.setdefault(input.case.operation.label, {})
57
+ entry[input.case.id] = Counter()
58
+
59
+ if input.transition is not None:
60
+ # Find immediate parent and record as derived operation
61
+ parent = recorder.cases[input.transition.parent_id]
62
+ source = parent.value.operation.label
63
+ case_id = parent.value.id
64
+
65
+ if source in self.statistic and case_id in self.statistic[source]:
66
+ self.statistic[source][case_id][case.operation.label] += 1
67
+
68
+ def allow_root_transition(self, source: str, bundles: dict[str, CaseInsensitiveDict]) -> bool:
69
+ """Decide if this root transition should be allowed now."""
70
+ if len(self.statistic.get(source, {})) < MAX_ROOT_SOURCES:
71
+ return True
72
+
73
+ # If all non-root operations are blocked, then allow root ones to make progress
74
+ history = {name.split("->")[0].strip() for name, values in bundles.items() if values}
75
+ return all(
76
+ incoming.source.label not in history
77
+ or not self.allow_transition(incoming.source.label, incoming.target.label)
78
+ for transitions in self.transitions.operations.values()
79
+ for incoming in transitions.incoming
80
+ if transitions.incoming
81
+ )
82
+
83
+ def allow_transition(self, source: str, target: str) -> bool:
84
+ """Decide if this transition should be allowed now."""
85
+ existing = self.statistic.get(source, {})
86
+ total = sum(metric.get(target, 0) for metric in existing.values())
87
+ return total < self.max_operations_per_source