castor-extractor 0.19.0__py3-none-any.whl → 0.19.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of castor-extractor might be problematic. Click here for more details.

Files changed (83) hide show
  1. CHANGELOG.md +29 -2
  2. castor_extractor/file_checker/templates/generic_warehouse.py +1 -1
  3. castor_extractor/knowledge/notion/client/client.py +44 -80
  4. castor_extractor/knowledge/notion/client/client_test.py +9 -4
  5. castor_extractor/knowledge/notion/client/constants.py +1 -0
  6. castor_extractor/knowledge/notion/client/endpoints.py +1 -1
  7. castor_extractor/knowledge/notion/client/pagination.py +9 -5
  8. castor_extractor/quality/soda/assets.py +1 -1
  9. castor_extractor/quality/soda/client/client.py +30 -83
  10. castor_extractor/quality/soda/client/credentials.py +0 -11
  11. castor_extractor/quality/soda/client/endpoints.py +3 -6
  12. castor_extractor/quality/soda/client/pagination.py +25 -0
  13. castor_extractor/utils/__init__.py +13 -2
  14. castor_extractor/utils/client/__init__.py +14 -0
  15. castor_extractor/utils/client/api/__init__.py +5 -0
  16. castor_extractor/utils/client/api/auth.py +76 -0
  17. castor_extractor/utils/client/api/auth_test.py +49 -0
  18. castor_extractor/utils/client/api/client.py +153 -0
  19. castor_extractor/utils/client/api/client_test.py +47 -0
  20. castor_extractor/utils/client/api/pagination.py +83 -0
  21. castor_extractor/utils/client/api/pagination_test.py +51 -0
  22. castor_extractor/utils/{safe_request_test.py → client/api/safe_request_test.py} +4 -1
  23. castor_extractor/utils/client/api/utils.py +9 -0
  24. castor_extractor/utils/client/api/utils_test.py +16 -0
  25. castor_extractor/utils/collection.py +34 -2
  26. castor_extractor/utils/collection_test.py +17 -3
  27. castor_extractor/utils/pager/__init__.py +0 -1
  28. castor_extractor/utils/retry.py +44 -0
  29. castor_extractor/utils/retry_test.py +26 -1
  30. castor_extractor/utils/salesforce/client.py +44 -49
  31. castor_extractor/utils/salesforce/client_test.py +2 -2
  32. castor_extractor/utils/salesforce/pagination.py +33 -0
  33. castor_extractor/visualization/domo/client/client.py +10 -5
  34. castor_extractor/visualization/domo/client/credentials.py +1 -1
  35. castor_extractor/visualization/domo/client/endpoints.py +19 -7
  36. castor_extractor/visualization/looker/api/credentials.py +1 -1
  37. castor_extractor/visualization/metabase/client/api/client.py +26 -11
  38. castor_extractor/visualization/metabase/client/api/credentials.py +1 -1
  39. castor_extractor/visualization/metabase/client/db/credentials.py +1 -1
  40. castor_extractor/visualization/mode/client/credentials.py +1 -1
  41. castor_extractor/visualization/qlik/client/engine/credentials.py +1 -1
  42. castor_extractor/visualization/salesforce_reporting/client/rest.py +4 -3
  43. castor_extractor/visualization/sigma/client/client.py +106 -111
  44. castor_extractor/visualization/sigma/client/credentials.py +11 -1
  45. castor_extractor/visualization/sigma/client/endpoints.py +1 -1
  46. castor_extractor/visualization/sigma/client/pagination.py +22 -18
  47. castor_extractor/visualization/tableau/tests/unit/rest_api/auth_test.py +0 -1
  48. castor_extractor/visualization/tableau/tests/unit/rest_api/credentials_test.py +0 -3
  49. castor_extractor/visualization/tableau_revamp/assets.py +11 -0
  50. castor_extractor/visualization/tableau_revamp/client/client.py +71 -151
  51. castor_extractor/visualization/tableau_revamp/client/client_metadata_api.py +95 -0
  52. castor_extractor/visualization/tableau_revamp/client/client_rest_api.py +128 -0
  53. castor_extractor/visualization/tableau_revamp/client/client_tsc.py +66 -0
  54. castor_extractor/visualization/tableau_revamp/client/{tsc_fields.py → rest_fields.py} +15 -2
  55. castor_extractor/visualization/tableau_revamp/constants.py +0 -2
  56. castor_extractor/visualization/tableau_revamp/extract.py +5 -11
  57. castor_extractor/warehouse/databricks/api_client.py +239 -0
  58. castor_extractor/warehouse/databricks/api_client_test.py +15 -0
  59. castor_extractor/warehouse/databricks/client.py +37 -490
  60. castor_extractor/warehouse/databricks/client_test.py +1 -99
  61. castor_extractor/warehouse/databricks/endpoints.py +28 -0
  62. castor_extractor/warehouse/databricks/lineage.py +141 -0
  63. castor_extractor/warehouse/databricks/lineage_test.py +34 -0
  64. castor_extractor/warehouse/databricks/pagination.py +22 -0
  65. castor_extractor/warehouse/databricks/sql_client.py +90 -0
  66. castor_extractor/warehouse/databricks/utils.py +44 -1
  67. castor_extractor/warehouse/databricks/utils_test.py +58 -1
  68. castor_extractor/warehouse/mysql/client.py +0 -2
  69. castor_extractor/warehouse/salesforce/client.py +12 -59
  70. castor_extractor/warehouse/salesforce/pagination.py +34 -0
  71. castor_extractor/warehouse/sqlserver/client.py +0 -1
  72. castor_extractor-0.19.6.dist-info/METADATA +903 -0
  73. {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/RECORD +77 -60
  74. castor_extractor/utils/client/api.py +0 -87
  75. castor_extractor/utils/client/api_test.py +0 -24
  76. castor_extractor/utils/pager/pager_on_token.py +0 -52
  77. castor_extractor/utils/pager/pager_on_token_test.py +0 -73
  78. castor_extractor/visualization/sigma/client/client_test.py +0 -54
  79. castor_extractor-0.19.0.dist-info/METADATA +0 -207
  80. /castor_extractor/utils/{safe_request.py → client/api/safe_request.py} +0 -0
  81. {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/LICENCE +0 -0
  82. {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/WHEEL +0 -0
  83. {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,49 @@
1
+ from typing import Dict, Optional
2
+
3
+ from .auth import BasicAuth, BearerAuth, CustomAuth
4
+
5
+
6
+ class _MockRequest:
7
+ def __init__(self):
8
+ self.headers = {}
9
+
10
+
11
+ class _CustomAuth(CustomAuth):
12
+ def _authentication_header(self) -> Dict[str, str]:
13
+ return {"custom-token": "token"}
14
+
15
+
16
+ class _BearAuth(BearerAuth):
17
+ def fetch_token(self) -> Optional[str]:
18
+ return "token"
19
+
20
+
21
+ def test_BasicAuth():
22
+ prepared_request = _MockRequest()
23
+ auth = BasicAuth(username="simple", password="basic")
24
+ auth.__call__(prepared_request)
25
+ assert prepared_request.headers == {
26
+ "Authorization": "Basic c2ltcGxlOmJhc2lj"
27
+ }
28
+
29
+
30
+ def test_BearerAuth():
31
+ prepared_request = _MockRequest()
32
+ auth = _BearAuth()
33
+ auth.__call__(prepared_request)
34
+ assert prepared_request.headers == {"Authorization": "Bearer token"}
35
+
36
+ auth._token = "expired_token"
37
+ auth.__call__(prepared_request)
38
+ assert prepared_request.headers == {"Authorization": "Bearer expired_token"}
39
+
40
+ auth.refresh_token()
41
+ auth.__call__(prepared_request)
42
+ assert prepared_request.headers == {"Authorization": "Bearer token"}
43
+
44
+
45
+ def test_CustomAuth():
46
+ prepared_request = _MockRequest()
47
+ auth = _CustomAuth()
48
+ auth.__call__(prepared_request)
49
+ assert prepared_request.headers == {"custom-token": "token"}
@@ -0,0 +1,153 @@
1
+ import logging
2
+ from http import HTTPStatus
3
+ from typing import Dict, Literal, Optional, Tuple
4
+
5
+ import requests
6
+ from requests import Response
7
+
8
+ from ...retry import retry_request
9
+ from .auth import Auth
10
+ from .safe_request import RequestSafeMode, handle_response
11
+ from .utils import build_url
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ Headers = Optional[Dict[str, str]]
16
+
17
+ # https://requests.readthedocs.io/en/latest/api/#requests.request
18
+ HttpMethod = Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]
19
+
20
+ DEFAULT_TIMEOUT = 60
21
+ RETRY_ON_EXPIRED_TOKEN = 1
22
+
23
+
24
+ def _generate_payloads(
25
+ method: HttpMethod,
26
+ params: Optional[dict],
27
+ data: Optional[dict],
28
+ pagination_params: Optional[dict],
29
+ ) -> Tuple[dict, dict]:
30
+ _pagination_params = pagination_params or {}
31
+ params = params or {}
32
+ data = data or {}
33
+
34
+ if method == "GET":
35
+ params = {**params, **_pagination_params}
36
+ elif method == "POST":
37
+ data = {**data, **_pagination_params}
38
+ else:
39
+ raise ValueError(f"Method {method} is not yet supported")
40
+ return data, params
41
+
42
+
43
+ class APIClient:
44
+ """
45
+ Interface to easily query REST-API with GET and POST requests
46
+
47
+ Args:
48
+ auth: auth class to enable logging to the API
49
+ host: base url of the API
50
+ headers: common headers to all calls that will be made
51
+ timeout: read timeout for each request
52
+ safe_mode: ignore certain exceptions based on status codes
53
+
54
+ Note:
55
+ If the auth implements a refreshing mechanism (refresh_token)
56
+ the token is automatically refreshed once upon receiving the
57
+ 401: UNAUTHORIZED status code
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ auth: Auth,
63
+ host: Optional[str] = None,
64
+ headers: Headers = None,
65
+ timeout: int = DEFAULT_TIMEOUT,
66
+ safe_mode: RequestSafeMode = RequestSafeMode(),
67
+ ):
68
+ self.base_headers = headers or {}
69
+ self._host = host
70
+ self._timeout = timeout
71
+ self._auth = auth
72
+ self._safe_mode = safe_mode
73
+
74
+ def _call(
75
+ self,
76
+ method: HttpMethod,
77
+ endpoint: str,
78
+ *,
79
+ headers: Headers = None,
80
+ params: Optional[dict] = None,
81
+ data: Optional[dict] = None,
82
+ pagination_params: Optional[dict] = None,
83
+ ) -> Response:
84
+ headers = headers or {}
85
+
86
+ data, params = _generate_payloads(
87
+ method=method,
88
+ params=params,
89
+ data=data,
90
+ pagination_params=pagination_params,
91
+ )
92
+
93
+ url = build_url(self._host, endpoint)
94
+
95
+ return requests.request(
96
+ method=method,
97
+ url=url,
98
+ auth=self._auth,
99
+ headers={**self.base_headers, **headers},
100
+ params=params,
101
+ json=data,
102
+ timeout=self._timeout,
103
+ )
104
+
105
+ @retry_request(
106
+ status_codes=(HTTPStatus.UNAUTHORIZED,),
107
+ max_retries=RETRY_ON_EXPIRED_TOKEN,
108
+ )
109
+ def _get(
110
+ self,
111
+ endpoint: str,
112
+ *,
113
+ headers: Headers = None,
114
+ params: Optional[dict] = None,
115
+ data: Optional[dict] = None,
116
+ pagination_params: Optional[dict] = None,
117
+ ):
118
+ response = self._call(
119
+ method="GET",
120
+ endpoint=endpoint,
121
+ params=params,
122
+ data=data,
123
+ pagination_params=pagination_params,
124
+ headers=headers,
125
+ )
126
+ if response.status_code == HTTPStatus.UNAUTHORIZED:
127
+ self._auth.refresh_token()
128
+
129
+ return handle_response(response, safe_mode=self._safe_mode)
130
+
131
+ @retry_request(
132
+ status_codes=(HTTPStatus.UNAUTHORIZED,),
133
+ max_retries=RETRY_ON_EXPIRED_TOKEN,
134
+ )
135
+ def _post(
136
+ self,
137
+ endpoint: str,
138
+ *,
139
+ headers: Headers = None,
140
+ data: Optional[dict] = None,
141
+ pagination_params: Optional[dict] = None,
142
+ ):
143
+ response = self._call(
144
+ method="POST",
145
+ endpoint=endpoint,
146
+ data=data,
147
+ pagination_params=pagination_params,
148
+ headers=headers,
149
+ )
150
+ if response.status_code == HTTPStatus.UNAUTHORIZED:
151
+ self._auth.refresh_token()
152
+
153
+ return handle_response(response, safe_mode=self._safe_mode)
@@ -0,0 +1,47 @@
1
+ from unittest.mock import patch
2
+
3
+ from requests import PreparedRequest, Request, Session
4
+
5
+ from .auth import BasicAuth
6
+ from .client import APIClient
7
+
8
+
9
+ class MockSession:
10
+ def __init__(self):
11
+ self.request_data = None
12
+
13
+ def __enter__(self):
14
+ return self
15
+
16
+ def __exit__(self, exc_type, exc_val, exc_tb):
17
+ pass
18
+
19
+ def request(self, **kwargs) -> PreparedRequest:
20
+ kwargs.pop("timeout")
21
+ request = Request(**kwargs)
22
+ prepared_request = Session().prepare_request(request=request)
23
+ return prepared_request
24
+
25
+
26
+ @patch("requests.sessions.Session", MockSession)
27
+ def test__get():
28
+ auth = BasicAuth(username="user_id", password="secret")
29
+ client = APIClient(
30
+ auth=auth,
31
+ host="https://example.api.com/v1/",
32
+ headers={"content-type": "test"},
33
+ timeout=9,
34
+ )
35
+
36
+ prepared_request = client._call("GET", "endpoint")
37
+
38
+ # test method
39
+ assert prepared_request.method == "GET"
40
+
41
+ # test headers
42
+ assert prepared_request.headers["content-type"] == "test"
43
+ assert (
44
+ prepared_request.headers["Authorization"]
45
+ == "Basic dXNlcl9pZDpzZWNyZXQ="
46
+ )
47
+ assert prepared_request.url == "https://example.api.com/v1/endpoint"
@@ -0,0 +1,83 @@
1
+ import logging
2
+ from abc import abstractmethod
3
+ from enum import Enum
4
+ from functools import partial
5
+ from time import sleep
6
+ from typing import Callable, Iterator, Optional, Type, Union
7
+
8
+ from pydantic import BaseModel
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class FetchNextPageBy(Enum):
14
+ """
15
+ Enum to pick which APIClient._call() argument we want
16
+ to use for calling the next page in the pagination.
17
+ Supported arguments are :
18
+ - params (PAYLOAD)
19
+ - endpoint (URL)
20
+ """
21
+
22
+ PAYLOAD = "pagination_params"
23
+ URL = "endpoint"
24
+
25
+
26
+ class PaginationModel(BaseModel):
27
+ """
28
+ Base abstract class defining a pagination model
29
+
30
+ By implementing the 3 abstract methods below, enables
31
+ to fetch all elements of a Paginated API by using the
32
+ `fetch_all_pages` method
33
+ """
34
+
35
+ fetch_by: FetchNextPageBy = FetchNextPageBy.PAYLOAD
36
+ current_page_payload: Optional["dict"] = None
37
+
38
+ @abstractmethod
39
+ def is_last(self) -> bool:
40
+ """Stopping condition for the pagination"""
41
+ pass
42
+
43
+ @abstractmethod
44
+ def next_page_payload(self) -> Optional[Union[dict, str]]:
45
+ """Payload enabling to generate the request for the next page"""
46
+ pass
47
+
48
+ @abstractmethod
49
+ def page_results(self) -> list:
50
+ """List of results of the current page"""
51
+ pass
52
+
53
+ def next_page_parameters(self) -> dict:
54
+ return {self.fetch_by.value: self.next_page_payload()}
55
+
56
+
57
+ def fetch_all_pages(
58
+ request: Callable,
59
+ pagination_model: Type[PaginationModel],
60
+ rate_limit: Optional[int] = None,
61
+ ) -> Iterator:
62
+ """
63
+ Method to return all results of a Paginated API based on the
64
+ pagination model and the first request call
65
+ """
66
+ page_number = 1
67
+ response_payload = request()
68
+ paginated_response = pagination_model(**response_payload)
69
+ while not paginated_response.is_last():
70
+ logger.info(f"Fetching page number {page_number}")
71
+ yield from paginated_response.page_results()
72
+ next_page_parameters = paginated_response.next_page_parameters()
73
+ new_request = partial(request, **next_page_parameters)
74
+ if rate_limit:
75
+ sleep(rate_limit)
76
+ paginated_response = pagination_model(
77
+ current_page_payload=next_page_parameters, **new_request()
78
+ )
79
+ page_number += 1
80
+
81
+ # send last page's results
82
+ logger.info(f"Fetching page number {page_number}")
83
+ yield from paginated_response.page_results()
@@ -0,0 +1,51 @@
1
+ from typing import Optional
2
+
3
+ from pydantic import ConfigDict, Field
4
+ from pydantic.alias_generators import to_camel
5
+
6
+ from .pagination import PaginationModel, fetch_all_pages
7
+
8
+
9
+ class _TestPagination(PaginationModel):
10
+ next_page: Optional[str] = None
11
+ entries: list = Field(default_factory=list)
12
+
13
+ model_config = ConfigDict(
14
+ alias_generator=to_camel,
15
+ populate_by_name=True,
16
+ from_attributes=True,
17
+ )
18
+
19
+ def is_last(self) -> bool:
20
+ return self.next_page is None
21
+
22
+ def next_page_payload(self) -> dict:
23
+ return {"page": self.next_page}
24
+
25
+ def page_results(self) -> list:
26
+ return self.entries
27
+
28
+
29
+ def _request(pagination_params: Optional[dict] = None):
30
+ if not pagination_params:
31
+ return {
32
+ "nextPage": "next_page_id",
33
+ "entries": [1, 2, 3, 4, 5],
34
+ }
35
+ if pagination_params.get("page") == "next_page_id":
36
+ return {
37
+ "nextPage": "next_page_id_2",
38
+ "entries": [6, 7, 8, 9, 10],
39
+ }
40
+ if pagination_params.get("page") == "next_page_id_2":
41
+ return {
42
+ "nextPage": None,
43
+ "entries": [11],
44
+ }
45
+
46
+ raise AssertionError(f"call has unexpected parameters: {pagination_params}")
47
+
48
+
49
+ def test__TestPagination():
50
+ all_results = fetch_all_pages(_request, _TestPagination)
51
+ assert list(all_results) == [i for i in range(1, 12)]
@@ -4,7 +4,10 @@ from http import HTTPStatus
4
4
  import pytest
5
5
  from requests import HTTPError, Response
6
6
 
7
- from .safe_request import RequestSafeMode, handle_response
7
+ from .safe_request import (
8
+ RequestSafeMode,
9
+ handle_response,
10
+ )
8
11
 
9
12
 
10
13
  def mock_response(status_code: int):
@@ -0,0 +1,9 @@
1
+ from typing import Optional
2
+
3
+
4
+ def build_url(host: Optional[str], endpoint: str):
5
+ if not host:
6
+ return endpoint
7
+ if not host.startswith("https://"):
8
+ host = "https://" + host
9
+ return f"{host.strip('/')}/{endpoint}"
@@ -0,0 +1,16 @@
1
+ from .utils import build_url
2
+
3
+
4
+ def test_APIClient_build_url():
5
+ expected = "https://3.14.azuredatabricks.net/api/2.1/unity-catalog/tables"
6
+
7
+ path = "api/2.1/unity-catalog/tables"
8
+
9
+ host = "3.14.azuredatabricks.net"
10
+ assert expected == build_url(host, path)
11
+
12
+ host_with_http = "https://3.14.azuredatabricks.net"
13
+ assert expected == build_url(host_with_http, path)
14
+
15
+ host_with_trailing_slash = "https://3.14.azuredatabricks.net/"
16
+ assert expected == build_url(host_with_trailing_slash, path)
@@ -1,12 +1,23 @@
1
- from typing import Any, DefaultDict, Dict, List, Sequence
1
+ from collections import defaultdict
2
+ from typing import (
3
+ Any,
4
+ Dict,
5
+ Iterable,
6
+ List,
7
+ Sequence,
8
+ Set,
9
+ TypeVar,
10
+ )
2
11
 
3
12
  from .object import getproperty
4
13
  from .type import Getter
5
14
 
15
+ T = TypeVar("T")
16
+
6
17
 
7
18
  def group_by(identifier: Getter, elements: Sequence) -> Dict[Any, List]:
8
19
  """Groups the elements by the given key"""
9
- groups: Dict[Any, List] = DefaultDict(list)
20
+ groups: Dict[Any, List] = defaultdict(list)
10
21
  for element in elements:
11
22
  key = getproperty(element, identifier)
12
23
  groups[key].append(element)
@@ -52,3 +63,24 @@ def empty_iterator():
52
63
  Remark: missing return type is on purpose, it breaks the typing
53
64
  """
54
65
  return iter([])
66
+
67
+
68
+ def deduplicate(
69
+ identifier: Getter,
70
+ elements: Iterable[T],
71
+ ) -> List[T]:
72
+ """
73
+ Remove duplicates in the given elements, using the specified identifier
74
+ Only the first occurrence is kept.
75
+ """
76
+ deduplicated: List[T] = []
77
+ processed: Set[Any] = set()
78
+
79
+ for element in elements:
80
+ key = getproperty(element, identifier)
81
+ if key in processed:
82
+ continue
83
+ processed.add(key)
84
+ deduplicated.append(element)
85
+
86
+ return deduplicated
@@ -1,6 +1,4 @@
1
- import pytest
2
-
3
- from .collection import mapping_from_rows
1
+ from .collection import deduplicate, mapping_from_rows
4
2
 
5
3
 
6
4
  def test__mapping_from_rows__basic_mapping():
@@ -58,3 +56,19 @@ def test__mapping_from_rows__multiple_valid_rows():
58
56
  result = mapping_from_rows(rows, "id", "name")
59
57
  expected = {1: "Alice", 2: "Bob", 3: "Charlie"}
60
58
  assert result == expected
59
+
60
+
61
+ def test_deduplicate():
62
+ e1 = {"id": "1", "name": "element_1"}
63
+ e2 = {"id": "2", "name": "element_2"}
64
+ e3 = {"id": "3", "name": "element_3"}
65
+
66
+ elements = [
67
+ e1,
68
+ e2,
69
+ e3,
70
+ {"id": "3", "name": "duplicate"},
71
+ {"id": "3", "name": "duplicate"},
72
+ {"id": "2", "name": "duplicate"},
73
+ ]
74
+ assert deduplicate("id", elements) == [e1, e2, e3]
@@ -1,3 +1,2 @@
1
1
  from .pager import AbstractPager, Pager, PagerLogger, PagerStopStrategy
2
2
  from .pager_on_id import PagerOnId, PagerOnIdLogger
3
- from .pager_on_token import PagerOnToken
@@ -6,6 +6,7 @@ from typing import Any, Callable, Sequence, Tuple, Type, Union
6
6
 
7
7
  from pydantic import BaseModel, PositiveInt, PrivateAttr
8
8
  from pydantic.fields import Field
9
+ from requests import HTTPError
9
10
 
10
11
  logger = logging.getLogger(__name__)
11
12
 
@@ -115,3 +116,46 @@ def retry(
115
116
  return _func
116
117
 
117
118
  return _wrapper
119
+
120
+
121
+ def retry_request(
122
+ status_codes: Sequence[int],
123
+ max_retries: int = 1,
124
+ base_ms: int = 10,
125
+ jitter_ms: int = 20,
126
+ strategy: RetryStrategy = DEFAULT_STRATEGY,
127
+ log_exc_info: bool = False,
128
+ ) -> Callable:
129
+ """retry decorator"""
130
+
131
+ exceptions_ = tuple(e for e in status_codes)
132
+
133
+ def _wrapper(callable: Callable) -> Callable:
134
+ def _try(*args, **kwargs) -> WrapperReturnType:
135
+ try:
136
+ return None, callable(*args, **kwargs)
137
+ except HTTPError as err:
138
+ status_code = err.response.status_code
139
+ if status_code not in exceptions_:
140
+ raise err
141
+ logger.warning(f"Exception within {callable.__name__}")
142
+ return err, None
143
+
144
+ def _func(*args, **kwargs) -> Any:
145
+ retry = Retry(
146
+ max_retries=max_retries,
147
+ base_ms=base_ms,
148
+ jitter_ms=jitter_ms,
149
+ strategy=strategy,
150
+ )
151
+ while True:
152
+ err, result = _try(*args, **kwargs)
153
+ if err is None:
154
+ return result
155
+ if retry.check(err, log_exc_info):
156
+ continue
157
+ raise err
158
+
159
+ return _func
160
+
161
+ return _wrapper
@@ -1,11 +1,15 @@
1
+ from http import HTTPStatus
1
2
  from statistics import variance
2
3
  from time import time
3
4
  from typing import List
5
+ from unittest.mock import patch
4
6
 
5
7
  import pytest
8
+ import requests
6
9
  from pydantic.error_wrappers import ValidationError
10
+ from requests import HTTPError, Response
7
11
 
8
- from .retry import MS_IN_SEC, Retry, RetryStrategy
12
+ from .retry import MS_IN_SEC, Retry, RetryStrategy, retry_request
9
13
 
10
14
 
11
15
  def test_retry_field_validations():
@@ -75,3 +79,24 @@ def test_retry_strategy__check():
75
79
  assert retry.check(error) is False
76
80
  delta_ms = int((after - before) * MS_IN_SEC)
77
81
  assert _within(delta_ms, 315, 345)
82
+
83
+
84
+ @patch("requests.get")
85
+ def test_retry_request(mocked_get):
86
+ def error_response():
87
+ response = Response()
88
+ response.status_code = HTTPStatus.UNAUTHORIZED
89
+ return response
90
+
91
+ mocked_get.return_value = error_response()
92
+
93
+ @retry_request(status_codes=(HTTPStatus.UNAUTHORIZED,), max_retries=3)
94
+ def get():
95
+ response = requests.get("hello")
96
+ response.raise_for_status()
97
+ return response.json()
98
+
99
+ with pytest.raises(HTTPError):
100
+ get()
101
+
102
+ assert mocked_get.call_count == 4 # 1 call + 3 retries