castor-extractor 0.16.1__py3-none-any.whl → 0.16.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of castor-extractor might be problematic. Click here for more details.

Files changed (41) hide show
  1. CHANGELOG.md +12 -0
  2. castor_extractor/commands/extract_databricks.py +3 -0
  3. castor_extractor/commands/extract_salesforce.py +43 -0
  4. castor_extractor/commands/extract_salesforce_reporting.py +6 -6
  5. castor_extractor/utils/client/api.py +36 -27
  6. castor_extractor/utils/salesforce/__init__.py +3 -0
  7. castor_extractor/utils/salesforce/client.py +84 -0
  8. castor_extractor/utils/salesforce/client_test.py +21 -0
  9. castor_extractor/utils/salesforce/constants.py +13 -0
  10. castor_extractor/utils/salesforce/credentials.py +65 -0
  11. castor_extractor/{visualization/salesforce_reporting/client → utils/salesforce}/credentials_test.py +3 -2
  12. castor_extractor/visualization/domo/client/client.py +1 -1
  13. castor_extractor/visualization/powerbi/client/constants.py +3 -3
  14. castor_extractor/visualization/powerbi/client/rest.py +14 -8
  15. castor_extractor/visualization/powerbi/client/rest_test.py +61 -27
  16. castor_extractor/visualization/salesforce_reporting/__init__.py +1 -2
  17. castor_extractor/visualization/salesforce_reporting/client/__init__.py +1 -2
  18. castor_extractor/visualization/salesforce_reporting/client/rest.py +7 -90
  19. castor_extractor/visualization/salesforce_reporting/extract.py +10 -8
  20. castor_extractor/visualization/tableau/assets.py +5 -0
  21. castor_extractor/visualization/tableau/client/client.py +10 -0
  22. castor_extractor/visualization/tableau/gql_fields.py +30 -9
  23. castor_extractor/warehouse/databricks/client.py +20 -3
  24. castor_extractor/warehouse/databricks/client_test.py +14 -0
  25. castor_extractor/warehouse/databricks/credentials.py +1 -4
  26. castor_extractor/warehouse/databricks/extract.py +3 -2
  27. castor_extractor/warehouse/databricks/format.py +5 -4
  28. castor_extractor/warehouse/salesforce/__init__.py +6 -0
  29. castor_extractor/warehouse/salesforce/client.py +112 -0
  30. castor_extractor/warehouse/salesforce/constants.py +2 -0
  31. castor_extractor/warehouse/salesforce/extract.py +111 -0
  32. castor_extractor/warehouse/salesforce/format.py +67 -0
  33. castor_extractor/warehouse/salesforce/format_test.py +32 -0
  34. castor_extractor/warehouse/salesforce/soql.py +45 -0
  35. castor_extractor-0.16.4.dist-info/LICENCE +86 -0
  36. {castor_extractor-0.16.1.dist-info → castor_extractor-0.16.4.dist-info}/METADATA +2 -3
  37. {castor_extractor-0.16.1.dist-info → castor_extractor-0.16.4.dist-info}/RECORD +39 -27
  38. {castor_extractor-0.16.1.dist-info → castor_extractor-0.16.4.dist-info}/WHEEL +1 -1
  39. {castor_extractor-0.16.1.dist-info → castor_extractor-0.16.4.dist-info}/entry_points.txt +2 -1
  40. castor_extractor/visualization/salesforce_reporting/client/constants.py +0 -2
  41. castor_extractor/visualization/salesforce_reporting/client/credentials.py +0 -33
CHANGELOG.md CHANGED
@@ -1,5 +1,17 @@
1
1
  # Changelog
2
2
 
3
+ ## 0.16.4 - 2024-04-25
4
+
5
+ * Salesforce: extract sobjects and fields
6
+
7
+ ## 0.16.3 - 2024-04-24
8
+
9
+ * Databricks: Extract table owners
10
+
11
+ ## 0.16.2 - 2024-04-09
12
+
13
+ * PowerBI: Extract pages from report
14
+
3
15
  ## 0.16.1 - 2024-04-02
