castor-extractor 0.16.1__py3-none-any.whl → 0.16.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of castor-extractor might be problematic. Click here for more details.

Files changed (41) hide show
  1. CHANGELOG.md +12 -0
  2. castor_extractor/commands/extract_databricks.py +3 -0
  3. castor_extractor/commands/extract_salesforce.py +43 -0
  4. castor_extractor/commands/extract_salesforce_reporting.py +6 -6
  5. castor_extractor/utils/client/api.py +36 -27
  6. castor_extractor/utils/salesforce/__init__.py +3 -0
  7. castor_extractor/utils/salesforce/client.py +84 -0
  8. castor_extractor/utils/salesforce/client_test.py +21 -0
  9. castor_extractor/utils/salesforce/constants.py +13 -0
  10. castor_extractor/utils/salesforce/credentials.py +65 -0
  11. castor_extractor/{visualization/salesforce_reporting/client → utils/salesforce}/credentials_test.py +3 -2
  12. castor_extractor/visualization/domo/client/client.py +1 -1
  13. castor_extractor/visualization/powerbi/client/constants.py +3 -3
  14. castor_extractor/visualization/powerbi/client/rest.py +14 -8
  15. castor_extractor/visualization/powerbi/client/rest_test.py +61 -27
  16. castor_extractor/visualization/salesforce_reporting/__init__.py +1 -2
  17. castor_extractor/visualization/salesforce_reporting/client/__init__.py +1 -2
  18. castor_extractor/visualization/salesforce_reporting/client/rest.py +7 -90
  19. castor_extractor/visualization/salesforce_reporting/extract.py +10 -8
  20. castor_extractor/visualization/tableau/assets.py +5 -0
  21. castor_extractor/visualization/tableau/client/client.py +10 -0
  22. castor_extractor/visualization/tableau/gql_fields.py +30 -9
  23. castor_extractor/warehouse/databricks/client.py +20 -3
  24. castor_extractor/warehouse/databricks/client_test.py +14 -0
  25. castor_extractor/warehouse/databricks/credentials.py +1 -4
  26. castor_extractor/warehouse/databricks/extract.py +3 -2
  27. castor_extractor/warehouse/databricks/format.py +5 -4
  28. castor_extractor/warehouse/salesforce/__init__.py +6 -0
  29. castor_extractor/warehouse/salesforce/client.py +112 -0
  30. castor_extractor/warehouse/salesforce/constants.py +2 -0
  31. castor_extractor/warehouse/salesforce/extract.py +111 -0
  32. castor_extractor/warehouse/salesforce/format.py +67 -0
  33. castor_extractor/warehouse/salesforce/format_test.py +32 -0
  34. castor_extractor/warehouse/salesforce/soql.py +45 -0
  35. castor_extractor-0.16.4.dist-info/LICENCE +86 -0
  36. {castor_extractor-0.16.1.dist-info → castor_extractor-0.16.4.dist-info}/METADATA +2 -3
  37. {castor_extractor-0.16.1.dist-info → castor_extractor-0.16.4.dist-info}/RECORD +39 -27
  38. {castor_extractor-0.16.1.dist-info → castor_extractor-0.16.4.dist-info}/WHEEL +1 -1
  39. {castor_extractor-0.16.1.dist-info → castor_extractor-0.16.4.dist-info}/entry_points.txt +2 -1
  40. castor_extractor/visualization/salesforce_reporting/client/constants.py +0 -2
  41. castor_extractor/visualization/salesforce_reporting/client/credentials.py +0 -33
@@ -1,2 +1 @@
1
- from .credentials import SalesforceCredentials
2
- from .rest import SalesforceClient
1
+ from .rest import SalesforceReportingClient
@@ -1,13 +1,8 @@
1
1
  import logging
2
- import os
3
- from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
4
-
5
- import requests
6
- from requests import Response
2
+ from typing import Dict, Iterator, List, Optional
7
3
 
4
+ from ....utils.salesforce import SalesforceBaseClient
8
5
  from ..assets import SalesforceReportingAsset
