schemathesis 3.19.7__py3-none-any.whl → 3.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- schemathesis/_compat.py +3 -2
- schemathesis/_hypothesis.py +21 -6
- schemathesis/_xml.py +177 -0
- schemathesis/auths.py +48 -10
- schemathesis/cli/__init__.py +77 -19
- schemathesis/cli/callbacks.py +42 -18
- schemathesis/cli/context.py +2 -1
- schemathesis/cli/output/default.py +102 -34
- schemathesis/cli/sanitization.py +15 -0
- schemathesis/code_samples.py +141 -0
- schemathesis/constants.py +1 -24
- schemathesis/exceptions.py +127 -26
- schemathesis/experimental/__init__.py +85 -0
- schemathesis/extra/pytest_plugin.py +10 -4
- schemathesis/fixups/__init__.py +8 -2
- schemathesis/fixups/fast_api.py +11 -1
- schemathesis/fixups/utf8_bom.py +7 -1
- schemathesis/hooks.py +63 -0
- schemathesis/lazy.py +10 -4
- schemathesis/loaders.py +57 -0
- schemathesis/models.py +120 -96
- schemathesis/parameters.py +3 -0
- schemathesis/runner/__init__.py +3 -0
- schemathesis/runner/events.py +55 -20
- schemathesis/runner/impl/core.py +54 -54
- schemathesis/runner/serialization.py +75 -34
- schemathesis/sanitization.py +248 -0
- schemathesis/schemas.py +21 -6
- schemathesis/serializers.py +32 -3
- schemathesis/service/serialization.py +5 -1
- schemathesis/specs/graphql/loaders.py +44 -13
- schemathesis/specs/graphql/schemas.py +56 -25
- schemathesis/specs/openapi/_hypothesis.py +11 -23
- schemathesis/specs/openapi/definitions.py +572 -0
- schemathesis/specs/openapi/loaders.py +100 -49
- schemathesis/specs/openapi/parameters.py +2 -2
- schemathesis/specs/openapi/schemas.py +87 -13
- schemathesis/specs/openapi/security.py +1 -0
- schemathesis/stateful.py +2 -2
- schemathesis/utils.py +30 -9
- schemathesis-3.20.1.dist-info/METADATA +342 -0
- {schemathesis-3.19.7.dist-info → schemathesis-3.20.1.dist-info}/RECORD +45 -39
- schemathesis-3.19.7.dist-info/METADATA +0 -291
- {schemathesis-3.19.7.dist-info → schemathesis-3.20.1.dist-info}/WHEEL +0 -0
- {schemathesis-3.19.7.dist-info → schemathesis-3.20.1.dist-info}/entry_points.txt +0 -0
- {schemathesis-3.19.7.dist-info → schemathesis-3.20.1.dist-info}/licenses/LICENSE +0 -0
schemathesis/runner/impl/core.py
CHANGED
|
@@ -14,6 +14,7 @@ import requests
|
|
|
14
14
|
from _pytest.logging import LogCaptureHandler, catching_logs
|
|
15
15
|
from hypothesis.errors import HypothesisException, InvalidArgument
|
|
16
16
|
from hypothesis_jsonschema._canonicalise import HypothesisRefResolutionError
|
|
17
|
+
from jsonschema.exceptions import ValidationError
|
|
17
18
|
from requests.auth import HTTPDigestAuth, _basic_auth_str
|
|
18
19
|
|
|
19
20
|
from ... import failures, hooks
|
|
@@ -29,8 +30,8 @@ from ...exceptions import (
|
|
|
29
30
|
CheckFailed,
|
|
30
31
|
DeadlineExceeded,
|
|
31
32
|
InvalidRegularExpression,
|
|
32
|
-
InvalidSchema,
|
|
33
33
|
NonCheckError,
|
|
34
|
+
OperationSchemaError,
|
|
34
35
|
SkipTest,
|
|
35
36
|
get_grouped_exception,
|
|
36
37
|
)
|
|
@@ -184,7 +185,7 @@ class BaseRunner:
|
|
|
184
185
|
headers=headers,
|
|
185
186
|
**kwargs,
|
|
186
187
|
)
|
|
187
|
-
except
|
|
188
|
+
except OperationSchemaError as exc:
|
|
188
189
|
yield from handle_schema_error(
|
|
189
190
|
exc,
|
|
190
191
|
results,
|
|
@@ -229,7 +230,7 @@ class EventStream:
|
|
|
229
230
|
|
|
230
231
|
|
|
231
232
|
def handle_schema_error(
|
|
232
|
-
error:
|
|
233
|
+
error: OperationSchemaError,
|
|
233
234
|
results: TestResultSet,
|
|
234
235
|
data_generation_methods: Iterable[DataGenerationMethod],
|
|
235
236
|
recursion_level: int,
|
|
@@ -299,7 +300,6 @@ def run_test(
|
|
|
299
300
|
method=operation.method.upper(),
|
|
300
301
|
path=operation.full_path,
|
|
301
302
|
verbose_name=operation.verbose_name,
|
|
302
|
-
overridden_headers=headers,
|
|
303
303
|
data_generation_method=data_generation_methods,
|
|
304
304
|
)
|
|
305
305
|
# To simplify connecting `before` and `after` events in external systems
|
|
@@ -364,8 +364,8 @@ def run_test(
|
|
|
364
364
|
except SkipTest:
|
|
365
365
|
status = Status.skip
|
|
366
366
|
result.mark_skipped()
|
|
367
|
-
except AssertionError
|
|
368
|
-
error = reraise(
|
|
367
|
+
except AssertionError: # comes from `hypothesis-jsonschema`
|
|
368
|
+
error = reraise(operation)
|
|
369
369
|
status = Status.error
|
|
370
370
|
result.add_error(error)
|
|
371
371
|
except HypothesisRefResolutionError:
|
|
@@ -440,7 +440,7 @@ def has_all_not_found(results: TestResultSet) -> bool:
|
|
|
440
440
|
else:
|
|
441
441
|
# There are non-404 responses, no reason to check any other response
|
|
442
442
|
return False
|
|
443
|
-
# Only happens if all responses are 404,
|
|
443
|
+
# Only happens if all responses are 404, or there are no responses at all.
|
|
444
444
|
# In the first case, it returns True, for the latter - False
|
|
445
445
|
return has_not_found
|
|
446
446
|
|
|
@@ -466,16 +466,14 @@ def get_invalid_regular_expression_message(warnings: List[WarningMessage]) -> Op
|
|
|
466
466
|
return None
|
|
467
467
|
|
|
468
468
|
|
|
469
|
-
def reraise(
|
|
470
|
-
traceback = format_exception(error, True)
|
|
471
|
-
if "assert type_ in TYPE_STRINGS" in traceback:
|
|
472
|
-
message = "Invalid type name"
|
|
473
|
-
else:
|
|
474
|
-
message = "Unknown schema error"
|
|
469
|
+
def reraise(operation: APIOperation) -> OperationSchemaError:
|
|
475
470
|
try:
|
|
476
|
-
|
|
477
|
-
except
|
|
478
|
-
return
|
|
471
|
+
operation.schema.validate()
|
|
472
|
+
except ValidationError as exc:
|
|
473
|
+
return OperationSchemaError.from_jsonschema_error(
|
|
474
|
+
exc, path=operation.path, method=operation.method, full_path=operation.schema.get_full_path(operation.path)
|
|
475
|
+
)
|
|
476
|
+
return OperationSchemaError("Unknown schema error")
|
|
479
477
|
|
|
480
478
|
|
|
481
479
|
def deduplicate_errors(errors: List[Exception]) -> Generator[Exception, None, None]:
|
|
@@ -490,6 +488,7 @@ def deduplicate_errors(errors: List[Exception]) -> Generator[Exception, None, No
|
|
|
490
488
|
|
|
491
489
|
|
|
492
490
|
def run_checks(
|
|
491
|
+
*,
|
|
493
492
|
case: Case,
|
|
494
493
|
checks: Iterable[CheckFunction],
|
|
495
494
|
check_results: List[Check],
|
|
@@ -623,24 +622,7 @@ def network_test(
|
|
|
623
622
|
headers["User-Agent"] = USER_AGENT
|
|
624
623
|
timeout = prepare_timeout(request_timeout)
|
|
625
624
|
if not dry_run:
|
|
626
|
-
|
|
627
|
-
case,
|
|
628
|
-
checks,
|
|
629
|
-
targets,
|
|
630
|
-
result,
|
|
631
|
-
session,
|
|
632
|
-
timeout,
|
|
633
|
-
store_interactions,
|
|
634
|
-
headers,
|
|
635
|
-
feedback,
|
|
636
|
-
request_tls_verify,
|
|
637
|
-
request_cert,
|
|
638
|
-
max_response_time,
|
|
639
|
-
)
|
|
640
|
-
add_cases(
|
|
641
|
-
case,
|
|
642
|
-
response,
|
|
643
|
-
_network_test,
|
|
625
|
+
args = (
|
|
644
626
|
checks,
|
|
645
627
|
targets,
|
|
646
628
|
result,
|
|
@@ -653,6 +635,8 @@ def network_test(
|
|
|
653
635
|
request_cert,
|
|
654
636
|
max_response_time,
|
|
655
637
|
)
|
|
638
|
+
response = _network_test(case, *args)
|
|
639
|
+
add_cases(case, response, _network_test, *args)
|
|
656
640
|
|
|
657
641
|
|
|
658
642
|
def _network_test(
|
|
@@ -695,14 +679,22 @@ def _network_test(
|
|
|
695
679
|
run_targets(targets, context)
|
|
696
680
|
status = Status.success
|
|
697
681
|
try:
|
|
698
|
-
run_checks(
|
|
682
|
+
run_checks(
|
|
683
|
+
case=case,
|
|
684
|
+
checks=checks,
|
|
685
|
+
check_results=check_results,
|
|
686
|
+
result=result,
|
|
687
|
+
response=response,
|
|
688
|
+
elapsed_time=context.response_time * 1000,
|
|
689
|
+
max_response_time=max_response_time,
|
|
690
|
+
)
|
|
699
691
|
except CheckFailed:
|
|
700
692
|
status = Status.failure
|
|
701
693
|
raise
|
|
702
694
|
finally:
|
|
695
|
+
feedback.add_test_case(case, response)
|
|
703
696
|
if store_interactions:
|
|
704
697
|
result.store_requests_response(case, response, status, check_results)
|
|
705
|
-
feedback.add_test_case(case, response)
|
|
706
698
|
return response
|
|
707
699
|
|
|
708
700
|
|
|
@@ -742,13 +734,7 @@ def wsgi_test(
|
|
|
742
734
|
result.mark_executed()
|
|
743
735
|
headers = _prepare_wsgi_headers(headers, auth, auth_type)
|
|
744
736
|
if not dry_run:
|
|
745
|
-
|
|
746
|
-
case, checks, targets, result, headers, store_interactions, feedback, max_response_time
|
|
747
|
-
)
|
|
748
|
-
add_cases(
|
|
749
|
-
case,
|
|
750
|
-
response,
|
|
751
|
-
_wsgi_test,
|
|
737
|
+
args = (
|
|
752
738
|
checks,
|
|
753
739
|
targets,
|
|
754
740
|
result,
|
|
@@ -757,6 +743,8 @@ def wsgi_test(
|
|
|
757
743
|
feedback,
|
|
758
744
|
max_response_time,
|
|
759
745
|
)
|
|
746
|
+
response = _wsgi_test(case, *args)
|
|
747
|
+
add_cases(case, response, _wsgi_test, *args)
|
|
760
748
|
|
|
761
749
|
|
|
762
750
|
def _wsgi_test(
|
|
@@ -782,14 +770,22 @@ def _wsgi_test(
|
|
|
782
770
|
status = Status.success
|
|
783
771
|
check_results: List[Check] = []
|
|
784
772
|
try:
|
|
785
|
-
run_checks(
|
|
773
|
+
run_checks(
|
|
774
|
+
case=case,
|
|
775
|
+
checks=checks,
|
|
776
|
+
check_results=check_results,
|
|
777
|
+
result=result,
|
|
778
|
+
response=response,
|
|
779
|
+
elapsed_time=context.response_time * 1000,
|
|
780
|
+
max_response_time=max_response_time,
|
|
781
|
+
)
|
|
786
782
|
except CheckFailed:
|
|
787
783
|
status = Status.failure
|
|
788
784
|
raise
|
|
789
785
|
finally:
|
|
786
|
+
feedback.add_test_case(case, response)
|
|
790
787
|
if store_interactions:
|
|
791
788
|
result.store_wsgi_response(case, response, headers, elapsed, status, check_results)
|
|
792
|
-
feedback.add_test_case(case, response)
|
|
793
789
|
return response
|
|
794
790
|
|
|
795
791
|
|
|
@@ -833,13 +829,7 @@ def asgi_test(
|
|
|
833
829
|
headers = headers or {}
|
|
834
830
|
|
|
835
831
|
if not dry_run:
|
|
836
|
-
|
|
837
|
-
case, checks, targets, result, store_interactions, headers, feedback, max_response_time
|
|
838
|
-
)
|
|
839
|
-
add_cases(
|
|
840
|
-
case,
|
|
841
|
-
response,
|
|
842
|
-
_asgi_test,
|
|
832
|
+
args = (
|
|
843
833
|
checks,
|
|
844
834
|
targets,
|
|
845
835
|
result,
|
|
@@ -848,6 +838,8 @@ def asgi_test(
|
|
|
848
838
|
feedback,
|
|
849
839
|
max_response_time,
|
|
850
840
|
)
|
|
841
|
+
response = _asgi_test(case, *args)
|
|
842
|
+
add_cases(case, response, _asgi_test, *args)
|
|
851
843
|
|
|
852
844
|
|
|
853
845
|
def _asgi_test(
|
|
@@ -869,12 +861,20 @@ def _asgi_test(
|
|
|
869
861
|
status = Status.success
|
|
870
862
|
check_results: List[Check] = []
|
|
871
863
|
try:
|
|
872
|
-
run_checks(
|
|
864
|
+
run_checks(
|
|
865
|
+
case=case,
|
|
866
|
+
checks=checks,
|
|
867
|
+
check_results=check_results,
|
|
868
|
+
result=result,
|
|
869
|
+
response=response,
|
|
870
|
+
elapsed_time=context.response_time * 1000,
|
|
871
|
+
max_response_time=max_response_time,
|
|
872
|
+
)
|
|
873
873
|
except CheckFailed:
|
|
874
874
|
status = Status.failure
|
|
875
875
|
raise
|
|
876
876
|
finally:
|
|
877
|
+
feedback.add_test_case(case, response)
|
|
877
878
|
if store_interactions:
|
|
878
879
|
result.store_requests_response(case, response, status, check_results)
|
|
879
|
-
feedback.add_test_case(case, response)
|
|
880
880
|
return response
|
|
@@ -4,44 +4,71 @@ They all consist of primitive types and don't have references to schemas, app, e
|
|
|
4
4
|
"""
|
|
5
5
|
import logging
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
|
-
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
7
|
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
8
8
|
|
|
9
9
|
import requests
|
|
10
|
+
from requests.structures import CaseInsensitiveDict
|
|
10
11
|
|
|
12
|
+
from ..code_samples import EXCLUDED_HEADERS
|
|
11
13
|
from ..exceptions import FailureContext, InternalError, make_unique_by_key
|
|
12
|
-
from ..models import Case, Check, Interaction, Request, Response, Status, TestResult
|
|
13
|
-
from ..utils import
|
|
14
|
+
from ..models import Case, Check, Interaction, Request, Response, Status, TestResult, serialize_payload
|
|
15
|
+
from ..utils import WSGIResponse, format_exception
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
@dataclass
|
|
17
19
|
class SerializedCase:
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
path_template: str
|
|
20
|
+
# Case data
|
|
21
|
+
id: str
|
|
21
22
|
path_parameters: Optional[Dict[str, Any]]
|
|
22
|
-
|
|
23
|
+
headers: Optional[Dict[str, Any]]
|
|
23
24
|
cookies: Optional[Dict[str, Any]]
|
|
24
|
-
|
|
25
|
-
|
|
25
|
+
query: Optional[Dict[str, Any]]
|
|
26
|
+
body: Optional[str]
|
|
26
27
|
media_type: Optional[str]
|
|
28
|
+
data_generation_method: Optional[str]
|
|
29
|
+
# Operation data
|
|
30
|
+
method: str
|
|
31
|
+
url: str
|
|
32
|
+
path_template: str
|
|
33
|
+
verbose_name: str
|
|
34
|
+
# Transport info
|
|
35
|
+
verify: bool
|
|
36
|
+
# Headers coming from sources outside data generation
|
|
37
|
+
extra_headers: Dict[str, Any]
|
|
27
38
|
|
|
28
39
|
@classmethod
|
|
29
|
-
def from_case(cls, case: Case, headers: Optional[Dict[str, Any]]) -> "SerializedCase":
|
|
40
|
+
def from_case(cls, case: Case, headers: Optional[Dict[str, Any]], verify: bool) -> "SerializedCase":
|
|
41
|
+
# `headers` include not only explicitly provided headers but also ones added by hooks, custom auth, etc.
|
|
42
|
+
request_data = case.prepare_code_sample_data(headers)
|
|
43
|
+
serialized_body = _serialize_body(request_data.body)
|
|
30
44
|
return cls(
|
|
31
|
-
|
|
32
|
-
curl_code=case.as_curl_command(headers),
|
|
33
|
-
path_template=case.path,
|
|
45
|
+
id=case.id,
|
|
34
46
|
path_parameters=case.path_parameters,
|
|
35
|
-
|
|
47
|
+
headers=dict(case.headers) if case.headers is not None else None,
|
|
36
48
|
cookies=case.cookies,
|
|
37
|
-
|
|
49
|
+
query=case.query,
|
|
50
|
+
body=serialized_body,
|
|
51
|
+
media_type=case.media_type,
|
|
38
52
|
data_generation_method=case.data_generation_method.as_short_name()
|
|
39
53
|
if case.data_generation_method is not None
|
|
40
54
|
else None,
|
|
41
|
-
|
|
55
|
+
method=case.method,
|
|
56
|
+
url=request_data.url,
|
|
57
|
+
path_template=case.path,
|
|
58
|
+
verbose_name=case.operation.verbose_name,
|
|
59
|
+
verify=verify,
|
|
60
|
+
extra_headers=request_data.headers,
|
|
42
61
|
)
|
|
43
62
|
|
|
44
63
|
|
|
64
|
+
def _serialize_body(body: Optional[Union[str, bytes]]) -> Optional[str]:
|
|
65
|
+
if body is None:
|
|
66
|
+
return None
|
|
67
|
+
if isinstance(body, str):
|
|
68
|
+
body = body.encode("utf-8")
|
|
69
|
+
return serialize_payload(body)
|
|
70
|
+
|
|
71
|
+
|
|
45
72
|
@dataclass
|
|
46
73
|
class SerializedCheck:
|
|
47
74
|
# Check name
|
|
@@ -75,23 +102,14 @@ class SerializedCheck:
|
|
|
75
102
|
response = Response.from_wsgi(check.response, check.elapsed)
|
|
76
103
|
else:
|
|
77
104
|
response = None
|
|
78
|
-
headers =
|
|
79
|
-
history =
|
|
80
|
-
case = check.example
|
|
81
|
-
while case.source is not None:
|
|
82
|
-
if isinstance(case.source.response, requests.Response):
|
|
83
|
-
history_response = Response.from_requests(case.source.response)
|
|
84
|
-
else:
|
|
85
|
-
history_response = Response.from_wsgi(case.source.response, case.source.elapsed)
|
|
86
|
-
entry = SerializedHistoryEntry(
|
|
87
|
-
case=SerializedCase.from_case(case.source.case, headers), response=history_response
|
|
88
|
-
)
|
|
89
|
-
history.append(entry)
|
|
90
|
-
case = case.source.case
|
|
105
|
+
headers = _get_headers(request.headers)
|
|
106
|
+
history = get_serialized_history(check.example)
|
|
91
107
|
return cls(
|
|
92
108
|
name=check.name,
|
|
93
109
|
value=check.value,
|
|
94
|
-
example=SerializedCase.from_case(
|
|
110
|
+
example=SerializedCase.from_case(
|
|
111
|
+
check.example, headers, verify=response.verify if response is not None else True
|
|
112
|
+
),
|
|
95
113
|
message=check.message,
|
|
96
114
|
request=request,
|
|
97
115
|
response=response,
|
|
@@ -100,27 +118,50 @@ class SerializedCheck:
|
|
|
100
118
|
)
|
|
101
119
|
|
|
102
120
|
|
|
121
|
+
def _get_headers(headers: Union[Dict[str, Any], CaseInsensitiveDict]) -> Dict[str, str]:
|
|
122
|
+
return {key: value[0] for key, value in headers.items() if key not in EXCLUDED_HEADERS}
|
|
123
|
+
|
|
124
|
+
|
|
103
125
|
@dataclass
|
|
104
126
|
class SerializedHistoryEntry:
|
|
105
127
|
case: SerializedCase
|
|
106
128
|
response: Response
|
|
107
129
|
|
|
108
130
|
|
|
131
|
+
def get_serialized_history(case: Case) -> List[SerializedHistoryEntry]:
|
|
132
|
+
history = []
|
|
133
|
+
while case.source is not None:
|
|
134
|
+
history_request = case.source.response.request
|
|
135
|
+
headers = _get_headers(history_request.headers)
|
|
136
|
+
if isinstance(case.source.response, requests.Response):
|
|
137
|
+
history_response = Response.from_requests(case.source.response)
|
|
138
|
+
verify = history_response.verify
|
|
139
|
+
else:
|
|
140
|
+
history_response = Response.from_wsgi(case.source.response, case.source.elapsed)
|
|
141
|
+
verify = True
|
|
142
|
+
entry = SerializedHistoryEntry(
|
|
143
|
+
case=SerializedCase.from_case(case.source.case, headers, verify=verify), response=history_response
|
|
144
|
+
)
|
|
145
|
+
history.append(entry)
|
|
146
|
+
case = case.source.case
|
|
147
|
+
return history
|
|
148
|
+
|
|
149
|
+
|
|
109
150
|
@dataclass
|
|
110
151
|
class SerializedError:
|
|
111
152
|
exception: str
|
|
112
153
|
exception_with_traceback: str
|
|
113
|
-
example: Optional[SerializedCase]
|
|
114
154
|
title: Optional[str]
|
|
115
155
|
|
|
116
156
|
@classmethod
|
|
117
157
|
def from_error(
|
|
118
|
-
cls,
|
|
158
|
+
cls,
|
|
159
|
+
exception: Exception,
|
|
160
|
+
title: Optional[str] = None,
|
|
119
161
|
) -> "SerializedError":
|
|
120
162
|
return cls(
|
|
121
163
|
exception=format_exception(exception),
|
|
122
164
|
exception_with_traceback=format_exception(exception, True),
|
|
123
|
-
example=SerializedCase.from_case(case, headers) if case else None,
|
|
124
165
|
title=title,
|
|
125
166
|
)
|
|
126
167
|
|
|
@@ -179,7 +220,7 @@ class SerializedTestResult:
|
|
|
179
220
|
data_generation_method=[m.as_short_name() for m in result.data_generation_method],
|
|
180
221
|
checks=[SerializedCheck.from_check(check) for check in result.checks],
|
|
181
222
|
logs=[formatter.format(record) for record in result.logs],
|
|
182
|
-
errors=[SerializedError.from_error(
|
|
223
|
+
errors=[SerializedError.from_error(error) for error in result.errors],
|
|
183
224
|
interactions=[SerializedInteraction.from_interaction(interaction) for interaction in result.interactions],
|
|
184
225
|
)
|
|
185
226
|
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from collections.abc import MutableMapping, MutableSequence
|
|
3
|
+
from dataclasses import dataclass, replace
|
|
4
|
+
from typing import TYPE_CHECKING, Any, FrozenSet, Optional, Union, cast
|
|
5
|
+
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
|
|
6
|
+
|
|
7
|
+
from requests import PreparedRequest
|
|
8
|
+
|
|
9
|
+
from .utils import NOT_SET
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from .models import Case, CaseSource, Request
|
|
13
|
+
from .runner.serialization import SerializedCase, SerializedCheck, SerializedInteraction
|
|
14
|
+
from .utils import GenericResponse
|
|
15
|
+
|
|
16
|
+
# Exact keys to sanitize
|
|
17
|
+
DEFAULT_KEYS_TO_SANITIZE = frozenset(
|
|
18
|
+
(
|
|
19
|
+
"phpsessid",
|
|
20
|
+
"xsrf-token",
|
|
21
|
+
"_csrf",
|
|
22
|
+
"_csrf_token",
|
|
23
|
+
"_session",
|
|
24
|
+
"_xsrf",
|
|
25
|
+
"aiohttp_session",
|
|
26
|
+
"api_key",
|
|
27
|
+
"api-key",
|
|
28
|
+
"apikey",
|
|
29
|
+
"auth",
|
|
30
|
+
"authorization",
|
|
31
|
+
"connect.sid",
|
|
32
|
+
"cookie",
|
|
33
|
+
"credentials",
|
|
34
|
+
"csrf",
|
|
35
|
+
"csrf_token",
|
|
36
|
+
"csrf-token",
|
|
37
|
+
"csrftoken",
|
|
38
|
+
"ip_address",
|
|
39
|
+
"mysql_pwd",
|
|
40
|
+
"passwd",
|
|
41
|
+
"password",
|
|
42
|
+
"private_key",
|
|
43
|
+
"private-key",
|
|
44
|
+
"privatekey",
|
|
45
|
+
"remote_addr",
|
|
46
|
+
"remote-addr",
|
|
47
|
+
"secret",
|
|
48
|
+
"session",
|
|
49
|
+
"sessionid",
|
|
50
|
+
"set_cookie",
|
|
51
|
+
"set-cookie",
|
|
52
|
+
"token",
|
|
53
|
+
"x_api_key",
|
|
54
|
+
"x-api-key",
|
|
55
|
+
"x_csrftoken",
|
|
56
|
+
"x-csrftoken",
|
|
57
|
+
"x_forwarded_for",
|
|
58
|
+
"x-forwarded-for",
|
|
59
|
+
"x_real_ip",
|
|
60
|
+
"x-real-ip",
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Markers indicating potentially sensitive keys
|
|
65
|
+
DEFAULT_SENSITIVE_MARKERS = frozenset(
|
|
66
|
+
(
|
|
67
|
+
"token",
|
|
68
|
+
"key",
|
|
69
|
+
"secret",
|
|
70
|
+
"password",
|
|
71
|
+
"auth",
|
|
72
|
+
"session",
|
|
73
|
+
"passwd",
|
|
74
|
+
"credential",
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
DEFAULT_REPLACEMENT = "[Filtered]"
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class Config:
|
|
83
|
+
"""Configuration class for sanitizing sensitive data.
|
|
84
|
+
|
|
85
|
+
:param FrozenSet[str] keys_to_sanitize: The exact keys to sanitize (case-insensitive).
|
|
86
|
+
:param FrozenSet[str] sensitive_markers: Markers indicating potentially sensitive keys (case-insensitive).
|
|
87
|
+
:param str replacement: The replacement string for sanitized values.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
keys_to_sanitize: FrozenSet[str] = DEFAULT_KEYS_TO_SANITIZE
|
|
91
|
+
sensitive_markers: FrozenSet[str] = DEFAULT_SENSITIVE_MARKERS
|
|
92
|
+
replacement: str = DEFAULT_REPLACEMENT
|
|
93
|
+
|
|
94
|
+
def with_keys_to_sanitize(self, *keys: str) -> "Config":
|
|
95
|
+
"""Create a new configuration with additional keys to sanitize."""
|
|
96
|
+
new_keys_to_sanitize = self.keys_to_sanitize.union([key.lower() for key in keys])
|
|
97
|
+
return replace(self, keys_to_sanitize=frozenset(new_keys_to_sanitize))
|
|
98
|
+
|
|
99
|
+
def without_keys_to_sanitize(self, *keys: str) -> "Config":
|
|
100
|
+
"""Create a new configuration without certain keys to sanitize."""
|
|
101
|
+
new_keys_to_sanitize = self.keys_to_sanitize.difference([key.lower() for key in keys])
|
|
102
|
+
return replace(self, keys_to_sanitize=frozenset(new_keys_to_sanitize))
|
|
103
|
+
|
|
104
|
+
def with_sensitive_markers(self, *markers: str) -> "Config":
|
|
105
|
+
"""Create a new configuration with additional sensitive markers."""
|
|
106
|
+
new_sensitive_markers = self.sensitive_markers.union([key.lower() for key in markers])
|
|
107
|
+
return replace(self, sensitive_markers=frozenset(new_sensitive_markers))
|
|
108
|
+
|
|
109
|
+
def without_sensitive_markers(self, *markers: str) -> "Config":
|
|
110
|
+
"""Create a new configuration without certain sensitive markers."""
|
|
111
|
+
new_sensitive_markers = self.sensitive_markers.difference([key.lower() for key in markers])
|
|
112
|
+
return replace(self, sensitive_markers=frozenset(new_sensitive_markers))
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
_thread_local = threading.local()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _get_default_sanitization_config() -> Config:
|
|
119
|
+
# Initialize the thread-local default sanitization config if not already set
|
|
120
|
+
if not hasattr(_thread_local, "default_sanitization_config"):
|
|
121
|
+
_thread_local.default_sanitization_config = Config()
|
|
122
|
+
return _thread_local.default_sanitization_config
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def configure(config: Config) -> None:
|
|
126
|
+
_thread_local.default_sanitization_config = config
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def sanitize_value(item: Any, *, config: Optional[Config] = None) -> None:
|
|
130
|
+
"""Sanitize sensitive values within a given item.
|
|
131
|
+
|
|
132
|
+
This function is recursive and will sanitize sensitive data within nested
|
|
133
|
+
dictionaries and lists as well.
|
|
134
|
+
"""
|
|
135
|
+
config = config or _get_default_sanitization_config()
|
|
136
|
+
if isinstance(item, MutableMapping):
|
|
137
|
+
for key in list(item.keys()):
|
|
138
|
+
lower_key = key.lower()
|
|
139
|
+
if lower_key in config.keys_to_sanitize or any(marker in lower_key for marker in config.sensitive_markers):
|
|
140
|
+
if isinstance(item[key], list):
|
|
141
|
+
item[key] = [config.replacement]
|
|
142
|
+
else:
|
|
143
|
+
item[key] = config.replacement
|
|
144
|
+
for value in item.values():
|
|
145
|
+
if isinstance(value, (MutableMapping, MutableSequence)):
|
|
146
|
+
sanitize_value(value, config=config)
|
|
147
|
+
elif isinstance(item, MutableSequence):
|
|
148
|
+
for value in item:
|
|
149
|
+
if isinstance(value, (MutableMapping, MutableSequence)):
|
|
150
|
+
sanitize_value(value, config=config)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def sanitize_case(case: "Case", *, config: Optional[Config] = None) -> None:
|
|
154
|
+
"""Sanitize sensitive values within a given case."""
|
|
155
|
+
if case.path_parameters is not None:
|
|
156
|
+
sanitize_value(case.path_parameters, config=config)
|
|
157
|
+
if case.headers is not None:
|
|
158
|
+
sanitize_value(case.headers, config=config)
|
|
159
|
+
if case.cookies is not None:
|
|
160
|
+
sanitize_value(case.cookies, config=config)
|
|
161
|
+
if case.query is not None:
|
|
162
|
+
sanitize_value(case.query, config=config)
|
|
163
|
+
if case.body not in (None, NOT_SET):
|
|
164
|
+
sanitize_value(case.body, config=config)
|
|
165
|
+
if case.source is not None:
|
|
166
|
+
sanitize_history(case.source, config=config)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def sanitize_history(source: "CaseSource", *, config: Optional[Config] = None) -> None:
|
|
170
|
+
"""Recursively sanitize history of case/response pairs."""
|
|
171
|
+
current: Optional["CaseSource"] = source
|
|
172
|
+
while current is not None:
|
|
173
|
+
sanitize_case(current.case, config=config)
|
|
174
|
+
sanitize_response(current.response, config=config)
|
|
175
|
+
current = current.case.source
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def sanitize_response(response: "GenericResponse", *, config: Optional[Config] = None) -> None:
|
|
179
|
+
# Sanitize headers
|
|
180
|
+
sanitize_value(response.headers, config=config)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def sanitize_request(request: Union[PreparedRequest, "Request"], *, config: Optional[Config] = None) -> None:
|
|
184
|
+
if isinstance(request, PreparedRequest) and request.url:
|
|
185
|
+
request.url = sanitize_url(request.url, config=config)
|
|
186
|
+
else:
|
|
187
|
+
request = cast("Request", request)
|
|
188
|
+
request.uri = sanitize_url(request.uri, config=config)
|
|
189
|
+
# Sanitize headers
|
|
190
|
+
sanitize_value(request.headers, config=config)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def sanitize_output(
|
|
194
|
+
case: "Case", response: Optional["GenericResponse"] = None, *, config: Optional[Config] = None
|
|
195
|
+
) -> None:
|
|
196
|
+
sanitize_case(case, config=config)
|
|
197
|
+
if response is not None:
|
|
198
|
+
sanitize_response(response, config=config)
|
|
199
|
+
sanitize_request(response.request, config=config)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def sanitize_url(url: str, *, config: Optional[Config] = None) -> str:
|
|
203
|
+
"""Sanitize sensitive parts of a given URL.
|
|
204
|
+
|
|
205
|
+
This function will sanitize the authority and query parameters in the URL.
|
|
206
|
+
"""
|
|
207
|
+
config = config or _get_default_sanitization_config()
|
|
208
|
+
parsed = urlsplit(url)
|
|
209
|
+
|
|
210
|
+
# Sanitize authority
|
|
211
|
+
netloc_parts = parsed.netloc.split("@")
|
|
212
|
+
if len(netloc_parts) > 1:
|
|
213
|
+
netloc = f"{config.replacement}@{netloc_parts[-1]}"
|
|
214
|
+
else:
|
|
215
|
+
netloc = parsed.netloc
|
|
216
|
+
|
|
217
|
+
# Sanitize query parameters
|
|
218
|
+
query = parse_qs(parsed.query, keep_blank_values=True)
|
|
219
|
+
sanitize_value(query, config=config)
|
|
220
|
+
sanitized_query = urlencode(query, doseq=True)
|
|
221
|
+
|
|
222
|
+
# Reconstruct the URL
|
|
223
|
+
sanitized_url_parts = parsed._replace(netloc=netloc, query=sanitized_query)
|
|
224
|
+
return urlunsplit(sanitized_url_parts)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def sanitize_serialized_check(check: "SerializedCheck", *, config: Optional[Config] = None) -> None:
|
|
228
|
+
sanitize_request(check.request, config=config)
|
|
229
|
+
response = check.response
|
|
230
|
+
if response:
|
|
231
|
+
sanitize_value(response.headers, config=config)
|
|
232
|
+
sanitize_serialized_case(check.example, config=config)
|
|
233
|
+
for entry in check.history:
|
|
234
|
+
sanitize_serialized_case(entry.case, config=config)
|
|
235
|
+
sanitize_value(entry.response.headers, config=config)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def sanitize_serialized_case(case: "SerializedCase", *, config: Optional[Config] = None) -> None:
|
|
239
|
+
for value in (case.path_parameters, case.headers, case.cookies, case.query, case.extra_headers):
|
|
240
|
+
if value is not None:
|
|
241
|
+
sanitize_value(value, config=config)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def sanitize_serialized_interaction(interaction: "SerializedInteraction", *, config: Optional[Config] = None) -> None:
|
|
245
|
+
sanitize_request(interaction.request, config=config)
|
|
246
|
+
sanitize_value(interaction.response.headers, config=config)
|
|
247
|
+
for check in interaction.checks:
|
|
248
|
+
sanitize_serialized_check(check, config=config)
|