schemathesis 3.33.2__py3-none-any.whl → 3.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. schemathesis/auths.py +71 -13
  2. schemathesis/checks.py +2 -0
  3. schemathesis/cli/__init__.py +10 -0
  4. schemathesis/cli/callbacks.py +3 -6
  5. schemathesis/cli/junitxml.py +20 -17
  6. schemathesis/cli/sanitization.py +5 -0
  7. schemathesis/exceptions.py +8 -0
  8. schemathesis/failures.py +20 -0
  9. schemathesis/generation/__init__.py +2 -0
  10. schemathesis/hooks.py +81 -8
  11. schemathesis/internal/transformation.py +10 -0
  12. schemathesis/models.py +12 -26
  13. schemathesis/runner/events.py +1 -0
  14. schemathesis/runner/impl/core.py +13 -1
  15. schemathesis/sanitization.py +1 -0
  16. schemathesis/schemas.py +12 -2
  17. schemathesis/service/serialization.py +1 -0
  18. schemathesis/specs/graphql/schemas.py +4 -0
  19. schemathesis/specs/openapi/checks.py +249 -12
  20. schemathesis/specs/openapi/examples.py +18 -1
  21. schemathesis/specs/openapi/links.py +45 -14
  22. schemathesis/specs/openapi/schemas.py +33 -17
  23. schemathesis/specs/openapi/stateful/__init__.py +18 -7
  24. schemathesis/stateful/__init__.py +20 -16
  25. schemathesis/stateful/config.py +16 -4
  26. schemathesis/stateful/runner.py +1 -1
  27. schemathesis/stateful/state_machine.py +20 -1
  28. schemathesis/transports/__init__.py +9 -1
  29. {schemathesis-3.33.2.dist-info → schemathesis-3.34.0.dist-info}/METADATA +11 -3
  30. {schemathesis-3.33.2.dist-info → schemathesis-3.34.0.dist-info}/RECORD +33 -33
  31. {schemathesis-3.33.2.dist-info → schemathesis-3.34.0.dist-info}/WHEEL +0 -0
  32. {schemathesis-3.33.2.dist-info → schemathesis-3.34.0.dist-info}/entry_points.txt +0 -0
  33. {schemathesis-3.33.2.dist-info → schemathesis-3.34.0.dist-info}/licenses/LICENSE +0 -0
schemathesis/auths.py CHANGED
@@ -14,6 +14,7 @@ from typing import (
14
14
  Generic,
15
15
  Protocol,
16
16
  TypeVar,
17
+ Union,
17
18
  overload,
18
19
  runtime_checkable,
19
20
  )
@@ -44,6 +45,9 @@ class AuthContext:
44
45
  app: Any | None
45
46
 
46
47
 
48
+ CacheKeyFunction = Callable[["Case", "AuthContext"], Union[str, int]]
49
+
50
+
47
51
  @runtime_checkable
48
52
  class AuthProvider(Generic[Auth], Protocol):
49
53
  """Get authentication data for an API and set it on the generated test cases."""
@@ -99,16 +103,24 @@ class CachingAuthProvider(Generic[Auth]):
99
103
 
100
104
  def get(self, case: Case, context: AuthContext) -> Auth | None:
101
105
  """Get cached auth value."""
102
- if self.cache_entry is None or self.timer() >= self.cache_entry.expires:
106
+ cache_entry = self._get_cache_entry(case, context)
107
+ if cache_entry is None or self.timer() >= cache_entry.expires:
103
108
  with self._refresh_lock:
104
- if not (self.cache_entry is None or self.timer() >= self.cache_entry.expires):
109
+ cache_entry = self._get_cache_entry(case, context)
110
+ if not (cache_entry is None or self.timer() >= cache_entry.expires):
105
111
  # Another thread updated the cache
106
- return self.cache_entry.data
112
+ return cache_entry.data
107
113
  # We know that optional auth is possible only inside a higher-level wrapper
108
114
  data: Auth = _provider_get(self.provider, case, context) # type: ignore[assignment]
109
- self.cache_entry = CacheEntry(data=data, expires=self.timer() + self.refresh_interval)
115
+ self._set_cache_entry(data, case, context)
110
116
  return data