9
- from .constants import DEFAULT_API_VERSION, DEFAULT_PAGINATION_LIMIT
10
- from .credentials import SalesforceCredentials
11
6
  from .soql import queries
12
7
 
13
8
  logger = logging.getLogger(__name__)
@@ -19,89 +14,11 @@ REQUIRING_URL_ASSETS = (
19
14
  )
20
15
 
21
16
 
22
- class SalesforceClient:
17
+ class SalesforceReportingClient(SalesforceBaseClient):
23
18
  """
24
- Salesforce API client.
25
- https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/intro_rest.htm
19
+ Salesforce Reporting API client
26
20
  """
27
21
 
28
- api_version = DEFAULT_API_VERSION
29
- pagination_limit = DEFAULT_PAGINATION_LIMIT
30
-
31
- def __init__(
32
- self,
33
- credentials: SalesforceCredentials,
34
- instance_url: str,
35
- ):
36
- self.credentials = credentials
37
- self.instance_url = instance_url
38
- self._token = self._access_token()
39
-
40
- def _access_token(self) -> Tuple[str, str]:
41
- url = f"{self.instance_url}/services/oauth2/token"
42
- response = self._call(
43
- url, "POST", data=self.credentials.token_request_payload()
44
- )
45
- return response["access_token"]
46
-
47
- def _header(self) -> Dict:
48
- return {"Authorization": f"Bearer {self._token}"}
49
-
50
- @staticmethod
51
- def _call(
52
- url: str,
53
- method: str = "GET",
54
- *,
55
- header: Optional[Dict] = None,
56
- params: Optional[Dict] = None,
57
- data: Optional[Dict] = None,
58
- processor: Optional[Callable] = None,
59
- ) -> Any:
60
- logger.debug(f"Calling {method} on {url}")
61
- result = requests.request(
62
- method,
63
- url,
64
- headers=header,
65
- params=params,
66
- data=data,
67
- )
68
- result.raise_for_status()
69
-
70
- if processor:
71
- return processor(result)
72
-
73
- return result.json()
74
-
75
- @staticmethod
76
- def _query_processor(response: Response) -> Tuple[dict, Optional[str]]:
77
- results = response.json()
78
- return results["records"], results.get("nextRecordsUrl")
79
-
80
- def _query_all(self, query: str) -> Iterator[Dict]:
81
- """
82
- Run a SOQL query over salesforce API.
83
-
84
- more: https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/dome_query.htm
85
- """
86
- url = f"{self.instance_url}/services/data/v{self.api_version}/query"
87
- records, next_page = self._call(
88
- url,
89
- params={"q": query},
90
- processor=self._query_processor,
91
- header=self._header(),
92
- )
93
- yield from records
94
-
95
- page_count = 0
96
- while next_page and page_count <= self.pagination_limit:
97
- logger.info(f"querying page {page_count}")
98
- url = f"{self.instance_url}{next_page}"
99
- records, next_page = self._call(
100
- url, processor=self._query_processor, header=self._header()
101
- )
102
- yield from records
103
- page_count += 1
104
-
105
22
  def _get_asset_url(
106
23
  self, asset_type: SalesforceReportingAsset, asset: dict
107
24
  ) -> Optional[str]:
@@ -111,15 +28,15 @@ class SalesforceClient:
111
28
 
112
29
  if asset_type == SalesforceReportingAsset.DASHBOARDS:
113
30
  path = f"lightning/r/Dashboard/{asset['Id']}/view"
114
- return os.path.join(self.instance_url, path)
31
+ return self.build_url(self._host, path)
115
32
 
116
33
  if asset_type == SalesforceReportingAsset.FOLDERS:
117
34
  path = asset["attributes"]["url"].lstrip("/")
118
- return os.path.join(self.instance_url, path)
35
+ return self.build_url(self._host, path)
119
36
 
120
37
  if asset_type == SalesforceReportingAsset.REPORTS:
121
38
  path = f"lightning/r/Report/{asset['Id']}/view"
122
- return os.path.join(self.instance_url, path)
39
+ return self.build_url(self._host, path)
123
40
 
124
41
  return None
125
42
 
@@ -10,14 +10,15 @@ from ...utils import (
10
10
  write_json,
11
11
  write_summary,
12
12
  )
