schemathesis 3.18.5__py3-none-any.whl → 3.19.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. schemathesis/__init__.py +1 -3
  2. schemathesis/auths.py +218 -43
  3. schemathesis/cli/__init__.py +37 -20
  4. schemathesis/cli/callbacks.py +13 -1
  5. schemathesis/cli/cassettes.py +18 -18
  6. schemathesis/cli/context.py +25 -24
  7. schemathesis/cli/debug.py +3 -3
  8. schemathesis/cli/junitxml.py +4 -4
  9. schemathesis/cli/options.py +1 -1
  10. schemathesis/cli/output/default.py +2 -0
  11. schemathesis/constants.py +3 -3
  12. schemathesis/exceptions.py +9 -9
  13. schemathesis/extra/pytest_plugin.py +1 -1
  14. schemathesis/failures.py +65 -66
  15. schemathesis/filters.py +269 -0
  16. schemathesis/hooks.py +11 -11
  17. schemathesis/lazy.py +21 -16
  18. schemathesis/models.py +149 -107
  19. schemathesis/parameters.py +12 -7
  20. schemathesis/runner/events.py +55 -55
  21. schemathesis/runner/impl/core.py +26 -26
  22. schemathesis/runner/impl/solo.py +6 -7
  23. schemathesis/runner/impl/threadpool.py +5 -5
  24. schemathesis/runner/serialization.py +50 -50
  25. schemathesis/schemas.py +38 -23
  26. schemathesis/serializers.py +3 -3
  27. schemathesis/service/ci.py +25 -25
  28. schemathesis/service/client.py +2 -2
  29. schemathesis/service/events.py +12 -13
  30. schemathesis/service/hosts.py +4 -4
  31. schemathesis/service/metadata.py +14 -15
  32. schemathesis/service/models.py +12 -13
  33. schemathesis/service/report.py +30 -31
  34. schemathesis/service/serialization.py +2 -4
  35. schemathesis/specs/graphql/loaders.py +21 -2
  36. schemathesis/specs/graphql/schemas.py +8 -8
  37. schemathesis/specs/openapi/expressions/context.py +4 -4
  38. schemathesis/specs/openapi/expressions/lexer.py +11 -12
  39. schemathesis/specs/openapi/expressions/nodes.py +16 -16
  40. schemathesis/specs/openapi/expressions/parser.py +1 -1
  41. schemathesis/specs/openapi/links.py +15 -17
  42. schemathesis/specs/openapi/loaders.py +29 -2
  43. schemathesis/specs/openapi/negative/__init__.py +5 -5
  44. schemathesis/specs/openapi/negative/mutations.py +6 -6
  45. schemathesis/specs/openapi/parameters.py +12 -13
  46. schemathesis/specs/openapi/references.py +2 -2
  47. schemathesis/specs/openapi/schemas.py +11 -15
  48. schemathesis/specs/openapi/security.py +12 -7
  49. schemathesis/specs/openapi/stateful/links.py +4 -4
  50. schemathesis/stateful.py +19 -19
  51. schemathesis/targets.py +5 -6
  52. schemathesis/throttling.py +34 -0
  53. schemathesis/types.py +11 -13
  54. schemathesis/utils.py +2 -2
  55. {schemathesis-3.18.5.dist-info → schemathesis-3.19.1.dist-info}/METADATA +4 -3
  56. schemathesis-3.19.1.dist-info/RECORD +107 -0
  57. schemathesis-3.18.5.dist-info/RECORD +0 -105
  58. {schemathesis-3.18.5.dist-info → schemathesis-3.19.1.dist-info}/WHEEL +0 -0
  59. {schemathesis-3.18.5.dist-info → schemathesis-3.19.1.dist-info}/entry_points.txt +0 -0
  60. {schemathesis-3.18.5.dist-info → schemathesis-3.19.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,45 +1,46 @@
1
1
  import os
2
2
  import shutil
3
+ from dataclasses import dataclass, field
3
4
  from queue import Queue
4
5
  from typing import List, Optional, Union
5
6
 
6
- import attr
7
7
  import hypothesis
8
8
 
9
9
  from ..constants import CodeSampleStyle
10
10
  from ..runner.serialization import SerializedTestResult
11
11
 
12
12
 
13
- @attr.s(slots=True) # pragma: no mutate
13
+ @dataclass
14
14
  class ServiceReportContext:
15
- queue: Queue = attr.ib() # pragma: no mutate
16
- service_base_url: str = attr.ib() # pragma: no mutate
15
+ queue: Queue
16
+ service_base_url: str
17
17
 
18
18
 
19
- @attr.s(slots=True) # pragma: no mutate
19
+ @dataclass
20
20
  class FileReportContext:
21
- queue: Queue = attr.ib() # pragma: no mutate
22
- filename: str = attr.ib(default=None) # pragma: no mutate
21
+ queue: Queue
22
+ filename: Optional[str] = None
23
23
 
24
24
 
25
- @attr.s(slots=True) # pragma: no mutate
25
+ @dataclass
26
26
  class ExecutionContext:
27
27
  """Storage for the current context of the execution."""
28
28
 
29
- hypothesis_settings: hypothesis.settings = attr.ib() # pragma: no mutate
30
- hypothesis_output: List[str] = attr.ib(factory=list) # pragma: no mutate
31
- workers_num: int = attr.ib(default=1) # pragma: no mutate
32
- show_errors_tracebacks: bool = attr.ib(default=False) # pragma: no mutate
33
- validate_schema: bool = attr.ib(default=True) # pragma: no mutate
34
- operations_processed: int = attr.ib(default=0) # pragma: no mutate
29
+ hypothesis_settings: hypothesis.settings
30
+ hypothesis_output: List[str] = field(default_factory=list)
31
+ workers_num: int = 1
32
+ rate_limit: Optional[str] = None
33
+ show_errors_tracebacks: bool = False
34
+ validate_schema: bool = True
35
+ operations_processed: int = 0
35
36
  # It is set in runtime, from a `Initialized` event
36
- operations_count: Optional[int] = attr.ib(default=None) # pragma: no mutate
37
- current_line_length: int = attr.ib(default=0) # pragma: no mutate
38
- terminal_size: os.terminal_size = attr.ib(factory=shutil.get_terminal_size) # pragma: no mutate
39
- results: List[SerializedTestResult] = attr.ib(factory=list) # pragma: no mutate
40
- cassette_path: Optional[str] = attr.ib(default=None) # pragma: no mutate
41
- junit_xml_file: Optional[str] = attr.ib(default=None) # pragma: no mutate
42
- is_interrupted: bool = attr.ib(default=False) # pragma: no mutate
43
- verbosity: int = attr.ib(default=0) # pragma: no mutate
44
- code_sample_style: CodeSampleStyle = attr.ib(default=CodeSampleStyle.default()) # pragma: no mutate
45
- report: Optional[Union[ServiceReportContext, FileReportContext]] = attr.ib(default=None) # pragma: no mutate
37
+ operations_count: Optional[int] = None
38
+ current_line_length: int = 0
39
+ terminal_size: os.terminal_size = field(default_factory=shutil.get_terminal_size)
40
+ results: List[SerializedTestResult] = field(default_factory=list)
41
+ cassette_path: Optional[str] = None
42
+ junit_xml_file: Optional[str] = None
43
+ is_interrupted: bool = False
44
+ verbosity: int = 0
45
+ code_sample_style: CodeSampleStyle = CodeSampleStyle.default()
46
+ report: Optional[Union[ServiceReportContext, FileReportContext]] = None
schemathesis/cli/debug.py CHANGED
@@ -1,15 +1,15 @@
1
1
  import json
2
+ from dataclasses import dataclass
2
3
 
3
- import attr
4
4
  from click.utils import LazyFile
5
5
 
6
6
  from ..runner import events
7
7
  from .handlers import EventHandler, ExecutionContext
8
8
 
9
9
 
10
- @attr.s(slots=True) # pragma: no mutate
10
+ @dataclass
11
11
  class DebugOutputHandler(EventHandler):
12
- file_handle: LazyFile = attr.ib() # pragma: no mutate
12
+ file_handle: LazyFile
13
13
 
14
14
  def handle_event(self, context: ExecutionContext, event: events.ExecutionEvent) -> None:
15
15
  stream = self.file_handle.open()
@@ -1,7 +1,7 @@
1
1
  import platform
2
+ from dataclasses import dataclass, field
2
3
  from typing import List
3
4
 
4
- import attr
5
5
  from click.utils import LazyFile
6
6
  from junit_xml import TestCase, TestSuite, to_xml_report_file
7
7
 
@@ -11,10 +11,10 @@ from ..runner.serialization import deduplicate_failures
11
11
  from .handlers import EventHandler, ExecutionContext
12
12
 
13
13
 
14
- @attr.s(slots=True) # pragma: no mutate
14
+ @dataclass
15
15
  class JunitXMLHandler(EventHandler):
16
- file_handle: LazyFile = attr.ib() # pragma: no mutate
17
- test_cases: List = attr.ib(factory=list) # pragma: no mutate
16
+ file_handle: LazyFile
17
+ test_cases: List = field(default_factory=list)
18
18
 
19
19
  def handle_event(self, context: ExecutionContext, event: events.ExecutionEvent) -> None:
20
20
  if isinstance(event, events.AfterExecution):
@@ -33,7 +33,7 @@ class BaseCsvChoice(click.Choice):
33
33
  class CsvEnumChoice(BaseCsvChoice):
34
34
  def __init__(self, choices: Type[Enum]):
35
35
  self.enum = choices
36
- super().__init__(tuple(choices.__members__))
36
+ super().__init__(tuple(el.name for el in choices))
37
37
 
38
38
  def convert( # type: ignore[return]
39
39
  self, value: str, param: Optional[click.core.Parameter], ctx: Optional[click.core.Context]
@@ -487,6 +487,8 @@ def handle_initialized(context: ExecutionContext, event: events.Initialized) ->
487
487
  click.secho(f"Base URL: {event.base_url}", bold=True)
488
488
  click.secho(f"Specification version: {event.specification_name}", bold=True)
489
489
  click.secho(f"Workers: {context.workers_num}", bold=True)
490
+ if context.rate_limit is not None:
491
+ click.secho(f"Rate limit: {context.rate_limit}", bold=True)
490
492
  click.secho(f"Collected API operations: {context.operations_count}", bold=True)
491
493
  if isinstance(context.report, ServiceReportContext):
492
494
  click.secho("Report to Schemathesis.io: ENABLED", bold=True)
schemathesis/constants.py CHANGED
@@ -20,9 +20,9 @@ SCHEMATHESIS_TEST_CASE_HEADER = "X-Schemathesis-TestCaseId"
20
20
  HYPOTHESIS_IN_MEMORY_DATABASE_IDENTIFIER = ":memory:"
21
21
  DISCORD_LINK = "https://discord.gg/R9ASRAmHnA"
22
22
  # Maximum test running time
23
- DEFAULT_DEADLINE = 15000 # pragma: no mutate
24
- DEFAULT_RESPONSE_TIMEOUT = 10000 # pragma: no mutate
25
- DEFAULT_STATEFUL_RECURSION_LIMIT = 5 # pragma: no mutate
23
+ DEFAULT_DEADLINE = 15000
24
+ DEFAULT_RESPONSE_TIMEOUT = 10000
25
+ DEFAULT_STATEFUL_RECURSION_LIMIT = 5
26
26
  HTTP_METHODS = frozenset({"get", "put", "post", "delete", "options", "head", "patch", "trace"})
27
27
  RECURSIVE_REFERENCE_ERROR_MESSAGE = (
28
28
  "Currently, Schemathesis can't generate data for this operation due to "
@@ -1,8 +1,8 @@
1
+ from dataclasses import dataclass
1
2
  from hashlib import sha1
2
3
  from json import JSONDecodeError
3
4
  from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NoReturn, Optional, Tuple, Type, Union
4
5
 
5
- import attr
6
6
  import hypothesis.errors
7
7
  import requests
8
8
  from jsonschema import ValidationError
@@ -139,15 +139,15 @@ def get_timeout_error(deadline: Union[float, int]) -> Type[CheckFailed]:
139
139
  return _get_hashed_exception("TimeoutError", str(deadline))
140
140
 
141
141
 
142
- @attr.s(slots=True)
142
+ @dataclass
143
143
  class InvalidSchema(Exception):
144
144
  """Schema associated with an API operation contains an error."""
145
145
 
146
146
  __module__ = "builtins"
147
- message: Optional[str] = attr.ib(default=None)
148
- path: Optional[str] = attr.ib(default=None)
149
- method: Optional[str] = attr.ib(default=None)
150
- full_path: Optional[str] = attr.ib(default=None)
147
+ message: Optional[str] = None
148
+ path: Optional[str] = None
149
+ method: Optional[str] = None
150
+ full_path: Optional[str] = None
151
151
 
152
152
  def as_failing_test_function(self) -> Callable:
153
153
  """Create a test function that will fail.
@@ -233,10 +233,10 @@ class InvalidRegularExpression(Exception):
233
233
  __module__ = "builtins"
234
234
 
235
235
 
236
- @attr.s # pragma: no mutate
236
+ @dataclass
237
237
  class HTTPError(Exception):
238
- response: "GenericResponse" = attr.ib() # pragma: no mutate
239
- url: str = attr.ib() # pragma: no mutate
238
+ response: "GenericResponse"
239
+ url: str
240
240
 
241
241
  @classmethod
242
242
  def raise_for_status(cls, response: requests.Response) -> None:
@@ -222,7 +222,7 @@ def skip_unnecessary_hypothesis_output() -> Generator:
222
222
  yield
223
223
 
224
224
 
225
- @hookimpl(hookwrapper=True) # pragma: no mutate
225
+ @hookimpl(hookwrapper=True)
226
226
  def pytest_pyfunc_call(pyfuncitem): # type:ignore
227
227
  """It is possible to have a Hypothesis exception in runtime.
228
228
 
schemathesis/failures.py CHANGED
@@ -1,9 +1,8 @@
1
+ from dataclasses import dataclass
1
2
  from typing import Any, Dict, List, Optional, Tuple, Union
2
3
 
3
- import attr
4
4
 
5
-
6
- @attr.s(slots=True, repr=False) # pragma: no mutate
5
+ # @dataclass(repr=False)
7
6
  class FailureContext:
8
7
  """Additional data specific to certain failure kind."""
9
8
 
@@ -18,36 +17,36 @@ class FailureContext:
18
17
  return (check_message or self.message,)
19
18
 
20
19
 
21
- @attr.s(slots=True, repr=False)
20
+ @dataclass(repr=False)
22
21
  class ValidationErrorContext(FailureContext):
23
22
  """Additional information about JSON Schema validation errors."""
24
23
 
25
- validation_message: str = attr.ib()
26
- schema_path: List[Union[str, int]] = attr.ib()
27
- schema: Union[Dict[str, Any], bool] = attr.ib()
28
- instance_path: List[Union[str, int]] = attr.ib()
29
- instance: Union[None, bool, float, str, list, Dict[str, Any]] = attr.ib()
30
- title: str = attr.ib(default="Non-conforming response payload")
31
- message: str = attr.ib(default="Response does not conform to the defined schema")
32
- type: str = attr.ib(default="json_schema")
24
+ validation_message: str
25
+ schema_path: List[Union[str, int]]
26
+ schema: Union[Dict[str, Any], bool]
27
+ instance_path: List[Union[str, int]]
28
+ instance: Union[None, bool, float, str, list, Dict[str, Any]]
29
+ title: str = "Non-conforming response payload"
30
+ message: str = "Response does not conform to the defined schema"
31
+ type: str = "json_schema"
33
32
 
34
33
  def unique_by_key(self, check_message: Optional[str]) -> Tuple[str, ...]:
35
34
  # Deduplicate by JSON Schema path. All errors that happened on this sub-schema will be deduplicated
36
35
  return ("/".join(map(str, self.schema_path)),)
37
36
 
38
37
 
39
- @attr.s(slots=True, repr=False)
38
+ @dataclass(repr=False)
40
39
  class JSONDecodeErrorContext(FailureContext):
41
40
  """Failed to decode JSON."""
42
41
 
43
- validation_message: str = attr.ib()
44
- document: str = attr.ib()
45
- position: int = attr.ib()
46
- lineno: int = attr.ib()
47
- colno: int = attr.ib()
48
- title: str = attr.ib(default="JSON deserialization error")
49
- message: str = attr.ib(default="Response is not a valid JSON")
50
- type: str = attr.ib(default="json_decode")
42
+ validation_message: str
43
+ document: str
44
+ position: int
45
+ lineno: int
46
+ colno: int
47
+ title: str = "JSON deserialization error"
48
+ message: str = "Response is not a valid JSON"
49
+ type: str = "json_decode"
51
50
 
52
51
  def unique_by_key(self, check_message: Optional[str]) -> Tuple[str, ...]:
53
52
  # Treat different JSON decoding failures as the same issue
@@ -56,90 +55,90 @@ class JSONDecodeErrorContext(FailureContext):
56
55
  return (self.title,)
57
56
 
58
57
 
59
- @attr.s(slots=True, repr=False)
58
+ @dataclass(repr=False)
60
59
  class ServerError(FailureContext):
61
- status_code: int = attr.ib()
62
- title: str = attr.ib(default="Internal server error")
63
- message: str = attr.ib(default="Server got itself in trouble")
64
- type: str = attr.ib(default="server_error")
60
+ status_code: int
61
+ title: str = "Internal server error"
62
+ message: str = "Server got itself in trouble"
63
+ type: str = "server_error"
65
64
 
66
65
 
67
- @attr.s(slots=True, repr=False)
66
+ @dataclass(repr=False)
68
67
  class MissingContentType(FailureContext):
69
68
  """Content type header is missing."""
70
69
 
71
- media_types: List[str] = attr.ib()
72
- title: str = attr.ib(default="Missing Content-Type header")
73
- message: str = attr.ib(default="Response is missing the `Content-Type` header")
74
- type: str = attr.ib(default="missing_content_type")
70
+ media_types: List[str]
71
+ title: str = "Missing Content-Type header"
72
+ message: str = "Response is missing the `Content-Type` header"
73
+ type: str = "missing_content_type"
75
74
 
76
75
 
77
- @attr.s(slots=True, repr=False)
76
+ @dataclass(repr=False)
78
77
  class UndefinedContentType(FailureContext):
79
78
  """Response has Content-Type that is not defined in the schema."""
80
79
 
81
- content_type: str = attr.ib()
82
- defined_content_types: List[str] = attr.ib()
83
- title: str = attr.ib(default="Undefined Content-Type")
84
- message: str = attr.ib(default="Response has `Content-Type` that is not declared in the schema")
85
- type: str = attr.ib(default="undefined_content_type")
80
+ content_type: str
81
+ defined_content_types: List[str]
82
+ title: str = "Undefined Content-Type"
83
+ message: str = "Response has `Content-Type` that is not declared in the schema"
84
+ type: str = "undefined_content_type"
86
85
 
87
86
 
88
- @attr.s(slots=True, repr=False)
87
+ @dataclass(repr=False)
89
88
  class UndefinedStatusCode(FailureContext):
90
89
  """Response has a status code that is not defined in the schema."""
91
90
 
92
91
  # Response's status code
93
- status_code: int = attr.ib()
92
+ status_code: int
94
93
  # Status codes as defined in schema
95
- defined_status_codes: List[str] = attr.ib()
94
+ defined_status_codes: List[str]
96
95
  # Defined status code with expanded wildcards
97
- allowed_status_codes: List[int] = attr.ib()
98
- title: str = attr.ib(default="Undefined status code")
99
- message: str = attr.ib(default="Response has a status code that is not declared in the schema")
100
- type: str = attr.ib(default="undefined_status_code")
96
+ allowed_status_codes: List[int]
97
+ title: str = "Undefined status code"
98
+ message: str = "Response has a status code that is not declared in the schema"
99
+ type: str = "undefined_status_code"
101
100
 
102
101
 
103
- @attr.s(slots=True, repr=False)
102
+ @dataclass(repr=False)
104
103
  class MissingHeaders(FailureContext):
105
104
  """Some required headers are missing."""
106
105
 
107
- missing_headers: List[str] = attr.ib()
108
- title: str = attr.ib(default="Missing required headers")
109
- message: str = attr.ib(default="Response is missing headers required by the schema")
110
- type: str = attr.ib(default="missing_headers")
106
+ missing_headers: List[str]
107
+ title: str = "Missing required headers"
108
+ message: str = "Response is missing headers required by the schema"
109
+ type: str = "missing_headers"
111
110
 
112
111
 
113
- @attr.s(slots=True, repr=False)
112
+ @dataclass(repr=False)
114
113
  class MalformedMediaType(FailureContext):
115
114
  """Media type name is malformed.
116
115
 
117
116
  Example: `application-json` instead of `application/json`
118
117
  """
119
118
 
120
- actual: str = attr.ib()
121
- defined: str = attr.ib()
122
- title: str = attr.ib(default="Malformed media type name")
123
- message: str = attr.ib(default="Media type name is not valid")
124
- type: str = attr.ib(default="malformed_media_type")
119
+ actual: str
120
+ defined: str
121
+ title: str = "Malformed media type name"
122
+ message: str = "Media type name is not valid"
123
+ type: str = "malformed_media_type"
125
124
 
126
125
 
127
- @attr.s(slots=True, repr=False)
126
+ @dataclass(repr=False)
128
127
  class ResponseTimeExceeded(FailureContext):
129
128
  """Response took longer than expected."""
130
129
 
131
- elapsed: float = attr.ib()
132
- deadline: int = attr.ib()
133
- title: str = attr.ib(default="Response time exceeded")
134
- message: str = attr.ib(default="Response time exceeds the deadline")
135
- type: str = attr.ib(default="response_time_exceeded")
130
+ elapsed: float
131
+ deadline: int
132
+ title: str = "Response time exceeded"
133
+ message: str = "Response time exceeds the deadline"
134
+ type: str = "response_time_exceeded"
136
135
 
137
136
 
138
- @attr.s(slots=True, repr=False)
137
+ @dataclass(repr=False)
139
138
  class RequestTimeout(FailureContext):
140
139
  """Request took longer than timeout."""
141
140
 
142
- timeout: int = attr.ib()
143
- title: str = attr.ib(default="Request timeout")
144
- message: str = attr.ib(default="The request timed out")
145
- type: str = attr.ib(default="request_timeout")
141
+ timeout: int
142
+ title: str = "Request timeout"
143
+ message: str = "The request timed out"
144
+ type: str = "request_timeout"
@@ -0,0 +1,269 @@
1
+ """Filtering system that allows users to filter API operations based on certain criteria."""
2
+ import re
3
+ from dataclasses import dataclass, field
4
+ from functools import partial
5
+ from types import SimpleNamespace
6
+ from typing import TYPE_CHECKING, Callable, List, Optional, Set, Tuple, Union
7
+
8
+ from typing_extensions import Protocol
9
+
10
+ from .exceptions import UsageError
11
+
12
+ if TYPE_CHECKING:
13
+ from .models import APIOperation
14
+
15
+
16
+ class HasAPIOperation(Protocol):
17
+ operation: "APIOperation"
18
+
19
+
20
+ MatcherFunc = Callable[[HasAPIOperation], bool]
21
+ FilterValue = Union[str, List[str]]
22
+ RegexValue = Union[str, re.Pattern]
23
+ ERROR_EXPECTED_AND_REGEX = "Passing expected value and regex simultaneously is not allowed"
24
+ ERROR_EMPTY_FILTER = "Filter can not be empty"
25
+ ERROR_FILTER_EXISTS = "Filter already exists"
26
+
27
+
28
+ @dataclass(repr=False, frozen=True)
29
+ class Matcher:
30
+ """Encapsulates matching logic by various criteria."""
31
+
32
+ func: Callable[..., bool] = field(hash=False, compare=False)
33
+ # A short description of a matcher. Primarily exists for debugging purposes
34
+ label: str = field(hash=False, compare=False)
35
+ # Compare & hash matchers by a pre-computed hash value
36
+ _hash: int
37
+
38
+ def __repr__(self) -> str:
39
+ return f"<{self.__class__.__name__}: {self.label}>"
40
+
41
+ @classmethod
42
+ def for_function(cls, func: MatcherFunc) -> "Matcher":
43
+ """Matcher that uses the given function for matching operations."""
44
+ return cls(func, label=func.__name__, _hash=hash(func))
45
+
46
+ @classmethod
47
+ def for_value(cls, attribute: str, expected: FilterValue) -> "Matcher":
48
+ """Matcher that checks whether the specified attribute has the expected value."""
49
+ if isinstance(expected, list):
50
+ func = partial(by_value_list, attribute=attribute, expected=expected)
51
+ else:
52
+ func = partial(by_value, attribute=attribute, expected=expected)
53
+ label = f"{attribute}={repr(expected)}"
54
+ return cls(func, label=label, _hash=hash(label))
55
+
56
+ @classmethod
57
+ def for_regex(cls, attribute: str, regex: RegexValue) -> "Matcher":
58
+ """Matcher that checks whether the specified attribute has the provided regex."""
59
+ if isinstance(regex, str):
60
+ regex = re.compile(regex)
61
+ func = partial(by_regex, attribute=attribute, regex=regex)
62
+ label = f"{attribute}_regex={repr(regex)}"
63
+ return cls(func, label=label, _hash=hash(label))
64
+
65
+ def match(self, ctx: HasAPIOperation) -> bool:
66
+ """Whether matcher matches the given operation."""
67
+ return self.func(ctx)
68
+
69
+
70
+ def get_operation_attribute(operation: "APIOperation", attribute: str) -> str:
71
+ # Just uppercase `method`
72
+ value = getattr(operation, attribute)
73
+ if attribute == "method":
74
+ value = value.upper()
75
+ return value
76
+
77
+
78
+ def by_value(ctx: HasAPIOperation, attribute: str, expected: str) -> bool:
79
+ return get_operation_attribute(ctx.operation, attribute) == expected
80
+
81
+
82
+ def by_value_list(ctx: HasAPIOperation, attribute: str, expected: List[str]) -> bool:
83
+ return get_operation_attribute(ctx.operation, attribute) in expected
84
+
85
+
86
+ def by_regex(ctx: HasAPIOperation, attribute: str, regex: re.Pattern) -> bool:
87
+ value = get_operation_attribute(ctx.operation, attribute)
88
+ return bool(regex.match(value))
89
+
90
+
91
+ @dataclass(repr=False, frozen=True)
92
+ class Filter:
93
+ """Match API operations against a list of matchers."""
94
+
95
+ matchers: Tuple[Matcher, ...]
96
+
97
+ def __repr__(self) -> str:
98
+ inner = " && ".join(matcher.label for matcher in self.matchers)
99
+ return f"<{self.__class__.__name__}: [{inner}]>"
100
+
101
+ def match(self, ctx: HasAPIOperation) -> bool:
102
+ """Whether the operation matches the filter.
103
+
104
+ Returns `True` only if all matchers matched.
105
+ """
106
+ return all(matcher.match(ctx) for matcher in self.matchers)
107
+
108
+
109
+ @dataclass
110
+ class FilterSet:
111
+ """Combines multiple filters to apply inclusion and exclusion rules on API operations."""
112
+
113
+ _includes: Set[Filter] = field(default_factory=set)
114
+ _excludes: Set[Filter] = field(default_factory=set)
115
+
116
+ def apply_to(self, operations: List["APIOperation"]) -> List["APIOperation"]:
117
+ """Get a filtered list of the given operations that match the filters."""
118
+ return [operation for operation in operations if self.match(SimpleNamespace(operation=operation))]
119
+
120
+ def match(self, ctx: HasAPIOperation) -> bool:
121
+ """Determines whether the given operation should be included based on the defined filters.
122
+
123
+ Returns True if the operation:
124
+ - matches at least one INCLUDE filter OR no INCLUDE filters defined;
125
+ - does not match any EXCLUDE filter;
126
+ False otherwise.
127
+ """
128
+ # Exclude early if the operation is excluded by at least one EXCLUDE filter
129
+ for filter_ in self._excludes:
130
+ if filter_.match(ctx):
131
+ return False
132
+ if not self._includes:
133
+ # No includes - nothing to filter out, include the operation
134
+ return True
135
+ # Otherwise check if the operation is included by at least one INCLUDE filter
136
+ return any(filter_.match(ctx) for filter_ in self._includes)
137
+
138
+ def is_empty(self) -> bool:
139
+ """Whether the filter set does not contain any filters."""
140
+ return not self._includes and not self._excludes
141
+
142
+ def include(
143
+ self,
144
+ func: Optional[MatcherFunc] = None,
145
+ *,
146
+ name: Optional[FilterValue] = None,
147
+ name_regex: Optional[RegexValue] = None,
148
+ method: Optional[FilterValue] = None,
149
+ method_regex: Optional[RegexValue] = None,
150
+ path: Optional[FilterValue] = None,
151
+ path_regex: Optional[RegexValue] = None,
152
+ ) -> None:
153
+ """Add a new INCLUDE filter."""
154
+ self._add_filter(
155
+ True,
156
+ func=func,
157
+ name=name,
158
+ name_regex=name_regex,
159
+ method=method,
160
+ method_regex=method_regex,
161
+ path=path,
162
+ path_regex=path_regex,
163
+ )
164
+
165
+ def exclude(
166
+ self,
167
+ func: Optional[MatcherFunc] = None,
168
+ *,
169
+ name: Optional[FilterValue] = None,
170
+ name_regex: Optional[RegexValue] = None,
171
+ method: Optional[FilterValue] = None,
172
+ method_regex: Optional[RegexValue] = None,
173
+ path: Optional[FilterValue] = None,
174
+ path_regex: Optional[RegexValue] = None,
175
+ ) -> None:
176
+ """Add a new EXCLUDE filter."""
177
+ self._add_filter(
178
+ False,
179
+ func=func,
180
+ name=name,
181
+ name_regex=name_regex,
182
+ method=method,
183
+ method_regex=method_regex,
184
+ path=path,
185
+ path_regex=path_regex,
186
+ )
187
+
188
+ def _add_filter(
189
+ self,
190
+ include: bool,
191
+ *,
192
+ func: Optional[MatcherFunc] = None,
193
+ name: Optional[FilterValue] = None,
194
+ name_regex: Optional[RegexValue] = None,
195
+ method: Optional[FilterValue] = None,
196
+ method_regex: Optional[RegexValue] = None,
197
+ path: Optional[FilterValue] = None,
198
+ path_regex: Optional[RegexValue] = None,
199
+ ) -> None:
200
+ matchers = []
201
+ if func is not None:
202
+ matchers.append(Matcher.for_function(func))
203
+ for attribute, expected, regex in (
204
+ ("verbose_name", name, name_regex),
205
+ ("method", method, method_regex),
206
+ ("path", path, path_regex),
207
+ ):
208
+ if expected is not None and regex is not None:
209
+ # To match anything the regex should match the expected value, hence passing them together is useless
210
+ raise UsageError(ERROR_EXPECTED_AND_REGEX)
211
+ if expected is not None:
212
+ matchers.append(Matcher.for_value(attribute, expected))
213
+ if regex is not None:
214
+ matchers.append(Matcher.for_regex(attribute, regex))
215
+
216
+ if not matchers:
217
+ raise UsageError(ERROR_EMPTY_FILTER)
218
+ filter_ = Filter(matchers=tuple(matchers))
219
+ if filter_ in self._includes or filter_ in self._excludes:
220
+ raise UsageError(ERROR_FILTER_EXISTS)
221
+ if include:
222
+ self._includes.add(filter_)
223
+ else:
224
+ self._excludes.add(filter_)
225
+
226
+
227
+ def attach_filter_chain(
228
+ target: Callable,
229
+ attribute: str,
230
+ filter_func: Callable[..., None],
231
+ ) -> None:
232
+ """Attach a filtering function to an object, which allows chaining of filter criteria.
233
+
234
+ For example:
235
+
236
+ >>> def auth(): ...
237
+ >>> filter_set = FilterSet()
238
+ >>> attach_filter_chain(auth, "apply_to", filter_set.include)
239
+ >>> auth.apply_to(method="GET", path="/users/")
240
+
241
+ This will add a new `apply_to` method to `auth` that matches only the `GET /users/` operation.
242
+ """
243
+
244
+ def proxy(
245
+ func: Optional[MatcherFunc] = None,
246
+ *,
247
+ name: Optional[FilterValue] = None,
248
+ name_regex: Optional[str] = None,
249
+ method: Optional[FilterValue] = None,
250
+ method_regex: Optional[str] = None,
251
+ path: Optional[FilterValue] = None,
252
+ path_regex: Optional[str] = None,
253
+ ) -> Callable:
254
+ __tracebackhide__ = True
255
+ filter_func(
256
+ func=func,
257
+ name=name,
258
+ name_regex=name_regex,
259
+ method=method,
260
+ method_regex=method_regex,
261
+ path=path,
262
+ path_regex=path_regex,
263
+ )
264
+ return target
265
+
266
+ proxy.__qualname__ = attribute
267
+ proxy.__name__ = attribute
268
+
269
+ setattr(target, attribute, proxy)