castor-extractor 0.19.0__py3-none-any.whl → 0.19.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of castor-extractor might be problematic. Click here for more details.
- CHANGELOG.md +29 -2
- castor_extractor/file_checker/templates/generic_warehouse.py +1 -1
- castor_extractor/knowledge/notion/client/client.py +44 -80
- castor_extractor/knowledge/notion/client/client_test.py +9 -4
- castor_extractor/knowledge/notion/client/constants.py +1 -0
- castor_extractor/knowledge/notion/client/endpoints.py +1 -1
- castor_extractor/knowledge/notion/client/pagination.py +9 -5
- castor_extractor/quality/soda/assets.py +1 -1
- castor_extractor/quality/soda/client/client.py +30 -83
- castor_extractor/quality/soda/client/credentials.py +0 -11
- castor_extractor/quality/soda/client/endpoints.py +3 -6
- castor_extractor/quality/soda/client/pagination.py +25 -0
- castor_extractor/utils/__init__.py +13 -2
- castor_extractor/utils/client/__init__.py +14 -0
- castor_extractor/utils/client/api/__init__.py +5 -0
- castor_extractor/utils/client/api/auth.py +76 -0
- castor_extractor/utils/client/api/auth_test.py +49 -0
- castor_extractor/utils/client/api/client.py +153 -0
- castor_extractor/utils/client/api/client_test.py +47 -0
- castor_extractor/utils/client/api/pagination.py +83 -0
- castor_extractor/utils/client/api/pagination_test.py +51 -0
- castor_extractor/utils/{safe_request_test.py → client/api/safe_request_test.py} +4 -1
- castor_extractor/utils/client/api/utils.py +9 -0
- castor_extractor/utils/client/api/utils_test.py +16 -0
- castor_extractor/utils/collection.py +34 -2
- castor_extractor/utils/collection_test.py +17 -3
- castor_extractor/utils/pager/__init__.py +0 -1
- castor_extractor/utils/retry.py +44 -0
- castor_extractor/utils/retry_test.py +26 -1
- castor_extractor/utils/salesforce/client.py +44 -49
- castor_extractor/utils/salesforce/client_test.py +2 -2
- castor_extractor/utils/salesforce/pagination.py +33 -0
- castor_extractor/visualization/domo/client/client.py +10 -5
- castor_extractor/visualization/domo/client/credentials.py +1 -1
- castor_extractor/visualization/domo/client/endpoints.py +19 -7
- castor_extractor/visualization/looker/api/credentials.py +1 -1
- castor_extractor/visualization/metabase/client/api/client.py +26 -11
- castor_extractor/visualization/metabase/client/api/credentials.py +1 -1
- castor_extractor/visualization/metabase/client/db/credentials.py +1 -1
- castor_extractor/visualization/mode/client/credentials.py +1 -1
- castor_extractor/visualization/qlik/client/engine/credentials.py +1 -1
- castor_extractor/visualization/salesforce_reporting/client/rest.py +4 -3
- castor_extractor/visualization/sigma/client/client.py +106 -111
- castor_extractor/visualization/sigma/client/credentials.py +11 -1
- castor_extractor/visualization/sigma/client/endpoints.py +1 -1
- castor_extractor/visualization/sigma/client/pagination.py +22 -18
- castor_extractor/visualization/tableau/tests/unit/rest_api/auth_test.py +0 -1
- castor_extractor/visualization/tableau/tests/unit/rest_api/credentials_test.py +0 -3
- castor_extractor/visualization/tableau_revamp/assets.py +11 -0
- castor_extractor/visualization/tableau_revamp/client/client.py +71 -151
- castor_extractor/visualization/tableau_revamp/client/client_metadata_api.py +95 -0
- castor_extractor/visualization/tableau_revamp/client/client_rest_api.py +128 -0
- castor_extractor/visualization/tableau_revamp/client/client_tsc.py +66 -0
- castor_extractor/visualization/tableau_revamp/client/{tsc_fields.py → rest_fields.py} +15 -2
- castor_extractor/visualization/tableau_revamp/constants.py +0 -2
- castor_extractor/visualization/tableau_revamp/extract.py +5 -11
- castor_extractor/warehouse/databricks/api_client.py +239 -0
- castor_extractor/warehouse/databricks/api_client_test.py +15 -0
- castor_extractor/warehouse/databricks/client.py +37 -490
- castor_extractor/warehouse/databricks/client_test.py +1 -99
- castor_extractor/warehouse/databricks/endpoints.py +28 -0
- castor_extractor/warehouse/databricks/lineage.py +141 -0
- castor_extractor/warehouse/databricks/lineage_test.py +34 -0
- castor_extractor/warehouse/databricks/pagination.py +22 -0
- castor_extractor/warehouse/databricks/sql_client.py +90 -0
- castor_extractor/warehouse/databricks/utils.py +44 -1
- castor_extractor/warehouse/databricks/utils_test.py +58 -1
- castor_extractor/warehouse/mysql/client.py +0 -2
- castor_extractor/warehouse/salesforce/client.py +12 -59
- castor_extractor/warehouse/salesforce/pagination.py +34 -0
- castor_extractor/warehouse/sqlserver/client.py +0 -1
- castor_extractor-0.19.6.dist-info/METADATA +903 -0
- {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/RECORD +77 -60
- castor_extractor/utils/client/api.py +0 -87
- castor_extractor/utils/client/api_test.py +0 -24
- castor_extractor/utils/pager/pager_on_token.py +0 -52
- castor_extractor/utils/pager/pager_on_token_test.py +0 -73
- castor_extractor/visualization/sigma/client/client_test.py +0 -54
- castor_extractor-0.19.0.dist-info/METADATA +0 -207
- /castor_extractor/utils/{safe_request.py → client/api/safe_request.py} +0 -0
- {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/LICENCE +0 -0
- {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/WHEEL +0 -0
- {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import Dict, Optional
|
|
2
|
+
|
|
3
|
+
from .auth import BasicAuth, BearerAuth, CustomAuth
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class _MockRequest:
|
|
7
|
+
def __init__(self):
|
|
8
|
+
self.headers = {}
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class _CustomAuth(CustomAuth):
|
|
12
|
+
def _authentication_header(self) -> Dict[str, str]:
|
|
13
|
+
return {"custom-token": "token"}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class _BearAuth(BearerAuth):
|
|
17
|
+
def fetch_token(self) -> Optional[str]:
|
|
18
|
+
return "token"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def test_BasicAuth():
|
|
22
|
+
prepared_request = _MockRequest()
|
|
23
|
+
auth = BasicAuth(username="simple", password="basic")
|
|
24
|
+
auth.__call__(prepared_request)
|
|
25
|
+
assert prepared_request.headers == {
|
|
26
|
+
"Authorization": "Basic c2ltcGxlOmJhc2lj"
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_BearerAuth():
|
|
31
|
+
prepared_request = _MockRequest()
|
|
32
|
+
auth = _BearAuth()
|
|
33
|
+
auth.__call__(prepared_request)
|
|
34
|
+
assert prepared_request.headers == {"Authorization": "Bearer token"}
|
|
35
|
+
|
|
36
|
+
auth._token = "expired_token"
|
|
37
|
+
auth.__call__(prepared_request)
|
|
38
|
+
assert prepared_request.headers == {"Authorization": "Bearer expired_token"}
|
|
39
|
+
|
|
40
|
+
auth.refresh_token()
|
|
41
|
+
auth.__call__(prepared_request)
|
|
42
|
+
assert prepared_request.headers == {"Authorization": "Bearer token"}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_CustomAuth():
|
|
46
|
+
prepared_request = _MockRequest()
|
|
47
|
+
auth = _CustomAuth()
|
|
48
|
+
auth.__call__(prepared_request)
|
|
49
|
+
assert prepared_request.headers == {"custom-token": "token"}
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from http import HTTPStatus
|
|
3
|
+
from typing import Dict, Literal, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
from requests import Response
|
|
7
|
+
|
|
8
|
+
from ...retry import retry_request
|
|
9
|
+
from .auth import Auth
|
|
10
|
+
from .safe_request import RequestSafeMode, handle_response
|
|
11
|
+
from .utils import build_url
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
Headers = Optional[Dict[str, str]]
|
|
16
|
+
|
|
17
|
+
# https://requests.readthedocs.io/en/latest/api/#requests.request
|
|
18
|
+
HttpMethod = Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]
|
|
19
|
+
|
|
20
|
+
DEFAULT_TIMEOUT = 60
|
|
21
|
+
RETRY_ON_EXPIRED_TOKEN = 1
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _generate_payloads(
|
|
25
|
+
method: HttpMethod,
|
|
26
|
+
params: Optional[dict],
|
|
27
|
+
data: Optional[dict],
|
|
28
|
+
pagination_params: Optional[dict],
|
|
29
|
+
) -> Tuple[dict, dict]:
|
|
30
|
+
_pagination_params = pagination_params or {}
|
|
31
|
+
params = params or {}
|
|
32
|
+
data = data or {}
|
|
33
|
+
|
|
34
|
+
if method == "GET":
|
|
35
|
+
params = {**params, **_pagination_params}
|
|
36
|
+
elif method == "POST":
|
|
37
|
+
data = {**data, **_pagination_params}
|
|
38
|
+
else:
|
|
39
|
+
raise ValueError(f"Method {method} is not yet supported")
|
|
40
|
+
return data, params
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class APIClient:
|
|
44
|
+
"""
|
|
45
|
+
Interface to easily query REST-API with GET and POST requests
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
auth: auth class to enable logging to the API
|
|
49
|
+
host: base url of the API
|
|
50
|
+
headers: common headers to all calls that will be made
|
|
51
|
+
timeout: read timeout for each request
|
|
52
|
+
safe_mode: ignore certain exceptions based on status codes
|
|
53
|
+
|
|
54
|
+
Note:
|
|
55
|
+
If the auth implements a refreshing mechanism (refresh_token)
|
|
56
|
+
the token is automatically refreshed once upon receiving the
|
|
57
|
+
401: UNAUTHORIZED status code
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
auth: Auth,
|
|
63
|
+
host: Optional[str] = None,
|
|
64
|
+
headers: Headers = None,
|
|
65
|
+
timeout: int = DEFAULT_TIMEOUT,
|
|
66
|
+
safe_mode: RequestSafeMode = RequestSafeMode(),
|
|
67
|
+
):
|
|
68
|
+
self.base_headers = headers or {}
|
|
69
|
+
self._host = host
|
|
70
|
+
self._timeout = timeout
|
|
71
|
+
self._auth = auth
|
|
72
|
+
self._safe_mode = safe_mode
|
|
73
|
+
|
|
74
|
+
def _call(
|
|
75
|
+
self,
|
|
76
|
+
method: HttpMethod,
|
|
77
|
+
endpoint: str,
|
|
78
|
+
*,
|
|
79
|
+
headers: Headers = None,
|
|
80
|
+
params: Optional[dict] = None,
|
|
81
|
+
data: Optional[dict] = None,
|
|
82
|
+
pagination_params: Optional[dict] = None,
|
|
83
|
+
) -> Response:
|
|
84
|
+
headers = headers or {}
|
|
85
|
+
|
|
86
|
+
data, params = _generate_payloads(
|
|
87
|
+
method=method,
|
|
88
|
+
params=params,
|
|
89
|
+
data=data,
|
|
90
|
+
pagination_params=pagination_params,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
url = build_url(self._host, endpoint)
|
|
94
|
+
|
|
95
|
+
return requests.request(
|
|
96
|
+
method=method,
|
|
97
|
+
url=url,
|
|
98
|
+
auth=self._auth,
|
|
99
|
+
headers={**self.base_headers, **headers},
|
|
100
|
+
params=params,
|
|
101
|
+
json=data,
|
|
102
|
+
timeout=self._timeout,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
@retry_request(
|
|
106
|
+
status_codes=(HTTPStatus.UNAUTHORIZED,),
|
|
107
|
+
max_retries=RETRY_ON_EXPIRED_TOKEN,
|
|
108
|
+
)
|
|
109
|
+
def _get(
|
|
110
|
+
self,
|
|
111
|
+
endpoint: str,
|
|
112
|
+
*,
|
|
113
|
+
headers: Headers = None,
|
|
114
|
+
params: Optional[dict] = None,
|
|
115
|
+
data: Optional[dict] = None,
|
|
116
|
+
pagination_params: Optional[dict] = None,
|
|
117
|
+
):
|
|
118
|
+
response = self._call(
|
|
119
|
+
method="GET",
|
|
120
|
+
endpoint=endpoint,
|
|
121
|
+
params=params,
|
|
122
|
+
data=data,
|
|
123
|
+
pagination_params=pagination_params,
|
|
124
|
+
headers=headers,
|
|
125
|
+
)
|
|
126
|
+
if response.status_code == HTTPStatus.UNAUTHORIZED:
|
|
127
|
+
self._auth.refresh_token()
|
|
128
|
+
|
|
129
|
+
return handle_response(response, safe_mode=self._safe_mode)
|
|
130
|
+
|
|
131
|
+
@retry_request(
|
|
132
|
+
status_codes=(HTTPStatus.UNAUTHORIZED,),
|
|
133
|
+
max_retries=RETRY_ON_EXPIRED_TOKEN,
|
|
134
|
+
)
|
|
135
|
+
def _post(
|
|
136
|
+
self,
|
|
137
|
+
endpoint: str,
|
|
138
|
+
*,
|
|
139
|
+
headers: Headers = None,
|
|
140
|
+
data: Optional[dict] = None,
|
|
141
|
+
pagination_params: Optional[dict] = None,
|
|
142
|
+
):
|
|
143
|
+
response = self._call(
|
|
144
|
+
method="POST",
|
|
145
|
+
endpoint=endpoint,
|
|
146
|
+
data=data,
|
|
147
|
+
pagination_params=pagination_params,
|
|
148
|
+
headers=headers,
|
|
149
|
+
)
|
|
150
|
+
if response.status_code == HTTPStatus.UNAUTHORIZED:
|
|
151
|
+
self._auth.refresh_token()
|
|
152
|
+
|
|
153
|
+
return handle_response(response, safe_mode=self._safe_mode)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from unittest.mock import patch
|
|
2
|
+
|
|
3
|
+
from requests import PreparedRequest, Request, Session
|
|
4
|
+
|
|
5
|
+
from .auth import BasicAuth
|
|
6
|
+
from .client import APIClient
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MockSession:
|
|
10
|
+
def __init__(self):
|
|
11
|
+
self.request_data = None
|
|
12
|
+
|
|
13
|
+
def __enter__(self):
|
|
14
|
+
return self
|
|
15
|
+
|
|
16
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
def request(self, **kwargs) -> PreparedRequest:
|
|
20
|
+
kwargs.pop("timeout")
|
|
21
|
+
request = Request(**kwargs)
|
|
22
|
+
prepared_request = Session().prepare_request(request=request)
|
|
23
|
+
return prepared_request
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@patch("requests.sessions.Session", MockSession)
|
|
27
|
+
def test__get():
|
|
28
|
+
auth = BasicAuth(username="user_id", password="secret")
|
|
29
|
+
client = APIClient(
|
|
30
|
+
auth=auth,
|
|
31
|
+
host="https://example.api.com/v1/",
|
|
32
|
+
headers={"content-type": "test"},
|
|
33
|
+
timeout=9,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
prepared_request = client._call("GET", "endpoint")
|
|
37
|
+
|
|
38
|
+
# test method
|
|
39
|
+
assert prepared_request.method == "GET"
|
|
40
|
+
|
|
41
|
+
# test headers
|
|
42
|
+
assert prepared_request.headers["content-type"] == "test"
|
|
43
|
+
assert (
|
|
44
|
+
prepared_request.headers["Authorization"]
|
|
45
|
+
== "Basic dXNlcl9pZDpzZWNyZXQ="
|
|
46
|
+
)
|
|
47
|
+
assert prepared_request.url == "https://example.api.com/v1/endpoint"
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from functools import partial
|
|
5
|
+
from time import sleep
|
|
6
|
+
from typing import Callable, Iterator, Optional, Type, Union
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FetchNextPageBy(Enum):
|
|
14
|
+
"""
|
|
15
|
+
Enum to pick which APIClient._call() argument we want
|
|
16
|
+
to use for calling the next page in the pagination.
|
|
17
|
+
Supported arguments are :
|
|
18
|
+
- params (PAYLOAD)
|
|
19
|
+
- endpoint (URL)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
PAYLOAD = "pagination_params"
|
|
23
|
+
URL = "endpoint"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PaginationModel(BaseModel):
|
|
27
|
+
"""
|
|
28
|
+
Base abstract class defining a pagination model
|
|
29
|
+
|
|
30
|
+
By implementing the 3 abstract methods below, enables
|
|
31
|
+
to fetch all elements of a Paginated API by using the
|
|
32
|
+
`fetch_all_pages` method
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
fetch_by: FetchNextPageBy = FetchNextPageBy.PAYLOAD
|
|
36
|
+
current_page_payload: Optional["dict"] = None
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def is_last(self) -> bool:
|
|
40
|
+
"""Stopping condition for the pagination"""
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def next_page_payload(self) -> Optional[Union[dict, str]]:
|
|
45
|
+
"""Payload enabling to generate the request for the next page"""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def page_results(self) -> list:
|
|
50
|
+
"""List of results of the current page"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
def next_page_parameters(self) -> dict:
|
|
54
|
+
return {self.fetch_by.value: self.next_page_payload()}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def fetch_all_pages(
|
|
58
|
+
request: Callable,
|
|
59
|
+
pagination_model: Type[PaginationModel],
|
|
60
|
+
rate_limit: Optional[int] = None,
|
|
61
|
+
) -> Iterator:
|
|
62
|
+
"""
|
|
63
|
+
Method to return all results of a Paginated API based on the
|
|
64
|
+
pagination model and the first request call
|
|
65
|
+
"""
|
|
66
|
+
page_number = 1
|
|
67
|
+
response_payload = request()
|
|
68
|
+
paginated_response = pagination_model(**response_payload)
|
|
69
|
+
while not paginated_response.is_last():
|
|
70
|
+
logger.info(f"Fetching page number {page_number}")
|
|
71
|
+
yield from paginated_response.page_results()
|
|
72
|
+
next_page_parameters = paginated_response.next_page_parameters()
|
|
73
|
+
new_request = partial(request, **next_page_parameters)
|
|
74
|
+
if rate_limit:
|
|
75
|
+
sleep(rate_limit)
|
|
76
|
+
paginated_response = pagination_model(
|
|
77
|
+
current_page_payload=next_page_parameters, **new_request()
|
|
78
|
+
)
|
|
79
|
+
page_number += 1
|
|
80
|
+
|
|
81
|
+
# send last page's results
|
|
82
|
+
logger.info(f"Fetching page number {page_number}")
|
|
83
|
+
yield from paginated_response.page_results()
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import ConfigDict, Field
|
|
4
|
+
from pydantic.alias_generators import to_camel
|
|
5
|
+
|
|
6
|
+
from .pagination import PaginationModel, fetch_all_pages
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class _TestPagination(PaginationModel):
|
|
10
|
+
next_page: Optional[str] = None
|
|
11
|
+
entries: list = Field(default_factory=list)
|
|
12
|
+
|
|
13
|
+
model_config = ConfigDict(
|
|
14
|
+
alias_generator=to_camel,
|
|
15
|
+
populate_by_name=True,
|
|
16
|
+
from_attributes=True,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
def is_last(self) -> bool:
|
|
20
|
+
return self.next_page is None
|
|
21
|
+
|
|
22
|
+
def next_page_payload(self) -> dict:
|
|
23
|
+
return {"page": self.next_page}
|
|
24
|
+
|
|
25
|
+
def page_results(self) -> list:
|
|
26
|
+
return self.entries
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _request(pagination_params: Optional[dict] = None):
|
|
30
|
+
if not pagination_params:
|
|
31
|
+
return {
|
|
32
|
+
"nextPage": "next_page_id",
|
|
33
|
+
"entries": [1, 2, 3, 4, 5],
|
|
34
|
+
}
|
|
35
|
+
if pagination_params.get("page") == "next_page_id":
|
|
36
|
+
return {
|
|
37
|
+
"nextPage": "next_page_id_2",
|
|
38
|
+
"entries": [6, 7, 8, 9, 10],
|
|
39
|
+
}
|
|
40
|
+
if pagination_params.get("page") == "next_page_id_2":
|
|
41
|
+
return {
|
|
42
|
+
"nextPage": None,
|
|
43
|
+
"entries": [11],
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
raise AssertionError(f"call has unexpected parameters: {pagination_params}")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test__TestPagination():
|
|
50
|
+
all_results = fetch_all_pages(_request, _TestPagination)
|
|
51
|
+
assert list(all_results) == [i for i in range(1, 12)]
|
|
@@ -4,7 +4,10 @@ from http import HTTPStatus
|
|
|
4
4
|
import pytest
|
|
5
5
|
from requests import HTTPError, Response
|
|
6
6
|
|
|
7
|
-
from .safe_request import
|
|
7
|
+
from .safe_request import (
|
|
8
|
+
RequestSafeMode,
|
|
9
|
+
handle_response,
|
|
10
|
+
)
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
def mock_response(status_code: int):
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from .utils import build_url
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def test_APIClient_build_url():
|
|
5
|
+
expected = "https://3.14.azuredatabricks.net/api/2.1/unity-catalog/tables"
|
|
6
|
+
|
|
7
|
+
path = "api/2.1/unity-catalog/tables"
|
|
8
|
+
|
|
9
|
+
host = "3.14.azuredatabricks.net"
|
|
10
|
+
assert expected == build_url(host, path)
|
|
11
|
+
|
|
12
|
+
host_with_http = "https://3.14.azuredatabricks.net"
|
|
13
|
+
assert expected == build_url(host_with_http, path)
|
|
14
|
+
|
|
15
|
+
host_with_trailing_slash = "https://3.14.azuredatabricks.net/"
|
|
16
|
+
assert expected == build_url(host_with_trailing_slash, path)
|
|
@@ -1,12 +1,23 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import (
|
|
3
|
+
Any,
|
|
4
|
+
Dict,
|
|
5
|
+
Iterable,
|
|
6
|
+
List,
|
|
7
|
+
Sequence,
|
|
8
|
+
Set,
|
|
9
|
+
TypeVar,
|
|
10
|
+
)
|
|
2
11
|
|
|
3
12
|
from .object import getproperty
|
|
4
13
|
from .type import Getter
|
|
5
14
|
|
|
15
|
+
T = TypeVar("T")
|
|
16
|
+
|
|
6
17
|
|
|
7
18
|
def group_by(identifier: Getter, elements: Sequence) -> Dict[Any, List]:
|
|
8
19
|
"""Groups the elements by the given key"""
|
|
9
|
-
groups: Dict[Any, List] =
|
|
20
|
+
groups: Dict[Any, List] = defaultdict(list)
|
|
10
21
|
for element in elements:
|
|
11
22
|
key = getproperty(element, identifier)
|
|
12
23
|
groups[key].append(element)
|
|
@@ -52,3 +63,24 @@ def empty_iterator():
|
|
|
52
63
|
Remark: missing return type is on purpose, it breaks the typing
|
|
53
64
|
"""
|
|
54
65
|
return iter([])
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def deduplicate(
|
|
69
|
+
identifier: Getter,
|
|
70
|
+
elements: Iterable[T],
|
|
71
|
+
) -> List[T]:
|
|
72
|
+
"""
|
|
73
|
+
Remove duplicates in the given elements, using the specified identifier
|
|
74
|
+
Only the first occurrence is kept.
|
|
75
|
+
"""
|
|
76
|
+
deduplicated: List[T] = []
|
|
77
|
+
processed: Set[Any] = set()
|
|
78
|
+
|
|
79
|
+
for element in elements:
|
|
80
|
+
key = getproperty(element, identifier)
|
|
81
|
+
if key in processed:
|
|
82
|
+
continue
|
|
83
|
+
processed.add(key)
|
|
84
|
+
deduplicated.append(element)
|
|
85
|
+
|
|
86
|
+
return deduplicated
|
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
from .collection import mapping_from_rows
|
|
1
|
+
from .collection import deduplicate, mapping_from_rows
|
|
4
2
|
|
|
5
3
|
|
|
6
4
|
def test__mapping_from_rows__basic_mapping():
|
|
@@ -58,3 +56,19 @@ def test__mapping_from_rows__multiple_valid_rows():
|
|
|
58
56
|
result = mapping_from_rows(rows, "id", "name")
|
|
59
57
|
expected = {1: "Alice", 2: "Bob", 3: "Charlie"}
|
|
60
58
|
assert result == expected
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_deduplicate():
|
|
62
|
+
e1 = {"id": "1", "name": "element_1"}
|
|
63
|
+
e2 = {"id": "2", "name": "element_2"}
|
|
64
|
+
e3 = {"id": "3", "name": "element_3"}
|
|
65
|
+
|
|
66
|
+
elements = [
|
|
67
|
+
e1,
|
|
68
|
+
e2,
|
|
69
|
+
e3,
|
|
70
|
+
{"id": "3", "name": "duplicate"},
|
|
71
|
+
{"id": "3", "name": "duplicate"},
|
|
72
|
+
{"id": "2", "name": "duplicate"},
|
|
73
|
+
]
|
|
74
|
+
assert deduplicate("id", elements) == [e1, e2, e3]
|
castor_extractor/utils/retry.py
CHANGED
|
@@ -6,6 +6,7 @@ from typing import Any, Callable, Sequence, Tuple, Type, Union
|
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel, PositiveInt, PrivateAttr
|
|
8
8
|
from pydantic.fields import Field
|
|
9
|
+
from requests import HTTPError
|
|
9
10
|
|
|
10
11
|
logger = logging.getLogger(__name__)
|
|
11
12
|
|
|
@@ -115,3 +116,46 @@ def retry(
|
|
|
115
116
|
return _func
|
|
116
117
|
|
|
117
118
|
return _wrapper
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def retry_request(
|
|
122
|
+
status_codes: Sequence[int],
|
|
123
|
+
max_retries: int = 1,
|
|
124
|
+
base_ms: int = 10,
|
|
125
|
+
jitter_ms: int = 20,
|
|
126
|
+
strategy: RetryStrategy = DEFAULT_STRATEGY,
|
|
127
|
+
log_exc_info: bool = False,
|
|
128
|
+
) -> Callable:
|
|
129
|
+
"""retry decorator"""
|
|
130
|
+
|
|
131
|
+
exceptions_ = tuple(e for e in status_codes)
|
|
132
|
+
|
|
133
|
+
def _wrapper(callable: Callable) -> Callable:
|
|
134
|
+
def _try(*args, **kwargs) -> WrapperReturnType:
|
|
135
|
+
try:
|
|
136
|
+
return None, callable(*args, **kwargs)
|
|
137
|
+
except HTTPError as err:
|
|
138
|
+
status_code = err.response.status_code
|
|
139
|
+
if status_code not in exceptions_:
|
|
140
|
+
raise err
|
|
141
|
+
logger.warning(f"Exception within {callable.__name__}")
|
|
142
|
+
return err, None
|
|
143
|
+
|
|
144
|
+
def _func(*args, **kwargs) -> Any:
|
|
145
|
+
retry = Retry(
|
|
146
|
+
max_retries=max_retries,
|
|
147
|
+
base_ms=base_ms,
|
|
148
|
+
jitter_ms=jitter_ms,
|
|
149
|
+
strategy=strategy,
|
|
150
|
+
)
|
|
151
|
+
while True:
|
|
152
|
+
err, result = _try(*args, **kwargs)
|
|
153
|
+
if err is None:
|
|
154
|
+
return result
|
|
155
|
+
if retry.check(err, log_exc_info):
|
|
156
|
+
continue
|
|
157
|
+
raise err
|
|
158
|
+
|
|
159
|
+
return _func
|
|
160
|
+
|
|
161
|
+
return _wrapper
|
|
@@ -1,11 +1,15 @@
|
|
|
1
|
+
from http import HTTPStatus
|
|
1
2
|
from statistics import variance
|
|
2
3
|
from time import time
|
|
3
4
|
from typing import List
|
|
5
|
+
from unittest.mock import patch
|
|
4
6
|
|
|
5
7
|
import pytest
|
|
8
|
+
import requests
|
|
6
9
|
from pydantic.error_wrappers import ValidationError
|
|
10
|
+
from requests import HTTPError, Response
|
|
7
11
|
|
|
8
|
-
from .retry import MS_IN_SEC, Retry, RetryStrategy
|
|
12
|
+
from .retry import MS_IN_SEC, Retry, RetryStrategy, retry_request
|
|
9
13
|
|
|
10
14
|
|
|
11
15
|
def test_retry_field_validations():
|
|
@@ -75,3 +79,24 @@ def test_retry_strategy__check():
|
|
|
75
79
|
assert retry.check(error) is False
|
|
76
80
|
delta_ms = int((after - before) * MS_IN_SEC)
|
|
77
81
|
assert _within(delta_ms, 315, 345)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@patch("requests.get")
|
|
85
|
+
def test_retry_request(mocked_get):
|
|
86
|
+
def error_response():
|
|
87
|
+
response = Response()
|
|
88
|
+
response.status_code = HTTPStatus.UNAUTHORIZED
|
|
89
|
+
return response
|
|
90
|
+
|
|
91
|
+
mocked_get.return_value = error_response()
|
|
92
|
+
|
|
93
|
+
@retry_request(status_codes=(HTTPStatus.UNAUTHORIZED,), max_retries=3)
|
|
94
|
+
def get():
|
|
95
|
+
response = requests.get("hello")
|
|
96
|
+
response.raise_for_status()
|
|
97
|
+
return response.json()
|
|
98
|
+
|
|
99
|
+
with pytest.raises(HTTPError):
|
|
100
|
+
get()
|
|
101
|
+
|
|
102
|
+
assert mocked_get.call_count == 4 # 1 call + 3 retries
|