schemathesis 4.0.26__py3-none-any.whl → 4.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -217,6 +217,12 @@ DEFAULT_PHASES = ["examples", "coverage", "fuzzing", "stateful"]
217
217
  type=str,
218
218
  callback=validation.validate_rate_limit,
219
219
  )
220
+ @grouped_option(
221
+ "--max-redirects",
222
+ help="Maximum number of redirects to follow for each request",
223
+ type=click.IntRange(min=0),
224
+ show_default=True,
225
+ )
220
226
  @grouped_option(
221
227
  "--request-timeout",
222
228
  help="Timeout limit, in seconds, for each network request during tests",
@@ -448,6 +454,7 @@ def run(
448
454
  suppress_health_check: list[HealthCheck] | None,
449
455
  warnings: bool | list[SchemathesisWarning] | None,
450
456
  rate_limit: str | None = None,
457
+ max_redirects: int | None = None,
451
458
  request_timeout: int | None = None,
452
459
  request_tls_verify: bool | None = None,
453
460
  request_cert: str | None = None,
@@ -528,6 +535,7 @@ def run(
528
535
  workers=workers,
529
536
  continue_on_failure=continue_on_failure,
530
537
  rate_limit=rate_limit,
538
+ max_redirects=max_redirects,
531
539
  request_timeout=request_timeout,
532
540
  tls_verify=request_tls_verify,
533
541
  request_cert=request_cert,
@@ -454,6 +454,7 @@ def har_writer(path: Path, config: SchemathesisConfig, queue: Queue) -> None:
454
454
  )
455
455
  elif isinstance(item, Finalize):
456
456
  break
457
+ har.flush()
457
458
 
458
459
 
459
460
  HARFILE_NO_RESPONSE = harfile.Response(
@@ -592,6 +592,7 @@ class StatefulProgressManager:
592
592
  console: Console
593
593
  title: str
594
594
  links_selected: int
595
+ links_inferred: int
595
596
  links_total: int
596
597
  start_time: float
597
598
 
@@ -616,6 +617,7 @@ class StatefulProgressManager:
616
617
  "console",
617
618
  "title",
618
619
  "links_selected",
620
+ "links_inferred",
619
621
  "links_total",
620
622
  "start_time",
621
623
  "title_progress",
@@ -631,13 +633,16 @@ class StatefulProgressManager:
631
633
  "is_interrupted",
632
634
  )
633
635
 
634
- def __init__(self, *, console: Console, title: str, links_selected: int, links_total: int) -> None:
636
+ def __init__(
637
+ self, *, console: Console, title: str, links_selected: int, links_inferred: int, links_total: int
638
+ ) -> None:
635
639
  from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
636
640
  from rich.style import Style
637
641
 
638
642
  self.console = console
639
643
  self.title = title
640
644
  self.links_selected = links_selected
645
+ self.links_inferred = links_inferred
641
646
  self.links_total = links_total
642
647
  self.start_time = time.monotonic()
643
648
 
@@ -686,9 +691,10 @@ class StatefulProgressManager:
686
691
 
687
692
  # Initialize progress displays
688
693
  self.title_task_id = self.title_progress.add_task("Stateful")
689
- self.progress_task_id = self.progress_bar.add_task(
690
- "", scenarios=0, links=f"0 covered / {self.links_selected} selected / {self.links_total} total links"
691
- )
694
+ links = f"0 covered / {self.links_selected} selected / {self.links_total} total"
695
+ if self.links_inferred:
696
+ links += f" ({self.links_inferred} inferred)"
697
+ self.progress_task_id = self.progress_bar.add_task("", scenarios=0, links=links)
692
698
 
693
699
  # Create live display
694
700
  group = Group(
@@ -720,11 +726,10 @@ class StatefulProgressManager:
720
726
  def _update_progress_display(self) -> None:
721
727
  """Update the progress display."""
722
728
  assert self.progress_task_id is not None
723
- self.progress_bar.update(
724
- self.progress_task_id,
725
- scenarios=self.scenarios,
726
- links=f"{len(self.links_covered)} covered / {self.links_selected} selected / {self.links_total} total links",
727
- )
729
+ links = f"{len(self.links_covered)} covered / {self.links_selected} selected / {self.links_total} total"
730
+ if self.links_inferred:
731
+ links += f" ({self.links_inferred} inferred)"
732
+ self.progress_bar.update(self.progress_task_id, scenarios=self.scenarios, links=links)
728
733
 
729
734
  def _get_stats_message(self) -> str:
730
735
  """Get formatted stats message."""
@@ -888,7 +893,7 @@ class OutputHandler(EventHandler):
888
893
  elif phase.name in [PhaseName.EXAMPLES, PhaseName.COVERAGE, PhaseName.FUZZING] and phase.is_enabled:
889
894
  self._start_unit_tests(phase.name)
890
895
  elif phase.name == PhaseName.STATEFUL_TESTING and phase.is_enabled and phase.skip_reason is None:
891
- self._start_stateful_tests()
896
+ self._start_stateful_tests(event)
892
897
 
893
898
  def _start_probing(self) -> None:
894
899
  self.probing_manager = ProbingProgressManager(console=self.console)
@@ -904,13 +909,18 @@ class OutputHandler(EventHandler):
904
909
  )
905
910
  self.unit_tests_manager.start()
906
911
 
907
- def _start_stateful_tests(self) -> None:
912
+ def _start_stateful_tests(self, event: events.PhaseStarted) -> None:
908
913
  assert self.statistic is not None
914
+ assert event.payload is not None
915
+ # Total number of links - original ones + inferred during tests
916
+ links_selected = self.statistic.links.selected + event.payload.inferred_links
917
+ links_total = self.statistic.links.total + event.payload.inferred_links
909
918
  self.stateful_tests_manager = StatefulProgressManager(
910
919
  console=self.console,
911
920
  title="Stateful",
912
- links_selected=self.statistic.links.selected,
913
- links_total=self.statistic.links.total,
921
+ links_selected=links_selected,
922
+ links_inferred=event.payload.inferred_links,
923
+ links_total=links_total,
914
924
  )
915
925
  self.stateful_tests_manager.start()
916
926
 
@@ -990,10 +1000,10 @@ class OutputHandler(EventHandler):
990
1000
  table.add_column("Field", style=Style(color="bright_white", bold=True))
991
1001
  table.add_column("Value", style="cyan")
992
1002
  table.add_row("Scenarios:", f"{self.stateful_tests_manager.scenarios}")
993
- table.add_row(
994
- "API Links:",
995
- f"{len(self.stateful_tests_manager.links_covered)} covered / {self.stateful_tests_manager.links_selected} selected / {self.stateful_tests_manager.links_total} total",
996
- )
1003
+ message = f"{len(self.stateful_tests_manager.links_covered)} covered / {self.stateful_tests_manager.links_selected} selected / {self.stateful_tests_manager.links_total} total"
1004
+ if self.stateful_tests_manager.links_inferred:
1005
+ message += f" ({self.stateful_tests_manager.links_inferred} inferred)"
1006
+ table.add_row("API Links:", message)
997
1007
 
998
1008
  self.console.print()
999
1009
  self.console.print(Padding(table, BLOCK_PADDING))
@@ -199,6 +199,7 @@ class OperationConfig(DiffBase):
199
199
  continue_on_failure: bool | None
200
200
  tls_verify: bool | str | None
201
201
  rate_limit: Limiter | None
202
+ max_redirects: int | None
202
203
  request_timeout: float | int | None
203
204
  request_cert: str | None
204
205
  request_cert_key: str | None
@@ -218,6 +219,7 @@ class OperationConfig(DiffBase):
218
219
  "tls_verify",
219
220
  "rate_limit",
220
221
  "_rate_limit",
222
+ "max_redirects",
221
223
  "request_timeout",
222
224
  "request_cert",
223
225
  "request_cert_key",
@@ -239,6 +241,7 @@ class OperationConfig(DiffBase):
239
241
  continue_on_failure: bool | None = None,
240
242
  tls_verify: bool | str | None = None,
241
243
  rate_limit: str | None = None,
244
+ max_redirects: int | None = None,
242
245
  request_timeout: float | int | None = None,
243
246
  request_cert: str | None = None,
244
247
  request_cert_key: str | None = None,
@@ -260,6 +263,7 @@ class OperationConfig(DiffBase):
260
263
  else:
261
264
  self.rate_limit = rate_limit
262
265
  self._rate_limit = rate_limit
266
+ self.max_redirects = max_redirects
263
267
  self.request_timeout = request_timeout
264
268
  self.request_cert = request_cert
265
269
  self.request_cert_key = request_cert_key
@@ -308,6 +312,7 @@ class OperationConfig(DiffBase):
308
312
  continue_on_failure=data.get("continue-on-failure", None),
309
313
  tls_verify=resolve(data.get("tls-verify")),
310
314
  rate_limit=resolve(data.get("rate-limit")),
315
+ max_redirects=data.get("max-redirects"),
311
316
  request_timeout=data.get("request-timeout"),
312
317
  request_cert=resolve(data.get("request-cert")),
313
318
  request_cert_key=resolve(data.get("request-cert-key")),
@@ -109,14 +109,40 @@ class CoveragePhaseConfig(DiffBase):
109
109
  )
110
110
 
111
111
 
112
+ @dataclass(repr=False)
113
+ class InferenceConfig(DiffBase):
114
+ algorithms: list[str]
115
+
116
+ __slots__ = ("algorithms",)
117
+
118
+ def __init__(
119
+ self,
120
+ *,
121
+ algorithms: list[str] | None = None,
122
+ ) -> None:
123
+ self.algorithms = algorithms if algorithms is not None else ["location-headers"]
124
+
125
+ @classmethod
126
+ def from_dict(cls, data: dict[str, Any]) -> InferenceConfig:
127
+ return cls(
128
+ algorithms=data.get("algorithms", ["location-headers"]),
129
+ )
130
+
131
+ @property
132
+ def is_enabled(self) -> bool:
133
+ """Inference is enabled if any algorithms are configured."""
134
+ return bool(self.algorithms)
135
+
136
+
112
137
  @dataclass(repr=False)
113
138
  class StatefulPhaseConfig(DiffBase):
114
139
  enabled: bool
115
140
  generation: GenerationConfig
116
141
  checks: ChecksConfig
117
142
  max_steps: int
143
+ inference: InferenceConfig
118
144
 
119
- __slots__ = ("enabled", "generation", "checks", "max_steps")
145
+ __slots__ = ("enabled", "generation", "checks", "max_steps", "inference")
120
146
 
121
147
  def __init__(
122
148
  self,
@@ -125,11 +151,13 @@ class StatefulPhaseConfig(DiffBase):
125
151
  generation: GenerationConfig | None = None,
126
152
  checks: ChecksConfig | None = None,
127
153
  max_steps: int | None = None,
154
+ inference: InferenceConfig | None = None,
128
155
  ) -> None:
129
156
  self.enabled = enabled
130
157
  self.max_steps = max_steps or DEFAULT_STATEFUL_STEP_COUNT
131
158
  self.generation = generation or GenerationConfig()
132
159
  self.checks = checks or ChecksConfig()
160
+ self.inference = inference or InferenceConfig()
133
161
 
134
162
  @classmethod
135
163
  def from_dict(cls, data: dict[str, Any]) -> StatefulPhaseConfig:
@@ -138,6 +166,7 @@ class StatefulPhaseConfig(DiffBase):
138
166
  max_steps=data.get("max-steps"),
139
167
  generation=GenerationConfig.from_dict(data.get("generation", {})),
140
168
  checks=ChecksConfig.from_dict(data.get("checks", {})),
169
+ inference=InferenceConfig.from_dict(data.get("inference", {})),
141
170
  )
