schemathesis 3.26.1__py3-none-any.whl → 3.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,6 +21,8 @@ from jsonschema.exceptions import SchemaError as JsonSchemaError
21
21
  from jsonschema.exceptions import ValidationError
22
22
  from requests.auth import HTTPDigestAuth, _basic_auth_str
23
23
 
24
+ from schemathesis.transports import RequestsTransport
25
+
24
26
  from ... import failures, hooks
25
27
  from ..._compat import MultipleFailures
26
28
  from ..._hypothesis import (
@@ -849,7 +851,7 @@ def _network_test(
849
851
  response = case.call(**kwargs)
850
852
  except CheckFailed as exc:
851
853
  check_name = "request_timeout"
852
- requests_kwargs = case.as_requests_kwargs(base_url=case.get_full_base_url(), headers=headers)
854
+ requests_kwargs = RequestsTransport().serialize_case(case, base_url=case.get_full_base_url(), headers=headers)
853
855
  request = requests.Request(**requests_kwargs).prepare()
854
856
  elapsed = cast(float, timeout) # It is defined and not empty, since the exception happened
855
857
  check_result = result.add_failure(
@@ -939,14 +941,14 @@ def _wsgi_test(
939
941
  feedback: Feedback,
940
942
  max_response_time: int | None,
941
943
  ) -> WSGIResponse:
944
+ from ...transports.responses import WSGIResponse
945
+
942
946
  with catching_logs(LogCaptureHandler(), level=logging.DEBUG) as recorded:
943
- start = time.monotonic()
944
947
  hook_context = HookContext(operation=case.operation)
945
- kwargs = {"headers": headers}
948
+ kwargs: dict[str, Any] = {"headers": headers}
946
949
  hooks.dispatch("process_call_kwargs", hook_context, case, kwargs)
947
- response = case.call_wsgi(**kwargs)
948
- elapsed = time.monotonic() - start
949
- context = TargetContext(case=case, response=response, response_time=elapsed)
950
+ response = cast(WSGIResponse, case.call(**kwargs))
951
+ context = TargetContext(case=case, response=response, response_time=response.elapsed.total_seconds())
950
952
  run_targets(targets, context)
951
953
  result.logs.extend(recorded.records)
952
954
  status = Status.success
@@ -967,7 +969,7 @@ def _wsgi_test(
967
969
  finally:
968
970
  feedback.add_test_case(case, response)
969
971
  if store_interactions:
970
- result.store_wsgi_response(case, response, headers, elapsed, status, check_results)
972
+ result.store_wsgi_response(case, response, headers, response.elapsed.total_seconds(), status, check_results)
971
973
  return response
972
974
 
973
975
 
@@ -1037,7 +1039,7 @@ def _asgi_test(
1037
1039
  hook_context = HookContext(operation=case.operation)
1038
1040
  kwargs: dict[str, Any] = {"headers": headers}
1039
1041
  hooks.dispatch("process_call_kwargs", hook_context, case, kwargs)
1040
- response = case.call_asgi(**kwargs)
1042
+ response = case.call(**kwargs)
1041
1043
  context = TargetContext(case=case, response=response, response_time=response.elapsed.total_seconds())
1042
1044
  run_targets(targets, context)
1043
1045
  status = Status.success
@@ -4,28 +4,30 @@ They all consist of primitive types and don't have references to schemas, app, e
4
4
  """
5
5
 
6
6
  from __future__ import annotations
7
+
7
8
  import logging
8
9
  import re
10
+ import textwrap
9
11
  from dataclasses import dataclass, field
10
- from typing import Any, TYPE_CHECKING, cast
12
+ from typing import TYPE_CHECKING, Any, cast
11
13
 
12
- from ..transports import serialize_payload
13
14
  from ..code_samples import get_excluded_headers
14
15
  from ..exceptions import (
16
+ BodyInGetRequestError,
17
+ DeadlineExceeded,
15
18
  FailureContext,
16
19
  InternalError,
17
- make_unique_by_key,
18
- format_exception,
19
- extract_requests_exception_details,
20
- RuntimeErrorType,
21
- DeadlineExceeded,
22
- OperationSchemaError,
23
- BodyInGetRequestError,
24
20
  InvalidRegularExpression,
21
+ OperationSchemaError,
22
+ RuntimeErrorType,
25
23
  SerializationError,
26
24
  UnboundPrefixError,
25
+ extract_requests_exception_details,
26
+ format_exception,
27
+ make_unique_by_key,
27
28
  )
28
29
  from ..models import Case, Check, Interaction, Request, Response, Status, TestResult
30
+ from ..transports import serialize_payload
29
31
 
30
32
  if TYPE_CHECKING:
31
33
  import hypothesis.errors
@@ -108,6 +110,7 @@ class SerializedCheck:
108
110
  @classmethod
109
111
  def from_check(cls, check: Check) -> SerializedCheck:
110
112
  import requests
113
+
111
114
  from ..transports.responses import WSGIResponse
112
115
 
113
116
  if check.response is not None:
@@ -140,6 +143,25 @@ class SerializedCheck:
140
143
  history=history,
141
144
  )
142
145
 
146
+ @property
147
+ def title(self) -> str:
148
+ if self.context is not None:
149
+ return self.context.title
150
+ return f"Custom check failed: `{self.name}`"
151
+
152
+ @property
153
+ def formatted_message(self) -> str | None:
154
+ if self.context is not None:
155
+ if self.context.message:
156
+ message = self.context.message
157
+ else:
158
+ message = None
159
+ else:
160
+ message = self.message
161
+ if message is not None:
162
+ message = textwrap.indent(message, prefix=" ")
163
+ return message
164
+
143
165
 
144
166
  def _get_headers(headers: dict[str, Any] | CaseInsensitiveDict) -> dict[str, str]:
145
167
  return {key: value[0] for key, value in headers.items() if key not in get_excluded_headers()}
@@ -203,8 +225,8 @@ class SerializedError:
203
225
 
204
226
  @classmethod
205
227
  def from_exception(cls, exception: Exception) -> SerializedError:
206
- import requests
207
228
  import hypothesis.errors
229
+ import requests
208
230
  from hypothesis import HealthCheck
209
231
 
210
232
  title = "Runtime Error"
schemathesis/schemas.py CHANGED
@@ -8,11 +8,13 @@ They give only static definitions of paths.
8
8
  """
9
9
 
10
10
  from __future__ import annotations
11
+
11
12
  from collections.abc import Mapping, MutableMapping
12
13
  from contextlib import nullcontext
13
14
  from dataclasses import dataclass, field
14
15
  from functools import lru_cache
15
16
  from typing import (
17
+ TYPE_CHECKING,
16
18
  Any,
17
19
  Callable,
18
20
  ContextManager,
@@ -22,7 +24,6 @@ from typing import (
22
24
  NoReturn,
23
25
  Sequence,
24
26
  TypeVar,
25
- TYPE_CHECKING,
26
27
  )
27
28
  from urllib.parse import quote, unquote, urljoin, urlparse, urlsplit, urlunsplit
28
29
 
@@ -30,22 +31,23 @@ import hypothesis
30
31
  from hypothesis.strategies import SearchStrategy
31
32
  from pyrate_limiter import Limiter
32
33
 
33
- from .constants import NOT_SET
34
+ from ._dependency_versions import IS_PYRATE_LIMITER_ABOVE_3
34
35
  from ._hypothesis import create_test
35
36
  from .auths import AuthStorage
36
37
  from .code_samples import CodeSampleStyle
38
+ from .constants import NOT_SET
39
+ from .exceptions import OperationSchemaError, UsageError
37
40
  from .generation import (
38
41
  DEFAULT_DATA_GENERATION_METHODS,
39
42
  DataGenerationMethod,
40
43
  DataGenerationMethodInput,
41
44
  GenerationConfig,
42
45
  )
43
- from .exceptions import OperationSchemaError, UsageError
44
46
  from .hooks import HookContext, HookDispatcher, HookScope, dispatch
45
- from .internal.result import Result, Ok
47
+ from .internal.result import Ok, Result
46
48
  from .models import APIOperation, Case
47
- from .stateful.state_machine import APIStateMachine
48
49
  from .stateful import Stateful, StatefulTest
50
+ from .stateful.state_machine import APIStateMachine
49
51
  from .types import (
50
52
  Body,
51
53
  Cookies,
@@ -57,9 +59,10 @@ from .types import (
57
59
  PathParameters,
58
60
  Query,
59
61
  )
60
- from .utils import PARAMETRIZE_MARKER, GivenInput, given_proxy, combine_strategies
62
+ from .utils import PARAMETRIZE_MARKER, GivenInput, combine_strategies, given_proxy
61
63
 
62
64
  if TYPE_CHECKING:
65
+ from .transports import Transport
63
66
  from .transports.responses import GenericResponse
64
67
 
65
68
 
@@ -74,6 +77,7 @@ def get_full_path(base_path: str, path: str) -> str:
74
77
  @dataclass(eq=False)
75
78
  class BaseSchema(Mapping):
76
79
  raw_schema: dict[str, Any]
80
+ transport: Transport
77
81
  location: str | None = None
78
82
  base_url: str | None = None
79
83
  method: Filter | None = None
@@ -345,6 +349,7 @@ class BaseSchema(Mapping):
345
349
  code_sample_style=code_sample_style, # type: ignore
346
350
  rate_limiter=rate_limiter, # type: ignore
347
351
  sanitize_output=sanitize_output, # type: ignore
352
+ transport=self.transport,
348
353
  )
349
354
 
350
355
  def get_local_hook_dispatcher(self) -> HookDispatcher | None:
@@ -423,7 +428,10 @@ class BaseSchema(Mapping):
423
428
  """Limit the rate of sending generated requests."""
424
429
  label = urlparse(self.base_url).netloc
425
430
  if self.rate_limiter is not None:
426
- return self.rate_limiter.ratelimit(label, delay=True, max_delay=0)
431
+ if IS_PYRATE_LIMITER_ABOVE_3:
432
+ self.rate_limiter.try_acquire(label)
433
+ else:
434
+ return self.rate_limiter.ratelimit(label, delay=True, max_delay=0)
427
435
  return nullcontext()
428
436
 
429
437
  def _get_payload_schema(self, definition: dict[str, Any], media_type: str) -> dict[str, Any] | None:
@@ -234,6 +234,7 @@ def from_dict(
234
234
  :return: GraphQLSchema
235
235
  """
236
236
  from .schemas import GraphQLSchema
237
+ from ... import transports
237
238
 
238
239
  _code_sample_style = CodeSampleStyle.from_str(code_sample_style)
239
240
  hook_context = HookContext()
@@ -252,6 +253,7 @@ def from_dict(
252
253
  code_sample_style=_code_sample_style,
253
254
  rate_limiter=rate_limiter,
254
255
  sanitize_output=sanitize_output,
256
+ transport=transports.get(app),
255
257
  ) # type: ignore
256
258
  dispatch("after_load_schema", hook_context, instance)
257
259
  return instance
@@ -20,7 +20,6 @@ from typing import (
20
20
  from urllib.parse import urlsplit, urlunsplit
21
21
 
22
22
  import graphql
23
- import requests
24
23
  from graphql import GraphQLNamedType
25
24
  from hypothesis import strategies as st
26
25
  from hypothesis.strategies import SearchStrategy
@@ -60,39 +59,15 @@ class RootType(enum.Enum):
60
59
 
61
60
  @dataclass(repr=False)
62
61
  class GraphQLCase(Case):
63
- def as_requests_kwargs(self, base_url: str | None = None, headers: dict[str, str] | None = None) -> dict[str, Any]:
64
- final_headers = self._get_headers(headers)
62
+ def _get_url(self, base_url: str | None) -> str:
65
63
  base_url = self._get_base_url(base_url)
66
64
  # Replace the path, in case if the user provided any path parameters via hooks
67
65
  parts = list(urlsplit(base_url))
68
66
  parts[2] = self.formatted_path
69
- kwargs: dict[str, Any] = {
70
- "method": self.method,
71
- "url": urlunsplit(parts),
72
- "headers": final_headers,
73
- "cookies": self.cookies,
74
- "params": self.query,
75
- }
76
- # There is no direct way to have bytes here, but it is a useful pattern to support.
77
- # It also unifies GraphQLCase with its Open API counterpart where bytes may come from external examples
78
- if isinstance(self.body, bytes):
79
- kwargs["data"] = self.body
80
- # Assume that the payload is JSON, not raw GraphQL queries
81
- kwargs["headers"].setdefault("Content-Type", "application/json")
82
- else:
83
- kwargs["json"] = {"query": self.body}
84
- return kwargs
85
-
86
- def as_werkzeug_kwargs(self, headers: dict[str, str] | None = None) -> dict[str, Any]:
87
- final_headers = self._get_headers(headers)
88
- return {
89
- "method": self.method,
90
- "path": self.operation.schema.get_full_path(self.formatted_path),
91
- # Convert to a regular dictionary, as we use `CaseInsensitiveDict` which is not supported by Werkzeug
92
- "headers": dict(final_headers),
93
- "query_string": self.query,
94
- "json": {"query": self.body},
95
- }
67
+ return urlunsplit(parts)
68
+
69
+ def _get_body(self) -> Body | NotSet:
70
+ return self.body if isinstance(self.body, (NotSet, bytes)) else {"query": self.body}
96
71
 
97
72
  def validate_response(
98
73
  self,
@@ -107,15 +82,6 @@ class GraphQLCase(Case):
107
82
  checks = tuple(check for check in checks if check not in excluded_checks)
108
83
  return super().validate_response(response, checks, code_sample_style=code_sample_style)
109
84
 
110
- def call_asgi(
111
- self,
112
- app: Any = None,
113
- base_url: str | None = None,
114
- headers: dict[str, str] | None = None,
115
- **kwargs: Any,
116
- ) -> requests.Response:
117
- return super().call_asgi(app=app, base_url=base_url, headers=headers, **kwargs)
118
-
119
85
 
120
86
  C = TypeVar("C", bound=Case)
121
87
 
@@ -287,7 +253,7 @@ class GraphQLSchema(BaseSchema):
287
253
  cookies=cookies,
288
254
  query=query,
289
255
  body=body,
290
- media_type=media_type,
256
+ media_type=media_type or "application/json",
291
257
  generation_time=0.0,
292
258
  )
293
259
 
@@ -373,6 +339,7 @@ def get_case_strategy(
373
339
  operation=operation,
374
340
  data_generation_method=data_generation_method,
375
341
  generation_time=time.monotonic() - start,
342
+ media_type="application/json",
376
343
  ) # type: ignore
377
344
  context = auths.AuthContext(
378
345
  operation=operation,
@@ -43,7 +43,9 @@ StrategyFactory = Callable[[Dict[str, Any], str, str, Optional[str], GenerationC
43
43
 
44
44
 
45
45
  def header_values(blacklist_characters: str = "\n\r") -> st.SearchStrategy[str]:
46
- return st.text(alphabet=st.characters(min_codepoint=0, max_codepoint=255, blacklist_characters="\n\r"))
46
+ return st.text(
47
+ alphabet=st.characters(min_codepoint=0, max_codepoint=255, blacklist_characters=blacklist_characters)
48
+ )
47
49
 
48
50
 
49
51
  @lru_cache
@@ -304,6 +304,7 @@ def from_dict(
304
304
  :param dict raw_schema: A schema to load.
305
305
  """
306
306
  from .schemas import OpenApi30, SwaggerV20
307
+ from ... import transports
307
308
 
308
309
  if not isinstance(raw_schema, dict):
309
310
  raise SchemaError(SchemaErrorType.OPEN_API_INVALID_SCHEMA, SCHEMA_INVALID_ERROR)
@@ -338,6 +339,7 @@ def from_dict(
338
339
  location=location,
339
340
  rate_limiter=rate_limiter,
340
341
  sanitize_output=sanitize_output,
342
+ transport=transports.get(app),
341
343
  )
342
344
  dispatch("after_load_schema", hook_context, instance)
343
345
  return instance
@@ -379,6 +381,7 @@ def from_dict(
379
381
  location=location,
380
382
  rate_limiter=rate_limiter,
381
383
  sanitize_output=sanitize_output,
384
+ transport=transports.get(app),
382
385
  )
383
386
  dispatch("after_load_schema", hook_context, instance)
384
387
  return instance
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import time
4
4
  import re
5
5
  from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Any, Callable, ClassVar
6
+ from typing import TYPE_CHECKING, Any, ClassVar
7
7
 
8
8
  from hypothesis.errors import InvalidDefinition
9
9
  from hypothesis.stateful import RuleBasedStateMachine
@@ -189,13 +189,12 @@ class APIStateMachine(RuleBasedStateMachine):
189
189
  :return: Response from the application under test.
190
190
 
191
191
  Note that WSGI/ASGI applications are detected automatically in this method. Depending on the result of this
192
- detection the state machine will call ``call``, ``call_wsgi`` or ``call_asgi`` methods.
192
+ detection the state machine will call the ``call`` method.
193
193
 
194
194
  Usually, you don't need to override this method unless you are building a different state machine on top of this
195
195
  one and want to customize the transport layer itself.
196
196
  """
197
- method = self._get_call_method(case)
198
- return method(**kwargs)
197
+ return case.call(**kwargs)
199
198
 
200
199
  def get_call_kwargs(self, case: Case) -> dict[str, Any]:
201
200
  """Create custom keyword arguments that will be passed to the :meth:`Case.call` method.
@@ -214,15 +213,6 @@ class APIStateMachine(RuleBasedStateMachine):
214
213
  """
215
214
  return {}
216
215
 
217
- def _get_call_method(self, case: Case) -> Callable:
218
- if case.app is not None:
219
- from starlette.applications import Starlette
220
-
221
- if isinstance(case.app, Starlette):
222
- return case.call_asgi
223
- return case.call_wsgi
224
- return case.call
225
-
226
216
  def validate_response(
227
217
  self, response: GenericResponse, case: Case, additional_checks: tuple[CheckFunction, ...] = ()
228
218
  ) -> None:
@@ -34,8 +34,8 @@ def invalid_rate(value: str) -> UsageError:
34
34
 
35
35
 
36
36
  def build_limiter(rate: str) -> Limiter:
37
- from pyrate_limiter import Limiter, RequestRate
37
+ from ._rate_limiter import Limiter, Rate
38
38
 
39
39
  limit, interval = parse_units(rate)
40
- rate = RequestRate(limit, interval)
40
+ rate = Rate(limit, interval)
41
41
  return Limiter(rate)