111
- return self.cache_entry.data
117
+ return cache_entry.data
118
+
119
+ def _get_cache_entry(self, case: Case, context: AuthContext) -> CacheEntry[Auth] | None:
120
+ return self.cache_entry
121
+
122
+ def _set_cache_entry(self, data: Auth, case: Case, context: AuthContext) -> None:
123
+ self.cache_entry = CacheEntry(data=data, expires=self.timer() + self.refresh_interval)
112
124
 
113
125
  def set(self, case: Case, data: Auth, context: AuthContext) -> None:
114
126
  """Set auth data on the `Case` instance.
@@ -118,6 +130,25 @@ class CachingAuthProvider(Generic[Auth]):
118
130
  self.provider.set(case, data, context)
119
131
 
120
132
 
133
+ def _noop_key_function(case: Case, context: AuthContext) -> str:
134
+ # Never used
135
+ raise NotImplementedError
136
+
137
+
138
+ @dataclass
139
+ class KeyedCachingAuthProvider(CachingAuthProvider[Auth]):
140
+ cache_by_key: CacheKeyFunction = _noop_key_function
141
+ cache_entries: dict[str | int, CacheEntry[Auth] | None] = field(default_factory=dict)
142
+
143
+ def _get_cache_entry(self, case: Case, context: AuthContext) -> CacheEntry[Auth] | None:
144
+ key = self.cache_by_key(case, context)
145
+ return self.cache_entries.get(key)
146
+
147
+ def _set_cache_entry(self, data: Auth, case: Case, context: AuthContext) -> None:
148
+ key = self.cache_by_key(case, context)
149
+ self.cache_entries[key] = CacheEntry(data=data, expires=self.timer() + self.refresh_interval)
150
+
151
+
121
152
  class FilterableRegisterAuth(Protocol):
122
153
  """Protocol that adds filters to the return value of `register`."""
123
154
 
@@ -246,6 +277,7 @@ class AuthStorage(Generic[Auth]):
246
277
  self,
247
278
  *,
248
279
  refresh_interval: int | None = DEFAULT_REFRESH_INTERVAL,
280
+ cache_by_key: CacheKeyFunction | None = None,
249
281
  ) -> FilterableRegisterAuth:
250
282
  pass
251
283
 
@@ -255,6 +287,7 @@ class AuthStorage(Generic[Auth]):
255
287
  provider_class: type[AuthProvider],
256
288
  *,
257
289
  refresh_interval: int | None = DEFAULT_REFRESH_INTERVAL,
290
+ cache_by_key: CacheKeyFunction | None = None,
258
291
  ) -> FilterableApplyAuth:
259
292
  pass
260
293
 
@@ -263,10 +296,11 @@ class AuthStorage(Generic[Auth]):
263
296
  provider_class: type[AuthProvider] | None = None,
264
297
  *,
265
298
  refresh_interval: int | None = DEFAULT_REFRESH_INTERVAL,
299
+ cache_by_key: CacheKeyFunction | None = None,
266
300
  ) -> FilterableRegisterAuth | FilterableApplyAuth:
267
301
  if provider_class is not None:
268
- return self.apply(provider_class, refresh_interval=refresh_interval)
269
- return self.register(refresh_interval=refresh_interval)
302
+ return self.apply(provider_class, refresh_interval=refresh_interval, cache_by_key=cache_by_key)
303
+ return self.register(refresh_interval=refresh_interval, cache_by_key=cache_by_key)
270
304
 
271
305
  def set_from_requests(self, auth: requests.auth.AuthBase) -> FilterableRequestsAuth:
272
306
  """Use `requests` auth instance as an auth provider."""
@@ -286,6 +320,7 @@ class AuthStorage(Generic[Auth]):
286
320
  *,
287
321
  provider_class: type[AuthProvider],
288
322
  refresh_interval: int | None = DEFAULT_REFRESH_INTERVAL,
323
+ cache_by_key: CacheKeyFunction | None = None,
289
324
  filter_set: FilterSet,
290
325
  ) -> None:
291
326
  if not issubclass(provider_class, AuthProvider):
@@ -295,16 +330,27 @@ class AuthStorage(Generic[Auth]):
295
330
  )
296
331
  provider: AuthProvider
297
332
  # Apply caching if desired
333
+ instance = provider_class()
298
334
  if refresh_interval is not None:
299
- provider = CachingAuthProvider(provider_class(), refresh_interval=refresh_interval)
335
+ if cache_by_key is None:
336
+ provider = CachingAuthProvider(instance, refresh_interval=refresh_interval)
337
+ else:
338
+ provider = KeyedCachingAuthProvider(
339
+ instance, refresh_interval=refresh_interval, cache_by_key=cache_by_key
340
+ )
300
341
  else:
301
- provider = provider_class()
342
+ provider = instance
302
343
  # Store filters if any
303
344
  if not filter_set.is_empty():
304
345
  provider = SelectiveAuthProvider(provider, filter_set)
305
346
  self.providers.append(provider)
306
347
 
307
- def register(self, *, refresh_interval: int | None = DEFAULT_REFRESH_INTERVAL) -> FilterableRegisterAuth:
348
+ def register(
349
+ self,
350
+ *,
351
+ refresh_interval: int | None = DEFAULT_REFRESH_INTERVAL,
352
+ cache_by_key: CacheKeyFunction | None = None,
353
+ ) -> FilterableRegisterAuth:
308
354
  """Register a new auth provider.