142
171
 
143
172
 
@@ -173,11 +202,20 @@ class PhasesConfig(DiffBase):
173
202
 
174
203
  @classmethod
175
204
  def from_dict(cls, data: dict[str, Any]) -> PhasesConfig:
205
+ # Use the outer "enabled" value as default for all phases.
206
+ default_enabled = data.get("enabled", None)
207
+
208
+ def merge(sub: dict[str, Any]) -> dict[str, Any]:
209
+ # Merge the default enabled flag with the sub-dict; the sub-dict takes precedence.
210
+ if default_enabled is not None:
211
+ return {"enabled": default_enabled, **sub}
212
+ return sub
213
+
176
214
  return cls(
177
- examples=ExamplesPhaseConfig.from_dict(data.get("examples", {})),
178
- coverage=CoveragePhaseConfig.from_dict(data.get("coverage", {})),
179
- fuzzing=PhaseConfig.from_dict(data.get("fuzzing", {})),
180
- stateful=StatefulPhaseConfig.from_dict(data.get("stateful", {})),
215
+ examples=ExamplesPhaseConfig.from_dict(merge(data.get("examples", {}))),
216
+ coverage=CoveragePhaseConfig.from_dict(merge(data.get("coverage", {}))),
217
+ fuzzing=PhaseConfig.from_dict(merge(data.get("fuzzing", {}))),
218
+ stateful=StatefulPhaseConfig.from_dict(merge(data.get("stateful", {}))),
181
219
  )