13
+ from ...utils.salesforce import SalesforceCredentials
13
14
  from .assets import SalesforceReportingAsset
14
- from .client import SalesforceClient, SalesforceCredentials
15
+ from .client import SalesforceReportingClient
15
16
 
16
17
  logger = logging.getLogger(__name__)
17
18
 
18
19
 
19
20
  def iterate_all_data(
20
- client: SalesforceClient,
21
+ client: SalesforceReportingClient,
21
22
  ) -> Iterable[Tuple[str, Union[list, dict]]]:
22
23
  """Iterate over the extracted data from Salesforce"""
23
24
 
@@ -30,10 +31,10 @@ def iterate_all_data(
30
31
  def extract_all(
31
32
  username: str,
32
33
  password: str,
33
- consumer_key: str,
34
- consumer_secret: str,
34
+ client_id: str,
35
+ client_secret: str,
35
36
  security_token: str,
36
- instance_url: str,
37
+ base_url: str,
37
38
  output_directory: Optional[str] = None,
38
39
  ) -> None:
39
40
  """
@@ -44,11 +45,12 @@ def extract_all(
44
45
  creds = SalesforceCredentials(
45
46
  username=username,
46
47
  password=password,
47
- consumer_key=consumer_key,
48
- consumer_secret=consumer_secret,
48
+ client_id=client_id,
49
+ client_secret=client_secret,
49
50
  security_token=security_token,
51
+ base_url=base_url,
50
52
  )
51
- client = SalesforceClient(credentials=creds, instance_url=instance_url)
53
+ client = SalesforceReportingClient(credentials=creds)
52
54
  ts = current_timestamp()
53
55
 
54
56
  for key, data in iterate_all_data(client):
@@ -12,10 +12,12 @@ class TableauAsset(ExternalAsset):
12
12
  CUSTOM_SQL_TABLE = "custom_sql_tables"
13
13
  CUSTOM_SQL_QUERY = "custom_sql_queries"
14
14
  DASHBOARD = "dashboards"
15
+ DASHBOARD_SHEET = "dashboards_sheets"
15
16
  DATASOURCE = "datasources"
16
17
  FIELD = "fields"
17
18
  PROJECT = "projects"
18
19
  PUBLISHED_DATASOURCE = "published_datasources"
20
+ SHEET = "sheets"
19
21
  USAGE = "views"
20
22
  USER = "users"
21
23
  WORKBOOK = "workbooks"
@@ -25,7 +27,9 @@ class TableauAsset(ExternalAsset):
25
27
  def optional(cls) -> Set["TableauAsset"]:
26
28
  return {
27
29
  TableauAsset.DASHBOARD,
30
+ TableauAsset.DASHBOARD_SHEET,
28
31
  TableauAsset.FIELD,
32
+ TableauAsset.SHEET,
29
33
  TableauAsset.PUBLISHED_DATASOURCE,
30
34
  }
31
35
 
@@ -42,4 +46,5 @@ class TableauGraphqlAsset(Enum):
42
46
  DASHBOARD = "dashboards"
43
47
  DATASOURCE = "datasources"
44
48
  GROUP_FIELD = "groupFields"
49
+ SHEETS = "sheets"
45
50
  WORKBOOK_TO_DATASOURCE = "workbooks"
@@ -173,6 +173,13 @@ class ApiClient:
173
173
  TableauAsset.DASHBOARD,
174
174
  )
175
175
 
176
+ def _fetch_sheets(self) -> SerializedAsset:
177
+ """Fetches sheets"""
178
+
179
+ return self._fetch_paginated_objects(
180
+ TableauAsset.SHEET,
181
+ )
182
+
176
183
  def _fetch_paginated_objects(self, asset: TableauAsset) -> SerializedAsset:
177
184
  """Fetches paginated objects"""
178
185
 
@@ -203,6 +210,9 @@ class ApiClient:
203
210
  if asset == TableauAsset.PUBLISHED_DATASOURCE:
204
211
  assets = self._fetch_published_datasources()
205
212
 
213
+ if asset == TableauAsset.SHEET:
214
+ assets = self._fetch_sheets()
215
+
206
216
  if asset == TableauAsset.USAGE:
207
217
  assets = self._fetch_usages(self._safe_mode)
208
218
 
@@ -111,15 +111,15 @@ class GQLQueryFields(Enum):
111
111
  """
112
112
 
113
113
  DASHBOARDS: str = """
114
- id
115
- name
116
- path
117
- tags {
118
- name
119
- }
120
- workbook {
121
- luid # to retrieve the parent
122
- }
114
+ id
115
+ name
116
+ path
117
+ tags {
118
+ name
119
+ }
120
+ workbook {
121
+ luid # to retrieve the parent
122
+ }
123
123
  """
124
124
 
125
125
  DATASOURCE: str = """
@@ -160,6 +160,21 @@ class GQLQueryFields(Enum):
160
160
  role
161
161
  """
162
162
 
163
+ SHEET: str = """
164
+ containedInDashboards {
165
+ id
166
+ }
167
+ id
168
+ index
169
+ name
170
+ upstreamFields{
171
+ name
172
+ }
173
+ workbook {
174
+ luid
175
+ }
176
+ """
177
+
163
178
  WORKBOOK_TO_DATASOURCE: str = """
164
179
  luid
165
180
  id
@@ -219,6 +234,12 @@ QUERY_FIELDS: Dict[TableauAsset, QueryInfo] = {
219
234
  OBJECT_TYPE: TableauGraphqlAsset.GROUP_FIELD,
220
235
  },
221
236
  ],
237
+ TableauAsset.SHEET: [
238
+ {
239
+ FIELDS: GQLQueryFields.SHEET,
240
+ OBJECT_TYPE: TableauGraphqlAsset.SHEETS,
241
+ },
242
+ ],
222
243
  TableauAsset.WORKBOOK_TO_DATASOURCE: [
223
244
  {
224
245
  FIELDS: GQLQueryFields.WORKBOOK_TO_DATASOURCE,
@@ -31,7 +31,7 @@ class DatabricksClient(APIClient):
31
31
  db_allowed: Optional[Set[str]] = None,
32
32
  db_blocked: Optional[Set[str]] = None,
33
33
  ):
34
- super().__init__(credentials)
34
+ super().__init__(host=credentials.host, token=credentials.token)
35
35
  self._db_allowed = db_allowed
36
36
  self._db_blocked = db_blocked
37
37
  self.formatter = DatabricksFormatter()
@@ -87,15 +87,32 @@ class DatabricksClient(APIClient):
87
87
  content.get("tables", []), schema
88
88
  )
89
89
 
90
- def tables_and_columns(self, schemas: List[dict]) -> TablesColumns:
90
+ @staticmethod
91
+ def _match_table_with_user(table: dict, user_id_by_email: dict) -> dict:
92
+ table_owner_email = table.get("owner_email")
93
+ if not table_owner_email:
94
+ return table
95
+ owner_external_id = user_id_by_email.get(table_owner_email)
96
+ if not owner_external_id:
97
+ return table
98
+ return {**table, "owner_external_id": owner_external_id}
99
+
100
+ def tables_and_columns(
101
+ self, schemas: List[dict], users: List[dict]
102
+ ) -> TablesColumns:
91
103
  """
92
104
  Get the databricks tables & columns leveraging the unity catalog API
93
105
  """
94
106
  tables: List[dict] = []
95
107
  columns: List[dict] = []
108
+ user_id_by_email = {user.get("email"): user.get("id") for user in users}
96
109
  for schema in schemas:
97
110
  t_to_add, c_to_add = self._tables_columns_of_schema(schema)
98
- tables.extend(t_to_add)
111
+ t_with_owner = [
112
+ self._match_table_with_user(table, user_id_by_email)
113
+ for table in t_to_add
114
+ ]
115
+ tables.extend(t_with_owner)
99
116
  columns.extend(c_to_add)
100
117
  return tables, columns
101
118
 
@@ -64,3 +64,17 @@ def test_DatabricksClient__keep_catalog():
64
64
  assert client._keep_catalog("staging")
65
65
  assert not client._keep_catalog("dev")
66
66
  assert not client._keep_catalog("something_unknown")
67
+
68
+
69
+ def test_DatabricksClient__match_table_with_user():
70
+ client = MockDatabricksClient()
71
+ users_by_email = {"bob@castordoc.com": 3}
72
+
73
+ table = {"id": 1, "owner_email": "bob@castordoc.com"}
74
+ table_with_owner = client._match_table_with_user(table, users_by_email)
75
+
76
+ assert table_with_owner == {**table, "owner_external_id": 3}
77
+
78
+ table_without_owner = {"id": 1, "owner_email": None}
79
+ actual = client._match_table_with_user(table_without_owner, users_by_email)
80
+ assert actual == table_without_owner
@@ -25,7 +25,4 @@ def to_credentials(params: dict) -> DatabricksCredentials:
25
25
  """extract Databricks credentials"""
26
26
  host = params.get("host") or from_env(_HOST)
27
27
  token = params.get("token") or from_env(_TOKEN)
28
- return DatabricksCredentials(
29
- host=host,
30
- token=token,
31
- )
28
+ return DatabricksCredentials(host=host, token=token)
@@ -43,7 +43,7 @@ class DatabricksExtractionProcessor:
43
43
  self._storage = storage
44
44
  self._skip_existing = skip_existing
45
45
 
46
- def _should_not_reextract(self, asset_group) -> bool:
46
+ def _should_not_reextract(self, asset_group: WarehouseAssetGroup) -> bool:
47
47
  """helper function to determine whether we need to extract"""
48
48
  if not self._skip_existing:
49
49
  return False
@@ -82,7 +82,8 @@ class DatabricksExtractionProcessor:
82
82
 
83
83
  del databases
84
84
 
85
- tables, columns = self._client.tables_and_columns(schemas)
85
+ users = self._client.users()
86
+ tables, columns = self._client.tables_and_columns(schemas, users)
86
87
 
87
88
  location = self._storage.put(WarehouseAsset.TABLE.value, tables)
88
89
  catalog_locations[WarehouseAsset.TABLE.value] = location
@@ -19,10 +19,11 @@ def _to_datetime_or_none(time_ms: Optional[int]) -> Optional[datetime]:
19
19
 
20
20
  def _table_payload(schema: dict, table: dict) -> dict:
21
21
  return {
22
+ "description": table.get("comment"),
22
23
  "id": table["table_id"],
24
+ "owner_email": table.get("owner"),
23
25
  "schema_id": f"{schema['id']}",
24
26
  "table_name": table["name"],
25
- "description": table.get("comment"),
26
27
  "tags": [],
27
28
  "type": table.get("table_type"),
28
29
  }
@@ -30,12 +31,12 @@ def _table_payload(schema: dict, table: dict) -> dict:
30
31
 
31
32
  def _column_payload(table: dict, column: dict) -> dict:
32
33
  return {
33
- "id": f"`{table['id']}`.`{column['name']}`",
34
34
  "column_name": column["name"],
35
- "table_id": table["id"],
36
- "description": column.get("comment"),
37
35
  "data_type": column["type_name"],
36
+ "description": column.get("comment"),
37
+ "id": f"`{table['id']}`.`{column['name']}`",
38
38
  "ordinal_position": column["position"],
39
+ "table_id": table["id"],
39
40
  }
40
41
 
41
42
 
@@ -0,0 +1,6 @@
1
+ from .client import SalesforceClient
2
+ from .extract import (
3
+ SALESFORCE_ASSETS,
4
+ SalesforceExtractionProcessor,
5
+ extract_all,
6
+ )
@@ -0,0 +1,112 @@
1
+ import logging
2
+ from typing import Dict, Iterator, List
3
+
4
+ from tqdm import tqdm # type: ignore
5
+
6
+ from ...utils.salesforce import SalesforceBaseClient, SalesforceCredentials
7
+ from .format import SalesforceFormatter
8
+ from .soql import SOBJECT_FIELDS_QUERY_TPL, SOBJECTS_QUERY_TPL
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class SalesforceClient(SalesforceBaseClient):
14
+ """
15
+ Salesforce API client to extract sobjects
16
+ """
17
+
18
+ # Implicit (hard-coded in Salesforce) limitation when using SOQL of 2,000 rows
19
+ LIMIT_RECORDS_PER_PAGE = 2000
20
+
21
+ def __init__(self, credentials: SalesforceCredentials):
22
+ super().__init__(credentials)
23
+ self.formatter = SalesforceFormatter()
24
+
25
+ @staticmethod
26
+ def name() -> str:
27
+ return "Salesforce"
28
+
29
+ def _format_query(self, query_template: str, start_durable_id: str) -> str:
30
+ return query_template.format(
31
+ start_durable_id=start_durable_id,
32
+ limit=self.LIMIT_RECORDS_PER_PAGE,
33
+ )
34
+
35
+ def _next_records(
36
+ self, url: str, query_template: str, start_durable_id: str = "0000"
37
+ ) -> List[dict]:
38
+ query = self._format_query(
39
+ query_template, start_durable_id=start_durable_id
40
+ )
41
+ records, _ = self._call(
42
+ url, params={"q": query}, processor=self._query_processor
43
+ )
44
+ return records
45
+
46
+ def _is_last_page(self, records: List[dict]) -> bool:
47
+ return len(records) < self.LIMIT_RECORDS_PER_PAGE
48
+
49
+ def _should_query_next_page(
50
+ self, records: List[dict], page_number: int
51
+ ) -> bool:
52
+ return not (
53
+ self._is_last_page(records)
54
+ or self._has_reached_pagination_limit(page_number)
55
+ )
56
+
57
+ def _query_all(self, query_template: str) -> Iterator[dict]:
58
+ """
59
+ Run a SOQL query over salesforce API
60
+
61
+ Note, pagination is performed via a LIMIT in the SOQL query and requires
62
+ that ids are sorted. The SOQL query must support `limit` and
63
+ `start_durable_id` as parameters.
64
+ """
65
+ url = self.query_url
66
+ logger.info("querying page 0")
67
+ records = self._next_records(url, query_template)
68
+ yield from records
69
+
70
+ page_count = 1
71
+ while self._should_query_next_page(records, page_count):
72
+ logger.info(f"querying page {page_count}")
73
+ last_durable_id = records[-1]["DurableId"]
74
+ records = self._next_records(
75
+ url, query_template, start_durable_id=last_durable_id
76
+ )
77
+ yield from records
78
+ page_count += 1
79
+
80
+ def fetch_sobjects(self) -> List[dict]:
81
+ """Fetch all sobjects"""
82
+ logger.info("Extracting sobjects")
83
+ return list(self._query_all(SOBJECTS_QUERY_TPL))
84
+
85
+ def fetch_fields(self, sobject_name: str) -> List[dict]:
86
+ """Fetches fields of a given sobject"""
87
+ query = SOBJECT_FIELDS_QUERY_TPL.format(
88
+ entity_definition_id=sobject_name
89
+ )
90
+ response = self._call(self.tooling_url, params={"q": query})
91
+ return response["records"]
92
+
93
+ def tables(self) -> List[dict]:
94
+ """
95
+ Get Salesforce sobjects as tables
96
+ """
97
+ sobjects = self.fetch_sobjects()
98
+ logger.info(f"Extracted {len(sobjects)} sobjects")
99
+ return self.formatter.tables(sobjects)
100
+
101
+ def columns(
102
+ self, sobject_names: List[str], show_progress: bool = True
103
+ ) -> List[dict]:
104
+ """
105
+ Get salesforce sobject fields as columns
106
+ show_progress: optionally deactivate the tqdm progress bar
107
+ """
108
+ sobject_fields: Dict[str, List[dict]] = dict()
109
+ for sobject_name in tqdm(sobject_names, disable=not show_progress):
110
+ fields = self.fetch_fields(sobject_name)
111
+ sobject_fields[sobject_name] = fields
112
+ return self.formatter.columns(sobject_fields)
@@ -0,0 +1,2 @@
1
+ DATABASE_NAME = "salesforce"
2
+ SCHEMA_NAME = "schema"
@@ -0,0 +1,111 @@
1
+ import logging
2
+ from typing import Dict, List, Tuple
3
+
4
+ from ...utils import AbstractStorage, LocalStorage, write_summary
5
+ from ...utils.salesforce import to_credentials
6
+ from ..abstract import (
7
+ SupportedAssets,
8
+ WarehouseAsset,
9
+ WarehouseAssetGroup,
10
+ common_args,
11
+ )
12
+ from .client import SalesforceClient
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ Paths = Dict[str, str]
18
+
19
+ SALESFORCE_CATALOG_ASSETS: Tuple[WarehouseAsset, ...] = (
20
+ WarehouseAsset.TABLE,
21
+ WarehouseAsset.COLUMN,
22
+ )
23
+
24
+ SALESFORCE_ASSETS: SupportedAssets = {
25
+ WarehouseAssetGroup.CATALOG: SALESFORCE_CATALOG_ASSETS
26
+ }
27
+
28
+
29
+ class SalesforceExtractionProcessor:
30
+ """Salesforce API-based extraction management - warehouse part"""
31
+
32
+ def __init__(
33
+ self,
34
+ client: SalesforceClient,
35
+ storage: AbstractStorage,
36
+ skip_existing: bool = False,
37
+ ):
38
+ self._client = client
39
+ self._storage = storage
40
+ self._skip_existing = skip_existing
41
+
42
+ def _should_extract(self) -> bool:
43
+ """helper function to determine whether we need to extract"""
44
+ if not self._skip_existing:
45
+ return True
46
+
47
+ for asset in SALESFORCE_CATALOG_ASSETS:
48
+ if not self._storage.exists(asset.value):
49
+ return True
50
+
51
+ logger.info("Skipped, files for catalog already exist")
52
+ return False
53
+
54
+ def _existing_group_paths(self) -> Paths:
55
+ return {
56
+ a.value: self._storage.path(a.value)
57
+ for a in SALESFORCE_CATALOG_ASSETS
58
+ }
59
+
60
+ def extract_catalog(self, show_progress: bool = True) -> Paths:
61
+ """
62
+ Extract the following catalog assets: tables and columns
63
+ and return the locations of the extracted data
64
+ """
65
+ if not self._should_extract():
66
+ return self._existing_group_paths()
67
+
68
+ catalog_locations: Paths = dict()
69
+
70
+ tables = self._client.tables()
71
+ location = self._storage.put(WarehouseAsset.TABLE.value, tables)
72
+ catalog_locations[WarehouseAsset.TABLE.value] = location
73
+ logger.info(f"Extracted {len(tables)} tables to {location}")
74
+
75
+ table_names = [t["table_name"] for t in tables]
76
+ columns = self._client.columns(table_names, show_progress)
77
+ location = self._storage.put(WarehouseAsset.COLUMN.value, columns)
78
+ catalog_locations[WarehouseAsset.COLUMN.value] = location
79
+ logger.info(f"Extracted {len(columns)} columns to {location}")
80
+ return catalog_locations
81
+
82
+ def extract_role(self) -> Paths:
83
+ """extract no users and return the empty file location"""
84
+ users: List[dict] = []
85
+ location = self._storage.put(WarehouseAsset.USER.value, users)
86
+ logger.info(f"Extracted {len(users)} users to {location}")
87
+ return {WarehouseAsset.USER.value: location}
88
+
89
+
90
+ def extract_all(**kwargs) -> None:
91
+ """
92
+ Extract all assets from Salesforce and store the results in CSV files
93
+ """
94
+ output_directory, skip_existing = common_args(kwargs)
95
+
96
+ client = SalesforceClient(credentials=to_credentials(kwargs))
97
+ storage = LocalStorage(directory=output_directory)
98
+ extractor = SalesforceExtractionProcessor(
99
+ client=client,
100
+ storage=storage,
101
+ skip_existing=skip_existing,
102
+ )
103
+
104
+ extractor.extract_catalog()
105
+ extractor.extract_role()
106
+
107
+ write_summary(
108
+ output_directory,
109
+ storage.stored_at_ts,
110
+ client_name=client.name(),
111
+ )