309
355
 
310
356
  .. code-block:: python
@@ -326,7 +372,12 @@ class AuthStorage(Generic[Auth]):
326
372
  filter_set = FilterSet()
327
373
 
328
374
  def wrapper(provider_class: type[AuthProvider]) -> type[AuthProvider]:
329
- self._set_provider(provider_class=provider_class, refresh_interval=refresh_interval, filter_set=filter_set)
375
+ self._set_provider(
376
+ provider_class=provider_class,
377
+ refresh_interval=refresh_interval,
378
+ filter_set=filter_set,
379
+ cache_by_key=cache_by_key,
380
+ )
330
381
  return provider_class
331
382
 
332
383
  attach_filter_chain(wrapper, "apply_to", filter_set.include)
@@ -342,7 +393,11 @@ class AuthStorage(Generic[Auth]):
342
393
  self.providers = []
343
394
 
344
395
  def apply(
345
- self, provider_class: type[AuthProvider], *, refresh_interval: int | None = DEFAULT_REFRESH_INTERVAL
396
+ self,
397
+ provider_class: type[AuthProvider],
398
+ *,
399
+ refresh_interval: int | None = DEFAULT_REFRESH_INTERVAL,
400
+ cache_by_key: CacheKeyFunction | None = None,
346
401
  ) -> FilterableApplyAuth:
347
402
  """Register auth provider only on one test function.
348
403
 
@@ -366,7 +421,10 @@ class AuthStorage(Generic[Auth]):
366
421
  def wrapper(test: GenericTest) -> GenericTest:
367
422
  auth_storage = self.add_auth_storage(test)
368
423
  auth_storage._set_provider(
369
- provider_class=provider_class, refresh_interval=refresh_interval, filter_set=filter_set
424
+ provider_class=provider_class,
425
+ refresh_interval=refresh_interval,
426
+ filter_set=filter_set,
427
+ cache_by_key=cache_by_key,
370
428
  )
371
429
  return test
372
430
 
schemathesis/checks.py CHANGED
@@ -7,6 +7,7 @@ from . import failures
7
7
  from .exceptions import get_response_parsing_error, get_server_error
8
8
  from .specs.openapi.checks import (
9
9
  content_type_conformance,
10
+ ignored_auth,
10
11
  negative_data_rejection,
11
12
  response_headers_conformance,
12
13
  response_schema_conformance,
@@ -50,6 +51,7 @@ OPTIONAL_CHECKS = (
50
51
  response_headers_conformance,
51
52
  response_schema_conformance,
52
53
  negative_data_rejection,
54
+ ignored_auth,
53
55
  )
54
56
  ALL_CHECKS: tuple[CheckFunction, ...] = DEFAULT_CHECKS + OPTIONAL_CHECKS
55
57
 
@@ -795,6 +795,14 @@ The report data, consisting of a tar gz file with multiple JSON files, is subjec
795
795
  show_default=True,
796
796
  callback=callbacks.convert_boolean_string,
797
797
  )
798
+ @click.option(
799
+ "--generation-graphql-allow-null",
800
+ help="Whether `null` values should be used for optional arguments in GraphQL queries.",
801
+ type=str,
802
+ default="true",
803
+ show_default=True,
804
+ callback=callbacks.convert_boolean_string,
805
+ )
798
806
  @click.option(
799
807
  "--schemathesis-io-token",
800
808
  help="Schemathesis.io authentication token.",
@@ -911,6 +919,7 @@ def run(
911
919
  no_color: bool = False,
912
920
  report_value: str | None = None,
913
921
  generation_allow_x00: bool = True,
922
+ generation_graphql_allow_null: bool = True,
914
923
  generation_with_security_parameters: bool = True,
915
924
  generation_codec: str = "utf-8",
916
925
  schemathesis_io_token: str | None = None,
@@ -953,6 +962,7 @@ def run(
953
962
 
954
963
  generation_config = generation.GenerationConfig(
955
964
  allow_x00=generation_allow_x00,
965
+ graphql_allow_null=generation_graphql_allow_null,
956
966
  codec=generation_codec,
957
967
  with_security_parameters=generation_with_security_parameters,
958
968
  )
@@ -15,9 +15,10 @@ from click.types import LazyFile # type: ignore
15
15
 
16
16
  from .. import exceptions, experimental, throttling
17
17
  from ..code_samples import CodeSampleStyle
18
- from ..constants import FALSE_VALUES, TRUE_VALUES
18
+ from ..constants import TRUE_VALUES
19
19
  from ..exceptions import extract_nth_traceback
20
20
  from ..generation import DataGenerationMethod
21
+ from ..internal.transformation import convert_boolean_string as _convert_boolean_string
21
22
  from ..internal.validation import file_exists, is_filename, is_illegal_surrogate
22
23
  from ..loaders import load_app
23
24
  from ..service.hosts import get_temporary_hosts_file
@@ -378,11 +379,7 @@ def convert_hosts_file(ctx: click.core.Context, param: click.core.Parameter, val
378
379
 
379
380
 
380
381
  def convert_boolean_string(ctx: click.core.Context, param: click.core.Parameter, value: str) -> str | bool:
381
- if value.lower() in TRUE_VALUES:
382
- return True
383
- if value.lower() in FALSE_VALUES:
384
- return False
385
- return value
382
+ return _convert_boolean_string(value)
386
383
 
387
384
 
388
385
  def convert_report(ctx: click.core.Context, param: click.core.Option, value: LazyFile) -> LazyFile:
@@ -27,28 +27,31 @@ class JunitXMLHandler(EventHandler):
27
27
  test_cases: list = field(default_factory=list)
28
28
 
29
29
  def handle_event(self, context: ExecutionContext, event: events.ExecutionEvent) -> None:
30
- if isinstance(event, events.AfterExecution):
31
- test_case = TestCase(
32
- f"{event.result.method} {event.result.path}",
33
- elapsed_sec=event.elapsed_time,
34
- allow_multiple_subelements=True,
35
- )
36
- if event.status == Status.failure:
37
- for idx, (code_sample, group) in enumerate(
38
- group_by_case(event.result.checks, context.code_sample_style), 1
39
- ):
40
- checks = sorted(group, key=lambda c: c.name != "not_a_server_error")
41
- test_case.add_failure_info(message=build_failure_message(context, idx, code_sample, checks))
42
- elif event.status == Status.error:
43
- test_case.add_error_info(message=build_error_message(context, event.result.errors[-1]))
44
- elif event.status == Status.skip:
45
- test_case.add_skipped_info(message=event.result.skip_reason)
30
+ if isinstance(event, (events.AfterExecution, events.AfterStatefulExecution)):
31
+ event_: events.AfterExecution | events.AfterStatefulExecution = event
32
+ if isinstance(event_, events.AfterExecution):
33
+ name = f"{event_.result.method} {event_.result.path}"
34
+ else:
35
+ name = event_.result.verbose_name
36
+ test_case = TestCase(name, elapsed_sec=event_.elapsed_time, allow_multiple_subelements=True)
37
+ if event_.status == Status.failure:
38
+ _add_failure(test_case, event_.result.checks, context)
39
+ elif event_.status == Status.error:
40
+ test_case.add_error_info(message=build_error_message(context, event_.result.errors[-1]))
41
+ elif event_.status == Status.skip:
42
+ test_case.add_skipped_info(message=event_.result.skip_reason)
46
43
  self.test_cases.append(test_case)
47
- if isinstance(event, events.Finished):
44
+ elif isinstance(event, events.Finished):
48
45
  test_suites = [TestSuite("schemathesis", test_cases=self.test_cases, hostname=platform.node())]
49
46
  to_xml_report_file(file_descriptor=self.file_handle, test_suites=test_suites, prettyprint=True)
50
47
 
51
48
 
49
+ def _add_failure(test_case: TestCase, checks: list[SerializedCheck], context: ExecutionContext) -> None:
50
+ for idx, (code_sample, group) in enumerate(group_by_case(checks, context.code_sample_style), 1):
51
+ checks = sorted(group, key=lambda c: c.name != "not_a_server_error")
52
+ test_case.add_failure_info(message=build_failure_message(context, idx, code_sample, checks))
53
+
54
+
52
55
  def build_failure_message(context: ExecutionContext, idx: int, code_sample: str, checks: list[SerializedCheck]) -> str:
53
56
  from ..transports.responses import get_reason
54
57
 
@@ -19,3 +19,8 @@ class SanitizationHandler(EventHandler):
19
19
  sanitize_serialized_check(check)
20
20
  for interaction in event.result.interactions:
21
21
  sanitize_serialized_interaction(interaction)
22
+ elif isinstance(event, events.AfterStatefulExecution):
23
+ for check in event.result.checks:
24
+ sanitize_serialized_check(check)
25
+ for interaction in event.result.interactions:
26
+ sanitize_serialized_interaction(interaction)
@@ -151,6 +151,14 @@ def get_use_after_free_error(free: str) -> type[CheckFailed]:
151
151
  return _get_hashed_exception("UseAfterFreeError", free)
152
152
 
153
153
 
154
+ def get_ensure_resource_availability_error(operation: str) -> type[CheckFailed]:
155
+ return _get_hashed_exception("EnsureResourceAvailabilityError", operation)
156
+
157
+
158
+ def get_ignored_auth_error(operation: str) -> type[CheckFailed]:
159
+ return _get_hashed_exception("IgnoredAuthError", operation)
160
+
161
+
154
162
  def get_timeout_error(prefix: str, deadline: float | int) -> type[CheckFailed]:
155
163
  """Request took too long."""
156
164
  return _get_hashed_exception(f"TimeoutError{prefix}", str(deadline))
schemathesis/failures.py CHANGED
@@ -151,6 +151,26 @@ class UseAfterFree(FailureContext):
151
151
  type: str = "use_after_free"
152
152
 
153
153
 
154
+ @dataclass(repr=False)
155
+ class EnsureResourceAvailability(FailureContext):
156
+ """Resource is not available immediately after creation."""
157
+
158
+ message: str
159
+ created_with: str
160
+ not_available_with: str
161
+ title: str = "Resource is not available after creation"
162
+ type: str = "ensure_resource_availability"
163
+
164
+
165
+ @dataclass(repr=False)
166
+ class IgnoredAuth(FailureContext):
167
+ """The API operation does not check the specified authentication."""
168
+
169
+ message: str
170
+ title: str = "Authentication declared but not enforced for this operation"
171
+ type: str = "ignored_auth"
172
+
173
+
154
174
  @dataclass(repr=False)
155
175
  class UndefinedStatusCode(FailureContext):
156
176
  """Response has a status code that is not defined in the schema."""
@@ -75,6 +75,8 @@ class GenerationConfig:
75
75
 
76
76
  # Allow generating `\x00` bytes in strings
77
77
  allow_x00: bool = True
78
+ # Allowing using `null` for optional arguments in GraphQL queries
79
+ graphql_allow_null: bool = True
78
80
  # Generate strings using the given codec
79
81
  codec: str | None = "utf-8"
80
82
  # Whether to generate security parameters
schemathesis/hooks.py CHANGED
@@ -8,6 +8,7 @@ from enum import Enum, unique
8
8
  from functools import partial
9
9
  from typing import TYPE_CHECKING, Any, Callable, ClassVar, DefaultDict, cast
10
10
 
11
+ from .filters import FilterSet, attach_filter_chain
11
12
  from .internal.deprecation import deprecated_property
12
13
  from .types import GenericTest
13
14
 
@@ -47,6 +48,58 @@ class HookContext:
47
48
  return self.operation
48
49
 
49
50
 
51
+ def to_filterable_hook(dispatcher: HookDispatcher) -> Callable:
52
+ filter_used = False
53
+ filter_set = FilterSet()
54
+
55
+ def register(hook: str | Callable) -> Callable:
56
+ nonlocal filter_set
57
+
58
+ if filter_used:
59
+ validate_filterable_hook(hook)
60
+
61
+ if isinstance(hook, str):
62
+
63
+ def decorator(func: Callable) -> Callable:
64
+ hook_name = cast(str, hook)
65
+ if filter_used:
66
+ validate_filterable_hook(hook)
67
+ func.filter_set = filter_set # type: ignore[attr-defined]
68
+ return dispatcher.register_hook_with_name(func, hook_name)
69
+
70
+ init_filter_set(decorator)
71
+ return decorator
72
+
73
+ hook.filter_set = filter_set # type: ignore[attr-defined]
74
+ init_filter_set(register)
75
+ return dispatcher.register_hook_with_name(hook, hook.__name__)
76
+
77
+ def init_filter_set(target: Callable) -> FilterSet:
78
+ nonlocal filter_used
79
+
80
+ filter_used = False
81
+ filter_set = FilterSet()
82
+
83
+ def include(*args: Any, **kwargs: Any) -> None:
84
+ nonlocal filter_used
85
+
86
+ filter_used = True
87
+ filter_set.include(*args, **kwargs)
88
+
89
+ def exclude(*args: Any, **kwargs: Any) -> None:
90
+ nonlocal filter_used
91
+
92
+ filter_used = True
93
+ filter_set.exclude(*args, **kwargs)
94
+
95
+ attach_filter_chain(target, "apply_to", include)
96
+ attach_filter_chain(target, "skip_for", exclude)
97
+ return filter_set
98
+
99
+ filter_set = init_filter_set(register)
100
+ return register
101
+
102
+
50
103
  @dataclass
51
104
  class HookDispatcher:
52
105
  """Generic hook dispatcher.