182
220
 
183
221
  def update(self, *, phases: list[str]) -> None:
@@ -54,6 +54,7 @@ class ProjectConfig(DiffBase):
54
54
  continue_on_failure: bool | None
55
55
  tls_verify: bool | str | None
56
56
  rate_limit: Limiter | None
57
+ max_redirects: int | None
57
58
  request_timeout: float | int | None
58
59
  request_cert: str | None
59
60
  request_cert_key: str | None
@@ -76,6 +77,7 @@ class ProjectConfig(DiffBase):
76
77
  "tls_verify",
77
78
  "rate_limit",
78
79
  "_rate_limit",
80
+ "max_redirects",
79
81
  "request_timeout",
80
82
  "request_cert",
81
83
  "request_cert_key",
@@ -100,6 +102,7 @@ class ProjectConfig(DiffBase):
100
102
  continue_on_failure: bool | None = None,
101
103
  tls_verify: bool | str = True,
102
104
  rate_limit: str | None = None,
105
+ max_redirects: int | None = None,
103
106
  request_timeout: float | int | None = None,
104
107
  request_cert: str | None = None,
105
108
  request_cert_key: str | None = None,
@@ -133,6 +136,7 @@ class ProjectConfig(DiffBase):
133
136
  else:
134
137
  self.rate_limit = rate_limit