4
16
 
5
17
  * Systematically escape nul bytes on CSV write
@@ -1,7 +1,10 @@
1
+ import logging
1
2
  from argparse import ArgumentParser
2
3
 
3
4
  from castor_extractor.warehouse import databricks # type: ignore
4
5
 
6
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
7
+
5
8
 
6
9
  def main():
7
10
  parser = ArgumentParser()
@@ -0,0 +1,43 @@
1
+ import logging
2
+ from argparse import ArgumentParser
3
+
4
+ from castor_extractor.warehouse import salesforce # type: ignore
5
+
6
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
7
+
8
+
9
+ def main():
10
+ parser = ArgumentParser()
11
+
12
+ parser.add_argument("-u", "--username", help="Salesforce username")
13
+ parser.add_argument("-p", "--password", help="Salesforce password")
14
+ parser.add_argument("-c", "--client-id", help="Salesforce client id")
15
+ parser.add_argument(
16
+ "-s", "--client-secret", help="Salesforce client secret"
17
+ )
18
+ parser.add_argument(
19
+ "-t", "--security-token", help="Salesforce security token"
20
+ )
21
+ parser.add_argument("-b", "--base-url", help="Salesforce instance URL")
22
+ parser.add_argument("-o", "--output", help="Directory to write to")
23
+
24
+ parser.add_argument(
25
+ "--skip-existing",
26
+ dest="skip_existing",
27
+ action="store_true",
28
+ help="Skips files already extracted instead of replacing them",
29
+ )
30
+ parser.set_defaults(skip_existing=False)
31
+
32
+ args = parser.parse_args()
33
+
34
+ salesforce.extract_all(
35
+ username=args.username,
36
+ password=args.password,
37
+ client_id=args.client_id,
38
+ client_secret=args.client_secret,
39
+ security_token=args.security_token,
40
+ base_url=args.base_url,
41
+ output_directory=args.output,
42
+ skip_existing=args.skip_existing,
43
+ )
@@ -11,23 +11,23 @@ def main():
11
11
 
12
12
  parser.add_argument("-u", "--username", help="Salesforce username")
13
13
  parser.add_argument("-p", "--password", help="Salesforce password")
14
- parser.add_argument("-k", "--consumer-key", help="Salesforce consumer key")
14
+ parser.add_argument("-c", "--client-id", help="Salesforce client id")
15
15
  parser.add_argument(
16
- "-s", "--consumer-secret", help="Salesforce consumer secret"
16
+ "-s", "--client-secret", help="Salesforce client secret"
17
17
  )
18
18
  parser.add_argument(
19
19
  "-t", "--security-token", help="Salesforce security token"
20
20
  )
21
- parser.add_argument("-l", "--url", help="Salesforce instance URL")
21
+ parser.add_argument("-b", "--base-url", help="Salesforce instance URL")
22
22
  parser.add_argument("-o", "--output", help="Directory to write to")
23
23
 
24
24
  args = parser.parse_args()
25
25
  salesforce_reporting.extract_all(
26
26
  username=args.username,
27
27
  password=args.password,
28
- consumer_key=args.consumer_key,
29
- consumer_secret=args.consumer_secret,
28
+ client_id=args.client_id,
29
+ client_secret=args.client_secret,
30
30
  security_token=args.security_token,
31
- instance_url=args.url,
31
+ base_url=args.base_url,
32
32
  output_directory=args.output,
33
33
  )
@@ -1,23 +1,25 @@
1
- import json
2
- from typing import Optional
1
+ import logging
2
+ from typing import Any, Callable, Dict, Literal, Optional
3
3
 
4
4
  import requests
5
5
 
6
- from ...warehouse.databricks.credentials import DatabricksCredentials
6
+ logger = logging.getLogger(__name__)
7
7
 