@@ -58,6 +111,9 @@ class HookDispatcher:
58
111
  _hooks: DefaultDict[str, list[Callable]] = field(default_factory=lambda: defaultdict(list))
59
112
  _specs: ClassVar[dict[str, RegisteredHook]] = {}
60
113
 
114
+ def __post_init__(self) -> None:
115
+ self.register = to_filterable_hook(self) # type: ignore[method-assign]
116
+
61
117
  def register(self, hook: str | Callable) -> Callable:
62
118
  """Register a new hook.
63
119
 
@@ -80,14 +136,7 @@ class HookDispatcher:
80
136
  def hook(context, strategy):
81
137
  ...
82
138
  """
83
- if isinstance(hook, str):
84
-
85
- def decorator(func: Callable) -> Callable:
86
- hook_name = cast(str, hook)
87
- return self.register_hook_with_name(func, hook_name)
88
-
89
- return decorator
90
- return self.register_hook_with_name(hook, hook.__name__)
139
+ raise NotImplementedError
91
140
 
92
141
  def merge(self, other: HookDispatcher) -> HookDispatcher:
93
142
  """Merge two dispatches together.
@@ -192,14 +241,22 @@ class HookDispatcher:
192
241
  self, strategy: st.SearchStrategy, container: str, context: HookContext
193
242
  ) -> st.SearchStrategy:
194
243
  for hook in self.get_all_by_name(f"before_generate_{container}"):
244
+ if _should_skip_hook(hook, context):
245
+ continue
195
246
  strategy = hook(context, strategy)
196
247
  for hook in self.get_all_by_name(f"filter_{container}"):
248
+ if _should_skip_hook(hook, context):
249
+ continue
197
250
  hook = partial(hook, context)
198
251
  strategy = strategy.filter(hook)
199
252
  for hook in self.get_all_by_name(f"map_{container}"):
253
+ if _should_skip_hook(hook, context):
254
+ continue
200
255
  hook = partial(hook, context)
201
256
  strategy = strategy.map(hook)
202
257
  for hook in self.get_all_by_name(f"flatmap_{container}"):
258
+ if _should_skip_hook(hook, context):
259
+ continue
203
260
  hook = partial(hook, context)
204
261
  strategy = strategy.flatmap(hook)
205
262
  return strategy
@@ -207,6 +264,8 @@ class HookDispatcher:
207
264
  def dispatch(self, name: str, context: HookContext, *args: Any, **kwargs: Any) -> None:
208
265
  """Run all hooks for the given name."""
209
266
  for hook in self.get_all_by_name(name):
267
+ if _should_skip_hook(hook, context):
268
+ continue
210
269
  hook(context, *args, **kwargs)
211
270
 
212
271
  def unregister(self, hook: Callable) -> None:
@@ -226,6 +285,11 @@ class HookDispatcher:
226
285
  self._hooks = defaultdict(list)
227
286
 
228
287
 
288
+ def _should_skip_hook(hook: Callable, context: HookContext) -> bool:
289
+ filter_set = getattr(hook, "filter_set", None)
290
+ return filter_set is not None and not filter_set.match(context)
291
+
292
+
229
293
  def apply_to_all_dispatchers(
230
294
  operation: APIOperation,
231
295
  context: HookContext,
@@ -248,6 +312,15 @@ def should_skip_operation(dispatcher: HookDispatcher, context: HookContext) -> b
248
312
  return False
249
313
 
250
314
 
315
+ def validate_filterable_hook(hook: str | Callable) -> None:
316
+ if callable(hook):
317
+ name = hook.__name__
318
+ else:
319
+ name = hook
320
+ if name in ("before_process_path", "before_load_schema", "after_load_schema", "after_init_cli_run_handlers"):
321
+ raise ValueError(f"Filters are not applicable to this hook: `{name}`")
322
+
323
+
251
324
  all_scopes = HookDispatcher.register_spec(list(HookScope))
252
325
 
253
326
 
@@ -2,6 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any
4
4
 
5
+ from ..constants import FALSE_VALUES, TRUE_VALUES
6
+
5
7
 
6
8
  def merge_recursively(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]:
7
9
  """Merge two dictionaries recursively."""
@@ -14,3 +16,11 @@ def merge_recursively(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]:
14
16
  else:
15
17
  a[key] = b[key]
16
18
  return a
19
+
20
+
21
+ def convert_boolean_string(value: str) -> str | bool:
22
+ if value.lower() in TRUE_VALUES:
23
+ return True
24
+ if value.lower() in FALSE_VALUES:
25
+ return False
26
+ return value
schemathesis/models.py CHANGED
@@ -24,7 +24,7 @@ from typing import (
24
24
  TypeVar,
25
25
  cast,
26
26
  )
27
- from urllib.parse import quote, unquote, urljoin, urlparse, urlsplit, urlunsplit
27
+ from urllib.parse import quote, unquote, urljoin, urlsplit, urlunsplit
28
28
 
29
29
  from . import serializers
30
30
  from ._dependency_versions import IS_WERKZEUG_ABOVE_3
@@ -78,9 +78,15 @@ class CaseSource:
78
78
  case: Case
79
79
  response: GenericResponse
80
80
  elapsed: float
81
+ overrides_all_parameters: bool
81
82
 
82
83
  def partial_deepcopy(self) -> CaseSource:
83
- return self.__class__(case=self.case.partial_deepcopy(), response=self.response, elapsed=self.elapsed)
84
+ return self.__class__(
85
+ case=self.case.partial_deepcopy(),
86
+ response=self.response,
87
+ elapsed=self.elapsed,
88
+ overrides_all_parameters=self.overrides_all_parameters,
89
+ )
84
90
 
85
91
 
86
92
  def cant_serialize(media_type: str) -> NoReturn: # type: ignore
@@ -198,8 +204,10 @@ class Case:
198
204
  def app(self) -> Any:
199
205
  return self.operation.app
200
206
 
201
- def set_source(self, response: GenericResponse, case: Case, elapsed: float) -> None:
202
- self.source = CaseSource(case=case, response=response, elapsed=elapsed)
207
+ def set_source(self, response: GenericResponse, case: Case, elapsed: float, overrides_all_parameters: bool) -> None:
208
+ self.source = CaseSource(
209
+ case=case, response=response, elapsed=elapsed, overrides_all_parameters=overrides_all_parameters
210
+ )
203
211
 
204
212
  @property
205
213
  def formatted_path(self) -> str:
@@ -511,28 +519,6 @@ class Case:
511
519
  )
512
520
 
513
521
 
514
- def _merge_dict_to(data: dict[str, Any], data_key: str, new: dict[str, Any]) -> None:
515
- original = data[data_key] or {}
516
- for key, value in new.items():
517
- original[key] = value
518
- data[data_key] = original
519
-
520
-
521
- def validate_vanilla_requests_kwargs(data: dict[str, Any]) -> None:
522
- """Check arguments for `requests.Session.request`.
523
-
524
- Some arguments can be valid for cases like ASGI integration, but at the same time they won't work for the regular
525
- `requests` calls. In such cases we need to avoid an obscure error message, that comes from `requests`.
526
- """
527
- url = data["url"]
528
- if not urlparse(url).netloc:
529
- raise RuntimeError(
530
- "The URL should be absolute, so Schemathesis knows where to send the data. \n"
531
- f"If you use the ASGI integration, please supply your test client "
532
- f"as the `session` argument to `call`.\nURL: {url}"
533
- )
534
-
535
-
536
522
  @contextmanager
537
523
  def cookie_handler(client: werkzeug.Client, cookies: Cookies | None) -> Generator[None, None, None]:
538
524
  """Set cookies required for a call."""
@@ -325,6 +325,7 @@ class AfterStatefulExecution(ExecutionEvent):
325
325
 
326
326
  status: Status
327
327
  result: SerializedTestResult
328
+ elapsed_time: float
328
329
  data_generation_method: list[DataGenerationMethod]
329
330
  thread_id: int = field(default_factory=threading.get_ident)
330
331
 
@@ -299,12 +299,19 @@ class BaseRunner:
299
299
  def on_step_finished(event: stateful_events.StepFinished) -> None:
300
300
  return None
301
301
 
302
+ test_start_time: float | None = None
303
+ test_elapsed_time: float | None = None
304
+
302
305
  for stateful_event in runner.execute():
303
306
  if isinstance(stateful_event, stateful_events.SuiteFinished):
304
307
  if stateful_event.failures and status != Status.error:
305
308
  status = Status.failure
306
309
  for failure in stateful_event.failures:
307
310
  result.checks.append(failure)
311
+ elif isinstance(stateful_event, stateful_events.RunStarted):
312
+ test_start_time = stateful_event.timestamp
313
+ elif isinstance(stateful_event, stateful_events.RunFinished):
314
+ test_elapsed_time = stateful_event.timestamp - cast(float, test_start_time)
308
315
  elif isinstance(stateful_event, stateful_events.StepFinished):
309
316
  on_step_finished(stateful_event)
310
317
  elif isinstance(stateful_event, stateful_events.Errored):
@@ -315,6 +322,7 @@ class BaseRunner:
315
322
  yield events.AfterStatefulExecution(
316
323
  status=status,
317
324
  result=SerializedTestResult.from_test_result(result),
325
+ elapsed_time=cast(float, test_elapsed_time),
318
326
  data_generation_method=self.schema.data_generation_methods,
319
327
  )
320
328
 
@@ -605,8 +613,12 @@ def run_test(
605
613
  status = Status.error
606
614
  try:
607
615
  operation.schema.validate()
616
+ msg = "Unexpected error during testing of this API operation"
617
+ exc_msg = str(exc)
618
+ if exc_msg:
619
+ msg += f": {exc_msg}"
608
620
  try:
609
- raise InternalError(f"Unexpected error during testing of this API operation: {exc}") from exc
621
+ raise InternalError(msg) from exc
610
622
  except InternalError as exc:
611
623
  error = exc
612
624
  except ValidationError as exc:
@@ -238,6 +238,7 @@ def sanitize_serialized_check(check: SerializedCheck, *, config: Config | None =
238
238
 
239
239
 
240
240
  def sanitize_serialized_case(case: SerializedCase, *, config: Config | None = None) -> None:
241
+ case.url = sanitize_url(case.url, config=config)
241
242
  for value in (case.path_parameters, case.headers, case.cookies, case.query, case.extra_headers):
242
243
  if value is not None:
243
244
  sanitize_value(value, config=config)