135
138
  self._rate_limit = rate_limit
139
+ self.max_redirects = max_redirects
136
140
  self.request_timeout = request_timeout
137
141
  self.request_cert = request_cert
138
142
  self.request_cert_key = request_cert_key
@@ -157,6 +161,7 @@ class ProjectConfig(DiffBase):
157
161
  continue_on_failure=data.get("continue-on-failure", None),
158
162
  tls_verify=resolve(data.get("tls-verify", True)),
159
163
  rate_limit=resolve(data.get("rate-limit")),
164
+ max_redirects=data.get("max-redirects"),
160
165
  request_timeout=data.get("request-timeout"),
161
166
  request_cert=resolve(data.get("request-cert")),
162
167
  request_cert_key=resolve(data.get("request-cert-key")),
@@ -188,6 +193,7 @@ class ProjectConfig(DiffBase):
188
193
  workers: int | Literal["auto"] | None = None,
189
194
  continue_on_failure: bool | None = None,
190
195
  rate_limit: str | None = None,
196
+ max_redirects: int | None = None,
191
197
  request_timeout: float | int | None = None,
192
198
  tls_verify: bool | str | None = None,
193
199
  request_cert: str | None = None,
@@ -221,6 +227,9 @@ class ProjectConfig(DiffBase):
221
227
  if rate_limit is not None:
222
228
  self.rate_limit = build_limiter(rate_limit)
223
229
 
230
+ if max_redirects is not None:
231
+ self.max_redirects = max_redirects
232
+
224
233
  if request_timeout is not None:
225
234
  self.request_timeout = request_timeout
226
235
 
@@ -264,6 +273,15 @@ class ProjectConfig(DiffBase):
264
273
  headers.update(config.headers)
265
274
  return headers
266
275
 
276
+ def max_redirects_for(self, *, operation: APIOperation | None = None) -> int | None:
277
+ if operation is not None:
278
+ config = self.operations.get_for_operation(operation=operation)
279
+ if config.max_redirects is not None:
280
+ return config.max_redirects
281
+ if self.max_redirects is not None:
282
+ return self.max_redirects
283
+ return None
284
+
267
285
  def request_timeout_for(self, *, operation: APIOperation | None = None) -> float | int | None:
268
286
  if operation is not None:
269
287
  config = self.operations.get_for_operation(operation=operation)
@@ -119,6 +119,10 @@
119
119
  "rate-limit": {
120
120
  "type": "string"
121
121
  },
122
+ "max-redirects": {
123
+ "type": "integer",
124
+ "minimum": 0
125
+ },
122
126
  "request-timeout": {
123
127
  "type": "number",
124
128
  "minimum": 0
@@ -251,6 +255,9 @@
251
255
  "type": "object",
252
256
  "additionalProperties": false,
253
257
  "properties": {
258
+ "enabled": {
259
+ "type": "boolean"
260
+ },
254
261
  "examples": {
255
262
  "$ref": "#/$defs/ExamplesPhaseConfig"
256
263
  },
@@ -309,6 +316,22 @@
309
316
  "type": "integer",
310
317
  "minimum": 2
311
318
  },
319
+ "inference": {
320
+ "type": "object",
321
+ "additionalProperties": false,
322
+ "properties": {
323
+ "algorithms": {
324
+ "type": "array",
325
+ "items": {
326
+ "type": "string",
327
+ "enum": [
328
+ "location-headers"
329
+ ]
330
+ },
331
+ "uniqueItems": true
332
+ }
333
+ }
334
+ },
312
335
  "generation": {
313
336
  "$ref": "#/$defs/GenerationConfig"
314
337
  },
@@ -523,6 +546,10 @@
523
546
  "rate-limit": {
524
547
  "type": "string"
525
548
  },
549
+ "max-redirects": {
550
+ "type": "integer",
551
+ "minimum": 0
552
+ },
526
553
  "request-timeout": {
527
554
  "type": "number",
528
555
  "minimum": 0
@@ -590,6 +617,10 @@
590
617
  "rate-limit": {
591
618
  "type": "string"
592
619
  },
620
+ "max-redirects": {
621
+ "type": "integer",
622
+ "minimum": 0
623
+ },
593
624
  "request-timeout": {
594
625
  "type": "number",
595
626
  "minimum": 0
@@ -6,16 +6,18 @@ from typing import TYPE_CHECKING, Any
6
6
 
7
7
  from schemathesis.config import ProjectConfig
8
8
  from schemathesis.core import NOT_SET, NotSet
9
+ from schemathesis.engine.control import ExecutionControl
10
+ from schemathesis.engine.observations import Observations
9
11
  from schemathesis.generation.case import Case
10
12
  from schemathesis.schemas import APIOperation, BaseSchema
11
13
 
12
- from .control import ExecutionControl
13
-
14
14
  if TYPE_CHECKING:
15
15
  import threading
16
16
 
17
17
  import requests
18
18
 
19
+ from schemathesis.engine.recorder import ScenarioRecorder
20
+
19
21
 
20
22
  @dataclass
21
23
  class EngineContext:
@@ -25,20 +27,31 @@ class EngineContext:
25
27
  control: ExecutionControl
26
28
  outcome_cache: dict[int, BaseException | None]
27
29
  start_time: float
28
-
29
- __slots__ = ("schema", "control", "outcome_cache", "start_time", "_session", "_transport_kwargs_cache")
30
+ observations: Observations | None
31
+
32
+ __slots__ = (
33
+ "schema",
34
+ "control",
35
+ "outcome_cache",
36
+ "start_time",
37
+ "observations",
38
+ "_session",
39
+ "_transport_kwargs_cache",
40
+ )
30
41
 
31
42
  def __init__(
32
43
  self,
33
44
  *,
34
45
  schema: BaseSchema,
35
46
  stop_event: threading.Event,
47
+ observations: Observations | None = None,
36
48
  session: requests.Session | None = None,
37
49
  ) -> None:
38
50
  self.schema = schema
39
51
  self.control = ExecutionControl(stop_event=stop_event, max_failures=schema.config.max_failures)
40
52
  self.outcome_cache = {}
41
53
  self.start_time = time.monotonic()
54
+ self.observations = observations
42
55
  self._session = session
43
56
  self._transport_kwargs_cache: dict[str | None, dict[str, Any]] = {}
44
57
 
@@ -65,6 +78,27 @@ class EngineContext:
65
78
  def has_reached_the_failure_limit(self) -> bool:
66
79
  return self.control.has_reached_the_failure_limit
67
80
 
81
+ def record_observations(self, recorder: ScenarioRecorder) -> None:
82
+ """Add new observations from a scenario."""
83
+ if self.observations is not None:
84
+ self.observations.extract_observations_from(recorder)
85
+
86
+ def inject_links(self) -> int:
87
+ """Inject inferred OpenAPI links into API operations based on collected observations."""
88
+ injected = 0
89
+ if self.observations is not None and self.observations.location_headers:
90
+ from schemathesis.specs.openapi.schemas import BaseOpenAPISchema
91
+ from schemathesis.specs.openapi.stateful.inference import LinkInferencer
92
+
93
+ assert isinstance(self.schema, BaseOpenAPISchema)
94
+
95
+ # Generate links from collected Location headers
96
+ inferencer = LinkInferencer.from_schema(self.schema)
97
+ for operation, entries in self.observations.location_headers.items():
98
+ injected += inferencer.inject_links(operation.definition.raw, entries)
99
+
100
+ return injected
101
+
68
102
  def stop(self) -> None:
69
103
  self.control.stop()
70
104
 
@@ -107,6 +141,7 @@ class EngineContext:
107
141
  kwargs: dict[str, Any] = {
108
142
  "session": self.get_session(operation=operation),
109
143
  "headers": config.headers_for(operation=operation),
144
+ "max_redirects": config.max_redirects_for(operation=operation),
110
145
  "timeout": config.request_timeout_for(operation=operation),
111
146
  "verify": config.tls_verify_for(operation=operation),
112
147
  "cert": config.request_cert_for(operation=operation),
@@ -2,15 +2,15 @@ from __future__ import annotations
2
2
 
3
3
  import threading
4
4
  from dataclasses import dataclass
5
- from typing import Sequence
6
5
 
7
- from schemathesis.auths import unregister as unregister_auth
6
+ from schemathesis import auths
8
7
  from schemathesis.core import SpecificationFeature
9
8
  from schemathesis.engine import Status, events, phases
9
+ from schemathesis.engine.observations import Observations
10
10
  from schemathesis.schemas import BaseSchema
11
11
 
12
12
  from .context import EngineContext
13
- from .events import EventGenerator
13
+ from .events import EventGenerator, StatefulPhasePayload
14
14
  from .phases import Phase, PhaseName, PhaseSkipReason
15
15
 
16
16
 
@@ -24,10 +24,20 @@ class Engine:
24
24
  """Execute all test phases."""
25
25
  # Unregister auth if explicitly provided
26
26
  if self.schema.config.auth.is_defined:
27
- unregister_auth()
27
+ auths.unregister()
28
28
 
29
- ctx = EngineContext(schema=self.schema, stop_event=threading.Event())
30
29
  plan = self._create_execution_plan()
30
+
31
+ observations = None
32
+ for phase in plan.phases:
33
+ if (
34
+ phase.name == PhaseName.STATEFUL_TESTING
35
+ and phase.skip_reason in (None, PhaseSkipReason.NOT_APPLICABLE)
36
+ and self.schema.config.phases.stateful.inference.is_enabled
37
+ ):
38
+ observations = Observations()
39
+
40
+ ctx = EngineContext(schema=self.schema, stop_event=threading.Event(), observations=observations)
31
41
  return EventStream(plan.execute(ctx), ctx.control.stop_event)
32
42
 
33
43
  def _create_execution_plan(self) -> ExecutionPlan:
@@ -103,7 +113,7 @@ class Engine:
103
113
  class ExecutionPlan:
104
114
  """Manages test execution phases."""
105
115
 
106
- phases: Sequence[Phase]
116
+ phases: list[Phase]
107
117
 
108
118
  __slots__ = ("phases",)
109
119
 
@@ -120,9 +130,8 @@ class ExecutionPlan:
120
130
 
121
131
  # Run main phases
122
132
  for phase in self.phases:
123
- if engine.has_reached_the_failure_limit:
124
- phase.skip_reason = PhaseSkipReason.FAILURE_LIMIT_REACHED
125
- yield events.PhaseStarted(phase=phase)
133
+ payload = self._adapt_execution(engine, phase)
134
+ yield events.PhaseStarted(phase=phase, payload=payload)
126
135
  if phase.should_execute(engine):
127
136
  yield from phases.execute(engine, phase)
128
137
  else:
@@ -143,6 +152,18 @@ class ExecutionPlan:
143
152
  """Finish the test run."""
144
153
  yield events.EngineFinished(running_time=ctx.running_time)
145
154
 
155
+ def _adapt_execution(self, engine: EngineContext, phase: Phase) -> StatefulPhasePayload | None:
156
+ if engine.has_reached_the_failure_limit:
157
+ phase.skip_reason = PhaseSkipReason.FAILURE_LIMIT_REACHED
158
+ # Phase can be enabled if certain conditions are met
159
+ if phase.name == PhaseName.STATEFUL_TESTING:
160
+ inferred = engine.inject_links()
161
+ # Enable stateful testing if we successfully generated any links
162
+ if inferred:
163
+ phase.enable()
164
+ return StatefulPhasePayload(inferred_links=inferred)
165
+ return None
166
+
146
167
 
147
168
  @dataclass
148
169
  class EventStream:
@@ -45,16 +45,26 @@ class PhaseEvent(EngineEvent):
45
45
  phase: Phase
46
46
 
47
47
 
48
+ @dataclass
49
+ class StatefulPhasePayload:
50
+ inferred_links: int
51
+
52
+ __slots__ = ("inferred_links",)
53
+
54
+
48
55
  @dataclass
49
56
  class PhaseStarted(PhaseEvent):
50
57
  """Start of an execution phase."""
51
58
 
52
- __slots__ = ("id", "timestamp", "phase")
59
+ payload: StatefulPhasePayload | None
60
+
61
+ __slots__ = ("id", "timestamp", "phase", "payload")
53
62
 
54
- def __init__(self, *, phase: Phase) -> None:
63
+ def __init__(self, *, phase: Phase, payload: StatefulPhasePayload | None) -> None:
55
64
  self.id = uuid.uuid4()
56
65
  self.timestamp = time.time()
57
66
  self.phase = phase
67
+ self.payload = payload
58
68
 
59
69
 
60
70
  @dataclass
@@ -0,0 +1,42 @@
1
+ from dataclasses import dataclass
2
+
3
+ from schemathesis.engine.recorder import ScenarioRecorder
4
+ from schemathesis.schemas import APIOperation
5
+
6
+
7
+ @dataclass
8
+ class LocationHeaderEntry:
9
+ """Value of `Location` coming from API response with a given status code."""
10
+
11
+ status_code: int
12
+ value: str
13
+
14
+ __slots__ = ("status_code", "value")
15
+
16
+
17
+ @dataclass
18
+ class Observations:
19
+ """Repository for observations collected during test execution."""
20
+
21
+ location_headers: dict[APIOperation, list[LocationHeaderEntry]]
22
+
23
+ __slots__ = ("location_headers",)
24
+
25
+ def __init__(self) -> None:
26
+ self.location_headers = {}
27
+
28
+ def extract_observations_from(self, recorder: ScenarioRecorder) -> None:
29
+ """Extract observations from completed test scenario."""
30
+ for id, interaction in recorder.interactions.items():
31
+ response = interaction.response
32
+ if response is not None:
33
+ location = response.headers.get("location")
34
+ if location:
35
+ # Group location headers by the operation that produced them
36
+ entries = self.location_headers.setdefault(recorder.cases[id].value.operation, [])
37
+ entries.append(
38
+ LocationHeaderEntry(
39
+ status_code=response.status_code,
40
+ value=location[0],
41
+ )
42
+ )
@@ -77,6 +77,11 @@ class Phase:
77
77
  """Determine if phase should run based on context & configuration."""
78
78
  return self.is_enabled and not ctx.has_to_stop
79
79
 
80
+ def enable(self) -> None:
81
+ """Enable this test phase."""
82
+ self.is_enabled = True
83
+ self.skip_reason = None
84
+
80
85
 
81
86
  def execute(ctx: EngineContext, phase: Phase) -> EventGenerator:
82
87
  from urllib3.exceptions import InsecureRequestWarning