8
8
  DEFAULT_TIMEOUT_MS = 30_000
9
- APICredentials = DatabricksCredentials
9
+
10
+ # https://requests.readthedocs.io/en/latest/api/#requests.request
11
+ HttpMethod = Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]
10
12
 
11
13
 
12
14
  class APIClient:
13
15
  """
14
16
  API client
15
- - used for Databricks Unity Catalog for now
16
- - authentication via access token for now
17
+ - authentication via access token
17
18
  """
18
19
 
19
- def __init__(self, credentials: APICredentials):
20
- self.credentials = credentials
20
+ def __init__(self, host: str, token: Optional[str] = None):
21
+ self._host = host
22
+ self._token = token or ""
21
23
  self._timeout = DEFAULT_TIMEOUT_MS
22
24
 
23
25
  @staticmethod
@@ -26,25 +28,32 @@ class APIClient:
26
28
  host = "https://" + host
27
29
  return f"{host.strip('/')}/{path}"
28
30
 
29
- def _headers(self):
30
- return {
31
- "Content-type": "application/json",
32
- "Authorization": f"Bearer {self.credentials.token}",
33
- }
34
-
35
- def get(self, path: str, payload: Optional[dict] = None) -> dict:
36
- """
37
- path: REST API operation path, such as /api/2.0/clusters/get
38
- """
39
- url = self.build_url(self.credentials.host, path)
40
- response = requests.get(
41
- url,
42
- data=json.dumps(payload or dict()),
43
- headers=self._headers(),
44
- timeout=self._timeout,
31
+ def _headers(self) -> Dict[str, str]:
32
+ if self._token:
33
+ return {"Authorization": f"Bearer {self._token}"}
34
+ return dict()
35
+
36
+ def _call(
37
+ self,
38
+ url: str,
39
+ method: HttpMethod = "GET",
40
+ *,
41
+ params: Optional[dict] = None,
42
+ data: Optional[dict] = None,
43
+ processor: Optional[Callable] = None,
44
+ ) -> Any:
45
+ logger.debug(f"Calling {method} on {url}")
46
+ result = requests.request(
47
+ method, url, headers=self._headers(), params=params, json=data
45
48
  )
49
+ result.raise_for_status()
46
50
 
47
- if response.content:
48
- return json.loads(response.content)
51
+ if processor:
52
+ return processor(result)
49
53
 
50
- return {}
54
+ return result.json()
55
+
56
+ def get(self, path: str, payload: Optional[dict] = None) -> dict:
57
+ """path: REST API operation path, such as /api/2.0/clusters/get"""
58
+ url = self.build_url(self._host, path)
59
+ return self._call(url=url, data=payload)
@@ -0,0 +1,3 @@
1
+ from .client import SalesforceBaseClient
2
+ from .constants import Keys
3
+ from .credentials import SalesforceCredentials, to_credentials
@@ -0,0 +1,84 @@
1
+ import logging
2
+ from typing import Iterator, Optional, Tuple
3
+
4
+ from requests import Response
5
+
6
+ from ...utils.client.api import APIClient
7
+ from .constants import DEFAULT_API_VERSION, DEFAULT_PAGINATION_LIMIT
8
+ from .credentials import SalesforceCredentials
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class SalesforceBaseClient(APIClient):
14
+ """
15
+ Salesforce API client.
16
+ https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/intro_rest.htm
17
+ """
18
+
19
+ api_version = DEFAULT_API_VERSION
20
+ pagination_limit = DEFAULT_PAGINATION_LIMIT
21
+
22
+ PATH_TPL = "services/data/v{version}/{suffix}"
23
+
24
+ def __init__(self, credentials: SalesforceCredentials):
25
+ super().__init__(host=credentials.base_url)
26
+ self._token = self._access_token(credentials)
27
+
28
+ def _access_token(self, credentials: SalesforceCredentials) -> str:
29
+ url = self.build_url(self._host, "services/oauth2/token")
30
+ response = self._call(
31
+ url, "POST", params=credentials.token_request_payload()
32
+ )
33
+ return response["access_token"]
34
+
35
+ def _full_url(self, suffix: str) -> str:
36
+ path = self.PATH_TPL.format(version=self.api_version, suffix=suffix)
37
+ return self.build_url(self._host, path)
38
+
39
+ @property
40
+ def query_url(self) -> str:
41
+ """Returns the query API url"""
42
+ return self._full_url("query")
43
+
44
+ @property
45
+ def tooling_url(self) -> str:
46
+ """Returns the tooling API url"""
47
+ return self._full_url("tooling/query")
48
+
49
+ @staticmethod
50
+ def _query_processor(response: Response) -> Tuple[dict, Optional[str]]:
51
+ results = response.json()
52
+ return results["records"], results.get("nextRecordsUrl")
53
+
54
+ def _has_reached_pagination_limit(self, page_number: int) -> bool:
55
+ return page_number > self.pagination_limit
56
+
57
+ def _query_first_page(self, query: str) -> Tuple[Iterator[dict], str]:
58
+ url = self.query_url
59
+ logger.info("querying page 0")
60
+ records, next_page_url = self._call(
61
+ url, params={"q": query}, processor=self._query_processor
62
+ )
63
+ return records, next_page_url
64
+
65
+ def _query_all(self, query: str) -> Iterator[dict]:
66
+ """
67
+ Run a SOQL query over salesforce API.
68
+
69
+ more: https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/dome_query.htm
70
+ """
71
+ records, next_page_path = self._query_first_page(query)
72
+ yield from records
73
+
74
+ page_count = 1
75
+ while next_page_path and not self._has_reached_pagination_limit(
76
+ page_count
77
+ ):
78
+ logger.info(f"querying page {page_count}")
79
+ url = self.build_url(self._host, next_page_path)
80
+ records, next_page = self._call(
81
+ url, processor=self._query_processor
82
+ )
83
+ yield from records
84
+ page_count += 1
@@ -0,0 +1,21 @@
1
+ from unittest.mock import patch
2
+
3
+ from .client import SalesforceBaseClient
4
+ from .credentials import SalesforceCredentials
5
+
6
+
7
+ @patch.object(SalesforceBaseClient, "_call")
8
+ def test_SalesforceBaseClient__urls(mock_call):
9
+ mock_call.return_value = {"access_token": "the_token"}
10
+ credentials = SalesforceCredentials(
11
+ username="usr",
12
+ password="pw",
13
+ client_id="key",
14
+ client_secret="secret",
15
+ security_token="token",
16
+ base_url="url",
17
+ )
18
+ client = SalesforceBaseClient(credentials)
19
+
20
+ assert client.query_url == "https://url/services/data/v59.0/query"
21
+ assert client.tooling_url == "https://url/services/data/v59.0/tooling/query"
@@ -0,0 +1,13 @@
1
+ DEFAULT_API_VERSION = 59.0
2
+ DEFAULT_PAGINATION_LIMIT = 100
3
+
4
+
5
+ class Keys:
6
+ """Salesforce's credentials keys"""
7
+
8
+ USERNAME = "username"
9
+ PASSWORD = "password" # noqa: S105
10
+ CLIENT_ID = "client_id"
11
+ CLIENT_SECRET = "client_secret" # noqa: S105
12
+ SECURITY_TOKEN = "security_token" # noqa: S105
13
+ BASE_URL = "base_url"
@@ -0,0 +1,65 @@
1
+ from typing import Dict
2
+
3
+ from ...utils import from_env
4
+ from .constants import Keys
5
+
6
+ _USERNAME = "CASTOR_SALESFORCE_USERNAME"
7
+ _PASSWORD = "CASTOR_SALESFORCE_PASSWORD" # noqa: S105
8
+ _SECURITY_TOKEN = "CASTOR_SALESFORCE_SECURITY_TOKEN" # noqa: S105
9
+ _CLIENT_ID = "CASTOR_SALESFORCE_CLIENT_ID"
10
+ _CLIENT_SECRET = "CASTOR_SALESFORCE_CLIENT_SECRET" # noqa: S105
11
+ _BASE_URL = "CASTOR_SALESFORCE_BASE_URL"
12
+
13
+
14
+ class SalesforceCredentials:
15
+ """
16
+ Class to handle Salesforce rest API permissions
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ *,
22
+ username: str,
23
+ password: str,
24
+ security_token: str,
25
+ client_id: str,
26
+ client_secret: str,
27
+ base_url: str,
28
+ ):
29
+ self.username = username
30
+ self.password = password + security_token
31
+ self.client_id = client_id
32
+ self.client_secret = client_secret
33
+ self.base_url = base_url
34
+
35
+ def token_request_payload(self) -> Dict[str, str]:
36
+ """
37
+ Params to post to the API in order to retrieve the authentication token
38
+ """
39
+ return {
40
+ "grant_type": "password",
41
+ "client_id": self.client_id,
42
+ "client_secret": self.client_secret,
43
+ "username": self.username,
44
+ "password": self.password,
45
+ }
46
+
47
+
48
+ def to_credentials(params: dict) -> SalesforceCredentials:
49
+ """extract Salesforce credentials"""
50
+ username = params.get(Keys.USERNAME) or from_env(_USERNAME)
51
+ password = params.get(Keys.PASSWORD) or from_env(_PASSWORD)
52
+ security_token = params.get(Keys.SECURITY_TOKEN) or from_env(
53
+ _SECURITY_TOKEN
54
+ )
55
+ client_id = params.get(Keys.CLIENT_ID) or from_env(_CLIENT_ID)
56
+ client_secret = params.get(Keys.CLIENT_SECRET) or from_env(_CLIENT_SECRET)
57
+ base_url = params.get(Keys.BASE_URL) or from_env(_BASE_URL)
58
+ return SalesforceCredentials(
59
+ username=username,
60
+ password=password,
61
+ client_id=client_id,
62
+ client_secret=client_secret,
63
+ security_token=security_token,
64
+ base_url=base_url,
65
+ )
@@ -5,9 +5,10 @@ def test_Credentials_token_request_payload():
5
5
  creds = SalesforceCredentials(
6
6
  username="giphy",
7
7
  password="1312",
8
- consumer_key="degenie",
9
- consumer_secret="fautpasledire",
8
+ client_id="degenie",
9
+ client_secret="fautpasledire",
10
10
  security_token="yo",
11
+ base_url="man",
11
12
  )
12
13
 
13
14
  payload = creds.token_request_payload()
@@ -14,7 +14,7 @@ RawData = Iterator[dict]
14
14
 
15
15
  DOMO_PUBLIC_URL = "https://api.domo.com"
16
16
  FORMAT = "%Y-%m-%d %I:%M:%S %p"
17
- DEFAULT_TIMEOUT = 300
17
+ DEFAULT_TIMEOUT = 500
18
18
  TOKEN_EXPIRATION_SECONDS = timedelta(seconds=3000) # auth token lasts 1 hour
19
19
 
20
20
  IGNORED_ERROR_CODES = (
@@ -14,18 +14,18 @@ POST = "POST"
14
14
  class Urls:
15
15
  """PowerBi's urls"""
16
16
 
17
- REST_API_BASE_PATH = "https://api.powerbi.com/v1.0/myorg"
18
17
  CLIENT_APP_BASE = "https://login.microsoftonline.com/"
19
18
  DEFAULT_SCOPE = "https://analysis.windows.net/powerbi/api/.default"
19
+ REST_API_BASE_PATH = "https://api.powerbi.com/v1.0/myorg"
20
20
 
21
21
  # PBI rest API Routes
22
22
  ACTIVITY_EVENTS = f"{REST_API_BASE_PATH}/admin/activityevents"
23
- DATASETS = f"{REST_API_BASE_PATH}/admin/datasets"
24
23
  DASHBOARD = f"{REST_API_BASE_PATH}/admin/dashboards"
24
+ DATASETS = f"{REST_API_BASE_PATH}/admin/datasets"
25
25
  GROUPS = f"{REST_API_BASE_PATH}/admin/groups"
26
+ METADATA_GET = f"{REST_API_BASE_PATH}/admin/workspaces/scanResult"
26
27
  METADATA_POST = f"{REST_API_BASE_PATH}/admin/workspaces/getInfo"
27
28
  METADATA_WAIT = f"{REST_API_BASE_PATH}/admin/workspaces/scanStatus"
28
- METADATA_GET = f"{REST_API_BASE_PATH}/admin/workspaces/scanResult"
29
29
  REPORTS = f"{REST_API_BASE_PATH}/admin/reports"
30
30
  WORKSPACE_IDS = (
31
31
  "https://api.powerbi.com/v1.0/myorg/admin/workspaces/modified"
@@ -64,19 +64,15 @@ class Client:
64
64
 
65
65
  def __init__(self, credentials: Credentials):
66
66
  self.creds = credentials
67
-
68
- def _access_token(self) -> dict:
69
67
  client_app = f"{Urls.CLIENT_APP_BASE}{self.creds.tenant_id}"
70
- app = msal.ConfidentialClientApplication(
68
+ self.app = msal.ConfidentialClientApplication(
71
69
  client_id=self.creds.client_id,
72
70
  authority=client_app,
73
71
  client_credential=self.creds.secret,
74
72
  )
75
73
 
76
- token = app.acquire_token_silent(self.creds.scopes, account=None)
77
-
78
- if not token:
79
- token = app.acquire_token_for_client(scopes=self.creds.scopes)
74
+ def _access_token(self) -> dict:
75
+ token = self.app.acquire_token_for_client(scopes=self.creds.scopes)
80
76
 
81
77
  if Keys.ACCESS_TOKEN not in token:
82
78
  raise ValueError(f"No access token in token response: {token}")
@@ -248,7 +244,17 @@ class Client:
248
244
  Returns a list of reports for the organization.
249
245
  https://learn.microsoft.com/en-us/rest/api/power-bi/admin/reports-get-reports-as-admin
250
246
  """
251
- return self._get(Urls.REPORTS)[Keys.VALUE]
247
+ reports = self._get(Urls.REPORTS)[Keys.VALUE]
248
+ for report in reports:
249
+ report_id = report.get("id")
250
+ try:
251
+ url = Urls.REPORTS + f"/{report_id}/pages"
252
+ pages = self._get(url)[Keys.VALUE]
253
+ report["pages"] = pages
254
+ except (requests.HTTPError, requests.exceptions.Timeout) as e:
255
+ logger.debug(e)
256
+ continue
257
+ return reports
252
258
 
253
259
  def _dashboards(self) -> List[Dict]:
254
260
  """
@@ -31,7 +31,6 @@ def test__access_token(mock_app):
31
31
  # init mocks
32
32
  valid_response = {"access_token": "mock_token"}
33
33
  returning_valid_token = Mock(return_value=valid_response)
34
- mock_app.return_value.acquire_token_silent = Mock(return_value=None)
35
34
  mock_app.return_value.acquire_token_for_client = returning_valid_token
36
35
 
37
36
  # init client
@@ -40,30 +39,29 @@ def test__access_token(mock_app):
40
39
  # generated token
41
40
  assert client._access_token() == valid_response
42
41
 
43
- # via silent endpoint token
44
- mock_app.return_value.acquire_token_silent = returning_valid_token
45
- mock_app.return_value.acquire_token_for_client = None
46
- assert client._access_token() == valid_response
47
-
48
42
  # token missing in response
49
43
  invalid_response = {"not_access_token": "666"}
50
44
  returning_invalid_token = Mock(return_value=invalid_response)
51
- mock_app.return_value.acquire_token_silent = returning_invalid_token
45
+ mock_app.return_value.acquire_token_for_client = returning_invalid_token
52
46
 
53
47
  with pytest.raises(ValueError):
54
48
  client._access_token()
55
49
 
56
50
 
51
+ @patch.object(msal, "ConfidentialClientApplication")
57
52
  @patch.object(Client, "_access_token")
58
- def test__headers(mock_access_token):
53
+ def test__headers(mock_access_token, mock_app):
54
+ mock_app.return_value = None
59
55
  client = _client()
60
56
  mock_access_token.return_value = {Keys.ACCESS_TOKEN: "666"}
61
57
  assert client._header() == {"Authorization": "Bearer 666"}
62
58
 
63
59
 
60
+ @patch.object(msal, "ConfidentialClientApplication")
64
61
  @patch("requests.request")
65
62
  @patch.object(Client, "_access_token")
66
- def test__get(mocked_access_token, mocked_request):
63
+ def test__get(mocked_access_token, mocked_request, mock_app):
64
+ mock_app.return_value = None
67
65
  client = _client()
68
66
  mocked_access_token.return_value = {Keys.ACCESS_TOKEN: "666"}
69
67
  fact = {"fact": "Approximately 24 cat skins can make a coat.", "length": 43}
@@ -81,9 +79,11 @@ def test__get(mocked_access_token, mocked_request):
81
79
  result = client._get("https/whatev.er")
82
80
 
83
81
 
82
+ @patch.object(msal, "ConfidentialClientApplication")
84
83
  @patch("requests.request")
85
84
  @patch.object(Client, "_access_token")
86
- def test__workspace_ids(_, mocked_request):
85
+ def test__workspace_ids(_, mocked_request, mock_app):
86
+ mock_app.return_value = None
87
87
  client = _client()
88
88
  mocked_request.return_value = Mock(
89
89
  json=lambda: [{"id": 1000}, {"id": 1001}, {"id": 1003}],
@@ -112,9 +112,11 @@ def test__workspace_ids(_, mocked_request):
112
112
  )
113
113
 
114
114
 
115
+ @patch.object(msal, "ConfidentialClientApplication")
115
116
  @patch("requests.request")
116
117
  @patch.object(Client, "_access_token")
117
- def test__post_default(_, mocked_request):
118
+ def test__post_default(_, mocked_request, mock_app):
119
+ mock_app.return_value = None
118
120
  client = _client()
119
121
  url = "https://estcequecestbientotleweekend.fr/"
120
122
  params = QueryParams.METADATA_SCAN
@@ -129,9 +131,11 @@ def test__post_default(_, mocked_request):
129
131
  )
130
132
 
131
133
 
134
+ @patch.object(msal, "ConfidentialClientApplication")
132
135
  @patch("requests.request")
133
136
  @patch.object(Client, "_access_token")
134
- def test__post_with_processor(_, mocked_request):
137
+ def test__post_with_processor(_, mocked_request, mock_app):
138
+ mock_app.return_value = None
135
139
  client = _client()
136
140
  url = "https://estcequecestbientotleweekend.fr/"
137
141
  params = QueryParams.METADATA_SCAN
@@ -146,9 +150,11 @@ def test__post_with_processor(_, mocked_request):
146
150
  assert result == 1000
147
151
 
148
152
 
153
+ @patch.object(msal, "ConfidentialClientApplication")
149
154
  @patch("requests.request")
150
155
  @patch.object(Client, "_access_token")
151
- def test__datasets(_, mocked_request):
156
+ def test__datasets(_, mocked_request, mock_app):
157
+ mock_app.return_value = None
152
158
  client = _client()
153
159
  mocked_request.return_value = Mock(
154
160
  json=lambda: {"value": [{"id": 1, "type": "dataset"}]},
@@ -164,27 +170,50 @@ def test__datasets(_, mocked_request):
164
170
  assert datasets == [{"id": 1, "type": "dataset"}]
165
171
 
166
172
 
173
+ @patch.object(msal, "ConfidentialClientApplication")
167
174
  @patch("requests.request")
168
175
  @patch.object(Client, "_access_token")
169
- def test__reports(_, mocked_request):
176
+ def test__reports(_, mocked_request, mock_app):
177
+ mock_app.return_value = None
170
178
  client = _client()
171
- mocked_request.return_value = Mock(
172
- json=lambda: {"value": [{"id": 1, "type": "report"}]},
173
- )
179
+ page_url = f"{Urls.REPORTS}/1/pages"
180
+ calls = [
181
+ call(GET, Urls.REPORTS, data=None, headers=ANY, params=None),
182
+ call(
183
+ GET,
184
+ page_url,
185
+ data=None,
186
+ headers=ANY,
187
+ params=None,
188
+ ),
189
+ ]
190
+ mocked_request.side_effect = [
191
+ Mock(json=lambda: {"value": [{"id": 1, "type": "report"}]}),
192
+ Mock(
193
+ json=lambda: {
194
+ "value": [
195
+ {"name": "page_name", "displayName": "page", "order": 0}
196
+ ]
197
+ }
198
+ ),
199
+ ]
174
200
  reports = client._reports()
175
- mocked_request.assert_called_with(
176
- GET,
177
- Urls.REPORTS,
178
- data=None,
179
- headers=ANY,
180
- params=None,
181
- )
182
- assert reports == [{"id": 1, "type": "report"}]
201
+ mocked_request.assert_has_calls(calls)
202
+
203
+ assert reports == [
204
+ {
205
+ "id": 1,
206
+ "type": "report",
207
+ "pages": [{"name": "page_name", "displayName": "page", "order": 0}],
208
+ }
209
+ ]
183
210
 
184
211
 
212
+ @patch.object(msal, "ConfidentialClientApplication")
185
213
  @patch("requests.request")
186
214
  @patch.object(Client, "_access_token")
187
- def test__dashboards(_, mocked_request):
215
+ def test__dashboards(_, mocked_request, mock_app):
216
+ mock_app.return_value = None
188
217
  client = _client()
189
218
  mocked_request.return_value = Mock(
190
219
  json=lambda: {"value": [{"id": 1, "type": "dashboard"}]},
@@ -200,6 +229,7 @@ def test__dashboards(_, mocked_request):
200
229
  assert dashboards == [{"id": 1, "type": "dashboard"}]
201
230
 
202
231
 
232
+ @patch.object(msal, "ConfidentialClientApplication")
203
233
  @patch.object(Client, "_workspace_ids")
204
234
  @patch.object(Client, "_create_scan")
205
235
  @patch.object(Client, "_wait_for_scan_result")
@@ -209,7 +239,9 @@ def test__metadata(
209
239
  mocked_wait,
210
240
  mocked_create_scan,
211
241
  mocked_workspace_ids,
242
+ mock_app,
212
243
  ):
244
+ mock_app.return_value = None
213
245
  mocked_workspace_ids.return_value = list(range(200))
214
246
  mocked_create_scan.return_value = 314
215
247
  mocked_wait.return_value = True
@@ -240,8 +272,10 @@ _CALLS = [
240
272
  ]
241
273
 
242
274
 
275
+ @patch.object(msal, "ConfidentialClientApplication")
243
276
  @patch.object(Client, "_call")
244
- def test__activity_events(mocked):
277
+ def test__activity_events(mocked, mock_app):
278
+ mock_app.return_value = None
245
279
  client = _client()
246
280
  mocked.side_effect = _CALLS
247
281
 
@@ -1,4 +1,3 @@
1
1
  from .assets import SalesforceReportingAsset
2
- from .client import SalesforceClient
3
- from .client.credentials import SalesforceCredentials
2
+ from .client import SalesforceReportingClient
4
3
  from .extract import extract_all