castor-extractor 0.19.0__py3-none-any.whl → 0.19.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of castor-extractor might be problematic. Click here for more details.

Files changed (83) hide show
  1. CHANGELOG.md +29 -2
  2. castor_extractor/file_checker/templates/generic_warehouse.py +1 -1
  3. castor_extractor/knowledge/notion/client/client.py +44 -80
  4. castor_extractor/knowledge/notion/client/client_test.py +9 -4
  5. castor_extractor/knowledge/notion/client/constants.py +1 -0
  6. castor_extractor/knowledge/notion/client/endpoints.py +1 -1
  7. castor_extractor/knowledge/notion/client/pagination.py +9 -5
  8. castor_extractor/quality/soda/assets.py +1 -1
  9. castor_extractor/quality/soda/client/client.py +30 -83
  10. castor_extractor/quality/soda/client/credentials.py +0 -11
  11. castor_extractor/quality/soda/client/endpoints.py +3 -6
  12. castor_extractor/quality/soda/client/pagination.py +25 -0
  13. castor_extractor/utils/__init__.py +13 -2
  14. castor_extractor/utils/client/__init__.py +14 -0
  15. castor_extractor/utils/client/api/__init__.py +5 -0
  16. castor_extractor/utils/client/api/auth.py +76 -0
  17. castor_extractor/utils/client/api/auth_test.py +49 -0
  18. castor_extractor/utils/client/api/client.py +153 -0
  19. castor_extractor/utils/client/api/client_test.py +47 -0
  20. castor_extractor/utils/client/api/pagination.py +83 -0
  21. castor_extractor/utils/client/api/pagination_test.py +51 -0
  22. castor_extractor/utils/{safe_request_test.py → client/api/safe_request_test.py} +4 -1
  23. castor_extractor/utils/client/api/utils.py +9 -0
  24. castor_extractor/utils/client/api/utils_test.py +16 -0
  25. castor_extractor/utils/collection.py +34 -2
  26. castor_extractor/utils/collection_test.py +17 -3
  27. castor_extractor/utils/pager/__init__.py +0 -1
  28. castor_extractor/utils/retry.py +44 -0
  29. castor_extractor/utils/retry_test.py +26 -1
  30. castor_extractor/utils/salesforce/client.py +44 -49
  31. castor_extractor/utils/salesforce/client_test.py +2 -2
  32. castor_extractor/utils/salesforce/pagination.py +33 -0
  33. castor_extractor/visualization/domo/client/client.py +10 -5
  34. castor_extractor/visualization/domo/client/credentials.py +1 -1
  35. castor_extractor/visualization/domo/client/endpoints.py +19 -7
  36. castor_extractor/visualization/looker/api/credentials.py +1 -1
  37. castor_extractor/visualization/metabase/client/api/client.py +26 -11
  38. castor_extractor/visualization/metabase/client/api/credentials.py +1 -1
  39. castor_extractor/visualization/metabase/client/db/credentials.py +1 -1
  40. castor_extractor/visualization/mode/client/credentials.py +1 -1
  41. castor_extractor/visualization/qlik/client/engine/credentials.py +1 -1
  42. castor_extractor/visualization/salesforce_reporting/client/rest.py +4 -3
  43. castor_extractor/visualization/sigma/client/client.py +106 -111
  44. castor_extractor/visualization/sigma/client/credentials.py +11 -1
  45. castor_extractor/visualization/sigma/client/endpoints.py +1 -1
  46. castor_extractor/visualization/sigma/client/pagination.py +22 -18
  47. castor_extractor/visualization/tableau/tests/unit/rest_api/auth_test.py +0 -1
  48. castor_extractor/visualization/tableau/tests/unit/rest_api/credentials_test.py +0 -3
  49. castor_extractor/visualization/tableau_revamp/assets.py +11 -0
  50. castor_extractor/visualization/tableau_revamp/client/client.py +71 -151
  51. castor_extractor/visualization/tableau_revamp/client/client_metadata_api.py +95 -0
  52. castor_extractor/visualization/tableau_revamp/client/client_rest_api.py +128 -0
  53. castor_extractor/visualization/tableau_revamp/client/client_tsc.py +66 -0
  54. castor_extractor/visualization/tableau_revamp/client/{tsc_fields.py → rest_fields.py} +15 -2
  55. castor_extractor/visualization/tableau_revamp/constants.py +0 -2
  56. castor_extractor/visualization/tableau_revamp/extract.py +5 -11
  57. castor_extractor/warehouse/databricks/api_client.py +239 -0
  58. castor_extractor/warehouse/databricks/api_client_test.py +15 -0
  59. castor_extractor/warehouse/databricks/client.py +37 -490
  60. castor_extractor/warehouse/databricks/client_test.py +1 -99
  61. castor_extractor/warehouse/databricks/endpoints.py +28 -0
  62. castor_extractor/warehouse/databricks/lineage.py +141 -0
  63. castor_extractor/warehouse/databricks/lineage_test.py +34 -0
  64. castor_extractor/warehouse/databricks/pagination.py +22 -0
  65. castor_extractor/warehouse/databricks/sql_client.py +90 -0
  66. castor_extractor/warehouse/databricks/utils.py +44 -1
  67. castor_extractor/warehouse/databricks/utils_test.py +58 -1
  68. castor_extractor/warehouse/mysql/client.py +0 -2
  69. castor_extractor/warehouse/salesforce/client.py +12 -59
  70. castor_extractor/warehouse/salesforce/pagination.py +34 -0
  71. castor_extractor/warehouse/sqlserver/client.py +0 -1
  72. castor_extractor-0.19.6.dist-info/METADATA +903 -0
  73. {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/RECORD +77 -60
  74. castor_extractor/utils/client/api.py +0 -87
  75. castor_extractor/utils/client/api_test.py +0 -24
  76. castor_extractor/utils/pager/pager_on_token.py +0 -52
  77. castor_extractor/utils/pager/pager_on_token_test.py +0 -73
  78. castor_extractor/visualization/sigma/client/client_test.py +0 -54
  79. castor_extractor-0.19.0.dist-info/METADATA +0 -207
  80. /castor_extractor/utils/{safe_request.py → client/api/safe_request.py} +0 -0
  81. {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/LICENCE +0 -0
  82. {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/WHEEL +0 -0
  83. {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/entry_points.txt +0 -0
@@ -1,14 +1,38 @@
1
1
  import logging
2
- from typing import Iterator, Optional, Tuple
3
-
4
- from requests import Response
5
-
6
- from ...utils.client.api import APIClient
2
+ from functools import partial
3
+ from typing import Iterator, Optional
4
+
5
+ import requests
6
+
7
+ from ...utils import (
8
+ APIClient,
9
+ BearerAuth,
10
+ build_url,
11
+ fetch_all_pages,
12
+ handle_response,
13
+ )
7
14
  from .constants import DEFAULT_API_VERSION, DEFAULT_PAGINATION_LIMIT
8
15
  from .credentials import SalesforceCredentials
16
+ from .pagination import SalesforcePagination
9
17
 
10
18
  logger = logging.getLogger(__name__)
11
19
 
20
+ SALESFORCE_TIMEOUT_S = 120
21
+
22
+
23
+ class SalesforceAuth(BearerAuth):
24
+ _AUTH_ENDPOINT = "services/oauth2/token"
25
+
26
+ def __init__(self, credentials: SalesforceCredentials):
27
+ self._host = credentials.base_url
28
+ self._token_payload = credentials.token_request_payload()
29
+
30
+ def fetch_token(self) -> Optional[str]:
31
+ url = build_url(self._host, self._AUTH_ENDPOINT)
32
+ response = requests.post(url, "POST", params=self._token_payload)
33
+ handled_response = handle_response(response)
34
+ return handled_response["access_token"]
35
+
12
36
 
13
37
  class SalesforceBaseClient(APIClient):
14
38
  """
@@ -22,45 +46,24 @@ class SalesforceBaseClient(APIClient):
22
46
  PATH_TPL = "services/data/v{version}/{suffix}"
23
47
 
24
48
  def __init__(self, credentials: SalesforceCredentials):
25
- super().__init__(host=credentials.base_url)
26
- self._token = self._access_token(credentials)
27
-
28
- def _access_token(self, credentials: SalesforceCredentials) -> str:
29
- url = self.build_url(self._host, "services/oauth2/token")
30
- response = self._call(
31
- url, "POST", params=credentials.token_request_payload()
49
+ auth = SalesforceAuth(credentials)
50
+ super().__init__(
51
+ host=credentials.base_url, auth=auth, timeout=SALESFORCE_TIMEOUT_S
32
52
  )
33
- return response["access_token"]
34
53
 
35
- def _full_url(self, suffix: str) -> str:
54
+ def _endpoint(self, suffix: str) -> str:
36
55
  path = self.PATH_TPL.format(version=self.api_version, suffix=suffix)
37
- return self.build_url(self._host, path)
56
+ return path
38
57
 
39
58
  @property
40
- def query_url(self) -> str:
59
+ def query_endpoint(self) -> str:
41
60
  """Returns the query API url"""
42
- return self._full_url("query")
61
+ return self._endpoint("query")
43
62
 
44
63
  @property
45
- def tooling_url(self) -> str:
64
+ def tooling_endpoint(self) -> str:
46
65
  """Returns the tooling API url"""
47
- return self._full_url("tooling/query")
48
-
49
- @staticmethod
50
- def _query_processor(response: Response) -> Tuple[dict, Optional[str]]:
51
- results = response.json()
52
- return results["records"], results.get("nextRecordsUrl")
53
-
54
- def _has_reached_pagination_limit(self, page_number: int) -> bool:
55
- return page_number > self.pagination_limit
56
-
57
- def _query_first_page(self, query: str) -> Tuple[Iterator[dict], str]:
58
- url = self.query_url
59
- logger.info("querying page 0")
60
- records, next_page_url = self._call(
61
- url, params={"q": query}, processor=self._query_processor
62
- )
63
- return records, next_page_url
66
+ return self._endpoint("tooling/query")
64
67
 
65
68
  def _query_all(self, query: str) -> Iterator[dict]:
66
69
  """
@@ -68,17 +71,9 @@ class SalesforceBaseClient(APIClient):
68
71
 
69
72
  more: https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/dome_query.htm
70
73
  """
71
- records, next_page_path = self._query_first_page(query)
72
- yield from records
73
-
74
- page_count = 1
75
- while next_page_path and not self._has_reached_pagination_limit(
76
- page_count
77
- ):
78
- logger.info(f"querying page {page_count}")
79
- url = self.build_url(self._host, next_page_path)
80
- records, next_page_path = self._call(
81
- url, processor=self._query_processor
82
- )
83
- yield from records
84
- page_count += 1
74
+ request = partial(
75
+ self._get,
76
+ endpoint=self.query_endpoint,
77
+ params={"q": query},
78
+ )
79
+ yield from fetch_all_pages(request, SalesforcePagination)
@@ -17,5 +17,5 @@ def test_SalesforceBaseClient__urls(mock_call):
17
17
  )
18
18
  client = SalesforceBaseClient(credentials)
19
19
 
20
- assert client.query_url == "https://url/services/data/v59.0/query"
21
- assert client.tooling_url == "https://url/services/data/v59.0/tooling/query"
20
+ assert client.query_endpoint == "services/data/v59.0/query"
21
+ assert client.tooling_endpoint == "services/data/v59.0/tooling/query"
@@ -0,0 +1,33 @@
1
+ from typing import Optional
2
+
3
+ from pydantic import ConfigDict
4
+ from pydantic.alias_generators import to_camel
5
+
6
+ from ...utils import (
7
+ FetchNextPageBy,
8
+ PaginationModel,
9
+ )
10
+
11
+ LIMIT_RECORDS_PER_PAGE = 2000
12
+
13
+
14
+ class SalesforcePagination(PaginationModel):
15
+ model_config = ConfigDict(
16
+ alias_generator=to_camel,
17
+ populate_by_name=True,
18
+ from_attributes=True,
19
+ )
20
+ fetch_by: FetchNextPageBy = FetchNextPageBy.URL
21
+ records: list
22
+ next_records_url: Optional[str] = None
23
+
24
+ def is_last(self) -> bool:
25
+ no_next_page = not self.next_records_url
26
+ page_incomplete = len(self.records) < LIMIT_RECORDS_PER_PAGE
27
+ return no_next_page or page_incomplete
28
+
29
+ def next_page_payload(self) -> Optional[str]:
30
+ return self.next_records_url
31
+
32
+ def page_results(self) -> list:
33
+ return self.records
@@ -153,11 +153,16 @@ class DomoClient:
153
153
 
154
154
  return all_results
155
155
 
156
- def _datasources(self, page_id: str) -> RawData:
157
- endpoint = self._endpoint_factory.page_content(page_id)
158
- page_content = self._get_element(endpoint)
156
+ def _datasources(self, card_ids: List[int]) -> RawData:
157
+ """Yields all distinct datasources associated to the given cards"""
158
+ if not card_ids:
159
+ return empty_iterator()
160
+
161
+ endpoint = self._endpoint_factory.cards_metadata(card_ids)
162
+ cards_metadata = self._get_element(endpoint)
163
+
159
164
  processed: set[str] = set()
160
- for card in page_content.get("cards", []):
165
+ for card in cards_metadata:
161
166
  for datasource in card["datasources"]:
162
167
  id_ = datasource["dataSourceId"]
163
168
  if id_ in processed:
@@ -195,7 +200,7 @@ class DomoClient:
195
200
  if not detail:
196
201
  continue
197
202
 
198
- datasources = self._datasources(page_id)
203
+ datasources = self._datasources(detail.get("cardIds", []))
199
204
  yield {
200
205
  **detail,
201
206
  "datasources": list(datasources),
@@ -1,6 +1,6 @@
1
1
  from typing import Dict, Optional
2
2
 
3
- from pydantic import Field, SecretStr
3
+ from pydantic import Field
4
4
  from pydantic_settings import BaseSettings, SettingsConfigDict
5
5
  from requests.auth import HTTPBasicAuth
6
6
 
@@ -1,10 +1,19 @@
1
1
  from dataclasses import dataclass, field
2
- from typing import Optional
2
+ from typing import List, Optional
3
3
 
4
- _DOMO_PUBLIC_URL = "https://api.domo.com"
5
4
  _AUTH_URL = (
6
5
  "grant_type=client_credentials&scope=data%20dashboard%20audit%20user"
7
6
  )
7
+ _DOMO_PUBLIC_URL = "https://api.domo.com"
8
+ _MAX_URL_LENGTH = 2048
9
+
10
+
11
+ class URLTooLongException(Exception):
12
+ """Custom error handling if the URL is too big"""
13
+
14
+ def __init__(self, url: str, max_size: int):
15
+ message = f"The URL is too long ({len(url)} > {max_size}) : {url}"
16
+ super().__init__(message)
8
17
 
9
18
 
10
19
  @dataclass
@@ -70,8 +79,11 @@ class EndpointFactory:
70
79
  is_private=True,
71
80
  )
72
81
 
73
- def page_content(self, page_id: str) -> Endpoint:
74
- return Endpoint(
75
- base_url=f"{self.base_url}/api/content/v3/stacks/{page_id}/cards?parts=datasources",
76
- is_private=True,
77
- )
82
+ def cards_metadata(self, card_ids: List[int]) -> Endpoint:
83
+ urns = ",".join(map(str, card_ids))
84
+ url = f"{self.base_url}/api/content/v1/cards?urns={urns}&parts=datasources"
85
+
86
+ if len(url) > _MAX_URL_LENGTH:
87
+ raise URLTooLongException(url, _MAX_URL_LENGTH)
88
+
89
+ return Endpoint(base_url=url, is_private=True)
@@ -1,5 +1,5 @@
1
1
  from looker_sdk.rtl.api_settings import SettingsConfig
2
- from pydantic import Field, SecretStr
2
+ from pydantic import Field
3
3
  from pydantic_settings import BaseSettings, SettingsConfigDict
4
4
 
5
5
  from ..constants import DEFAULT_LOOKER_TIMEOUT_SECOND, LOOKER_ENV_PREFIX
@@ -1,10 +1,16 @@
1
1
  import logging
2
- from typing import Dict, Iterator, List, cast
2
+ from http import HTTPStatus
3
+ from typing import Any, Dict, Iterator, List, Optional, cast
3
4
 
4
5
  import requests
5
- from requests import HTTPError, Response
6
-
7
- from .....utils import JsonType, SerializedAsset
6
+ from requests import HTTPError
7
+
8
+ from .....utils import (
9
+ JsonType,
10
+ RequestSafeMode,
11
+ SerializedAsset,
12
+ handle_response,
13
+ )
8
14
  from ...assets import EXPORTED_FIELDS, MetabaseAsset
9
15
  from ...errors import MetabaseLoginError, SuperuserCredentialsRequired
10
16
  from ...types import IdsType
@@ -13,6 +19,17 @@ from .credentials import MetabaseApiCredentials
13
19
 
14
20
  logger = logging.getLogger(__name__)
15
21
 
22
+ # Safe mode
23
+ VOLUME_IGNORED = 5
24
+ IGNORED_ERROR_CODES = (
25
+ HTTPStatus.BAD_REQUEST,
26
+ HTTPStatus.NOT_FOUND,
27
+ )
28
+ METABASE_SAFE_MODE = RequestSafeMode(
29
+ max_errors=VOLUME_IGNORED,
30
+ status_codes=IGNORED_ERROR_CODES,
31
+ )
32
+
16
33
  URL_TEMPLATE = "{base_url}/api/{endpoint}"
17
34
 
18
35
  ROOT_KEY = "root"
@@ -31,12 +48,14 @@ class ApiClient:
31
48
  def __init__(
32
49
  self,
33
50
  credentials: MetabaseApiCredentials,
51
+ safe_mode: Optional[RequestSafeMode] = None,
34
52
  ):
35
53
  self.base_url = credentials.base_url
36
54
 
37
55
  self._credentials = credentials
38
56
  self._session = requests.Session()
39
57
  self._session_id = self._login()
58
+ self.safe_mode = safe_mode or METABASE_SAFE_MODE
40
59
  self._check_permissions() # verify that the given user is superuser
41
60
 
42
61
  @staticmethod
@@ -57,8 +76,8 @@ class ApiClient:
57
76
  }
58
77
 
59
78
  @staticmethod
60
- def _answer(response: Response):
61
- answer = response.json()
79
+ def _answer(response: Any):
80
+ answer = response
62
81
  if isinstance(answer, Dict) and DATA_KEY in answer:
63
82
  # v0.41 of Metabase introduced embedded data for certain calls
64
83
  # {'data': [{ }, ...] , 'total': 15, 'limit': None, 'offset': None}"
@@ -69,7 +88,7 @@ class ApiClient:
69
88
  url = self._url(endpoint)
70
89
  headers = self._headers()
71
90
  response = self._session.get(url=url, headers=headers)
72
- response.raise_for_status() # check for errors
91
+ response = handle_response(response, safe_mode=self.safe_mode)
73
92
  return self._answer(response)
74
93
 
75
94
  def _check_permissions(self) -> None:
@@ -110,10 +129,6 @@ class ApiClient:
110
129
  return ids
111
130
 
112
131
  def _dashboards(self) -> Iterator[dict]:
113
- """
114
- GET /api/dashboard is deprecated
115
- https://github.com/metabase/metabase/pull/35235
116
- """
117
132
  collection_ids = self._fetch_ids(MetabaseAsset.COLLECTION)
118
133
  for _id in collection_ids:
119
134
  collection = self._call(f"collection/{_id}/items?models=dashboard")
@@ -1,4 +1,4 @@
1
- from pydantic import Field, SecretStr
1
+ from pydantic import Field
2
2
  from pydantic_settings import BaseSettings, SettingsConfigDict
3
3
 
4
4
  METABASE_API_ENV_PREFIX = "CASTOR_METABASE_API_"
@@ -1,6 +1,6 @@
1
1
  from typing import Optional
2
2
 
3
- from pydantic import Field, SecretStr
3
+ from pydantic import Field
4
4
  from pydantic_settings import BaseSettings, SettingsConfigDict
5
5
 
6
6
  METABASE_DB_ENV_PREFIX = "CASTOR_METABASE_DB_"
@@ -1,4 +1,4 @@
1
- from pydantic import Field, SecretStr, field_validator
1
+ from pydantic import Field, field_validator
2
2
  from pydantic_settings import BaseSettings, SettingsConfigDict
3
3
 
4
4
  MODE_ENV_PREFIX = "CASTOR_MODE_ANALYTICS_"
@@ -1,4 +1,4 @@
1
- from pydantic import Field, SecretStr, field_validator
1
+ from pydantic import Field, field_validator
2
2
  from pydantic_settings import BaseSettings, SettingsConfigDict
3
3
 
4
4
  from .....utils import validate_baseurl
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
  from typing import Dict, Iterator, List, Optional
3
3
 
4
+ from ....utils import build_url
4
5
  from ....utils.salesforce import SalesforceBaseClient
5
6
  from ..assets import SalesforceReportingAsset
6
7
  from .soql import queries
@@ -28,15 +29,15 @@ class SalesforceReportingClient(SalesforceBaseClient):
28
29
 
29
30
  if asset_type == SalesforceReportingAsset.DASHBOARDS:
30
31
  path = f"lightning/r/Dashboard/{asset['Id']}/view"
31
- return self.build_url(self._host, path)
32
+ return build_url(self._host, path)
32
33
 
33
34
  if asset_type == SalesforceReportingAsset.FOLDERS:
34
35
  path = asset["attributes"]["url"].lstrip("/")
35
- return self.build_url(self._host, path)
36
+ return build_url(self._host, path)
36
37
 
37
38
  if asset_type == SalesforceReportingAsset.REPORTS:
38
39
  path = f"lightning/r/Report/{asset['Id']}/view"
39
- return self.build_url(self._host, path)
40
+ return build_url(self._host, path)
40
41
 
41
42
  return None
42
43
 
@@ -1,150 +1,142 @@
1
- import logging
2
- from typing import Dict, Iterator, List, Optional, Tuple
3
- from urllib.parse import urljoin
1
+ from functools import partial
2
+ from http import HTTPStatus
3
+ from typing import Callable, Dict, Iterator, List, Optional, Tuple
4
4
 
5
5
  import requests
6
6
 
7
+ from ....utils import (
8
+ APIClient,
9
+ BearerAuth,
10
+ RequestSafeMode,
11
+ build_url,
12
+ fetch_all_pages,
13
+ handle_response,
14
+ )
7
15
  from ..assets import SigmaAsset
8
16
  from .credentials import SigmaCredentials
9
- from .endpoints import EndpointFactory
10
- from .pagination import Pagination
11
-
12
- logger = logging.getLogger()
17
+ from .endpoints import SigmaEndpointFactory
18
+ from .pagination import SIGMA_API_LIMIT, SigmaPagination
13
19
 
20
+ _CONTENT_TYPE = "application/x-www-form-urlencoded"
14
21
 
15
- DATA_ELEMENTS: Tuple[str, ...] = (
22
+ _DATA_ELEMENTS: Tuple[str, ...] = (
16
23
  "input-table",
17
24
  "pivot-table",
18
25
  "table",
19
26
  "visualization",
20
27
  "viz",
21
28
  )
22
- _CONTENT_TYPE = "application/x-www-form-urlencoded"
23
29
 
30
+ _AUTH_TIMEOUT_S = 60
24
31
 
25
- class SigmaClient:
26
- """Client used for all Sigma's assets extractions"""
27
-
28
- def __init__(self, credentials: SigmaCredentials):
29
- self.host = credentials.host
30
- self.client_id = credentials.client_id
31
- self.api_token = credentials.api_token
32
- self.grant_type = credentials.grant_type
33
- self.headers: Optional[Dict[str, str]] = None
34
-
35
- def _get_token(self) -> Dict[str, str]:
36
- auth_endpoint = EndpointFactory.authentication()
37
- token_api_path = urljoin(self.host, auth_endpoint)
38
- token_response = requests.post( # noqa: S113
39
- token_api_path,
40
- data={
41
- "grant_type": self.grant_type,
42
- "client_id": self.client_id,
43
- "client_secret": self.api_token,
44
- },
45
- )
46
- if token_response.status_code != requests.codes.OK:
47
- raise ValueError("Couldn't fetch the token in the API")
48
- return token_response.json()
49
-
50
- def _get_headers(self, reset=False) -> Dict[str, str]:
51
- """
52
- If reset is True, will re-create the headers with a new authentication token
53
-
54
- Note : From this [documentation](https://help.sigmacomputing.com/docs/api-authentication-with-curl),
55
- instead of re-creating a token we could refresh it, but I don't see any benefit.
56
- """
57
- if reset or not self.headers:
58
- headers = {"Content-Type": _CONTENT_TYPE}
59
- token = self._get_token()
60
- headers["Authorization"] = f"Bearer {token['access_token']}"
61
- self.headers = headers
62
- return self.headers
63
-
64
- def _get(self, endpoint_url: str) -> dict:
65
- url = urljoin(self.host, endpoint_url)
66
- result = requests.get(url, headers=self._get_headers()) # noqa: S113
67
-
68
- if result.status_code == requests.codes.UNAUTHORIZED:
69
- logger.info("Regenerating access token")
70
- result = requests.get( # noqa: S113
71
- url, headers=self._get_headers(reset=True)
72
- )
32
+ _SIGMA_HEADERS = {
33
+ "Content-Type": _CONTENT_TYPE,
34
+ }
73
35
 
74
- try:
75
- return result.json()
76
- except Exception as e:
77
- logger.warning(
78
- f"Couldn't deserialize result from url {url}."
79
- f" with status code {result.status_code} and"
80
- f" exception {type(e)}"
81
- )
82
- return dict()
83
-
84
- def _get_with_pagination(self, endpoint_url: str) -> Iterator[dict]:
85
- pagination = Pagination(next_page="0")
86
-
87
- while pagination.next_page is not None:
88
- paginated_url = pagination.generate_url(endpoint_url)
89
- response = self._get(paginated_url)
90
- pagination = Pagination(
91
- next_page=response.get("nextPage"),
92
- entries=response.get("entries"),
93
- total=response.get("total"),
94
- )
95
- yield from pagination.entries
36
+ _VOLUME_IGNORED = 10_000
37
+ _IGNORED_ERROR_CODES = (
38
+ HTTPStatus.BAD_REQUEST,
39
+ HTTPStatus.BAD_GATEWAY,
40
+ HTTPStatus.INTERNAL_SERVER_ERROR,
41
+ HTTPStatus.CONFLICT,
42
+ HTTPStatus.NOT_FOUND,
43
+ )
44
+ SIGMA_SAFE_MODE = RequestSafeMode(
45
+ max_errors=_VOLUME_IGNORED,
46
+ status_codes=_IGNORED_ERROR_CODES,
47
+ )
48
+
49
+
50
+ class SigmaBearerAuth(BearerAuth):
51
+ def __init__(self, host: str, token_payload: Dict[str, str]):
52
+ auth_endpoint = SigmaEndpointFactory.authentication()
53
+ self.authentication_url = build_url(host, auth_endpoint)
54
+ self.token_payload = token_payload
55
+
56
+ def fetch_token(self):
57
+ token_api_path = self.authentication_url
58
+ token_response = requests.post(
59
+ token_api_path, data=self.token_payload, timeout=_AUTH_TIMEOUT_S
60
+ )
61
+ return handle_response(token_response)["access_token"]
96
62
 
97
- def _per_workbook_get_pages(self, workbook_id: str) -> Iterator[dict]:
98
- endpoint = EndpointFactory.pages(workbook_id)
99
- yield from self._get_with_pagination(endpoint)
100
63
 
101
- def _per_page_get_elements(
64
+ class SigmaClient(APIClient):
65
+ def __init__(
102
66
  self,
103
- workbook_id: str,
104
- page_id: str,
105
- ) -> Iterator[dict]:
106
- endpoint = EndpointFactory.elements(workbook_id, page_id)
107
- yield from self._get_with_pagination(endpoint)
67
+ credentials: SigmaCredentials,
68
+ safe_mode: Optional[RequestSafeMode] = None,
69
+ ):
70
+ auth = SigmaBearerAuth(
71
+ host=credentials.host,
72
+ token_payload=credentials.token_payload,
73
+ )
74
+ super().__init__(
75
+ host=credentials.host,
76
+ auth=auth,
77
+ headers=_SIGMA_HEADERS,
78
+ safe_mode=safe_mode or SIGMA_SAFE_MODE,
79
+ )
80
+
81
+ def _get_paginated(self, endpoint: str) -> Callable:
82
+ return partial(
83
+ self._get, endpoint=endpoint, params={"limit": SIGMA_API_LIMIT}
84
+ )
108
85
 
109
86
  def _get_all_datasets(self) -> Iterator[dict]:
110
- endpoint = EndpointFactory.datasets()
111
- yield from self._get_with_pagination(endpoint)
87
+ request = self._get_paginated(endpoint=SigmaEndpointFactory.datasets())
88
+ yield from fetch_all_pages(request, SigmaPagination)
112
89
 
113
90
  def _get_all_files(self) -> Iterator[dict]:
114
- endpoint = EndpointFactory.files()
115
- yield from self._get_with_pagination(endpoint)
91
+ request = self._get_paginated(endpoint=SigmaEndpointFactory.files())
92
+ yield from fetch_all_pages(request, SigmaPagination)
116
93
 
117
94
  def _get_all_members(self) -> Iterator[dict]:
118
- endpoint = EndpointFactory.members()
119
- yield from self._get(endpoint)
95
+ request = self._get_paginated(endpoint=SigmaEndpointFactory.members())
96
+ yield from fetch_all_pages(request, SigmaPagination)
120
97
 
121
98
  def _get_all_workbooks(self) -> Iterator[dict]:
122
- endpoint = EndpointFactory.workbooks()
123
- yield from self._get_with_pagination(endpoint)
99
+ request = self._get_paginated(endpoint=SigmaEndpointFactory.workbooks())
100
+ yield from fetch_all_pages(request, SigmaPagination)
101
+
102
+ def _get_elements_per_page(
103
+ self, page: dict, workbook_id: str
104
+ ) -> Iterator[dict]:
105
+ page_id = page["pageId"]
106
+ request = self._get_paginated(
107
+ SigmaEndpointFactory.elements(workbook_id, page_id)
108
+ )
109
+ elements = fetch_all_pages(request, SigmaPagination)
110
+ for element in elements:
111
+ if element.get("type") not in _DATA_ELEMENTS:
112
+ continue
113
+ yield {
114
+ **element,
115
+ "workbook_id": workbook_id,
116
+ "page_id": page_id,
117
+ }
124
118
 
125
119
  def _get_all_elements(self, workbooks: List[dict]) -> Iterator[dict]:
126
120
  for workbook in workbooks:
127
121
  workbook_id = workbook["workbookId"]
128
- pages = self._per_workbook_get_pages(workbook_id)
122
+
123
+ request = self._get_paginated(
124
+ SigmaEndpointFactory.pages(workbook_id)
125
+ )
126
+ pages = fetch_all_pages(request, SigmaPagination)
129
127
 
130
128
  for page in pages:
131
- page_id = page["pageId"]
132
- elements = self._per_page_get_elements(workbook_id, page_id)
133
- for element in elements:
134
- if element.get("type") not in DATA_ELEMENTS:
135
- continue
136
- yield {
137
- **element,
138
- "workbook_id": workbook_id,
139
- "page_id": page_id,
140
- }
129
+ yield from self._get_elements_per_page(
130
+ page=page, workbook_id=workbook_id
131
+ )
141
132
 
142
133
  def _get_all_lineages(self, elements: List[dict]) -> Iterator[dict]:
143
134
  for element in elements:
144
135
  workbook_id = element["workbook_id"]
145
136
  element_id = element["elementId"]
146
- endpoint = EndpointFactory.lineage(workbook_id, element_id)
147
- lineage = self._get(endpoint)
137
+ lineage = self._get(
138
+ endpoint=SigmaEndpointFactory.lineage(workbook_id, element_id)
139
+ )
148
140
  yield {
149
141
  **lineage,
150
142
  "workbook_id": workbook_id,
@@ -154,8 +146,11 @@ class SigmaClient:
154
146
  def _get_all_queries(self, workbooks: List[dict]) -> Iterator[dict]:
155
147
  for workbook in workbooks:
156
148
  workbook_id = workbook["workbookId"]
157
- endpoint = EndpointFactory.queries(workbook_id)
158
- queries = self._get_with_pagination(endpoint)
149
+ request = self._get_paginated(
150
+ SigmaEndpointFactory.queries(workbook_id)
151
+ )
152
+ queries = fetch_all_pages(request, SigmaPagination)
153
+
159
154
  for query in queries:
160
155
  yield {**query, "workbook_id": workbook_id}
161
156