castor-extractor 0.16.1__py3-none-any.whl → 0.16.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of castor-extractor might be problematic. Click here for more details.
- CHANGELOG.md +12 -0
- castor_extractor/commands/extract_databricks.py +3 -0
- castor_extractor/commands/extract_salesforce.py +43 -0
- castor_extractor/commands/extract_salesforce_reporting.py +6 -6
- castor_extractor/utils/client/api.py +36 -27
- castor_extractor/utils/salesforce/__init__.py +3 -0
- castor_extractor/utils/salesforce/client.py +84 -0
- castor_extractor/utils/salesforce/client_test.py +21 -0
- castor_extractor/utils/salesforce/constants.py +13 -0
- castor_extractor/utils/salesforce/credentials.py +65 -0
- castor_extractor/{visualization/salesforce_reporting/client → utils/salesforce}/credentials_test.py +3 -2
- castor_extractor/visualization/domo/client/client.py +1 -1
- castor_extractor/visualization/powerbi/client/constants.py +3 -3
- castor_extractor/visualization/powerbi/client/rest.py +14 -8
- castor_extractor/visualization/powerbi/client/rest_test.py +61 -27
- castor_extractor/visualization/salesforce_reporting/__init__.py +1 -2
- castor_extractor/visualization/salesforce_reporting/client/__init__.py +1 -2
- castor_extractor/visualization/salesforce_reporting/client/rest.py +7 -90
- castor_extractor/visualization/salesforce_reporting/extract.py +10 -8
- castor_extractor/visualization/tableau/assets.py +5 -0
- castor_extractor/visualization/tableau/client/client.py +10 -0
- castor_extractor/visualization/tableau/gql_fields.py +30 -9
- castor_extractor/warehouse/databricks/client.py +20 -3
- castor_extractor/warehouse/databricks/client_test.py +14 -0
- castor_extractor/warehouse/databricks/credentials.py +1 -4
- castor_extractor/warehouse/databricks/extract.py +3 -2
- castor_extractor/warehouse/databricks/format.py +5 -4
- castor_extractor/warehouse/salesforce/__init__.py +6 -0
- castor_extractor/warehouse/salesforce/client.py +112 -0
- castor_extractor/warehouse/salesforce/constants.py +2 -0
- castor_extractor/warehouse/salesforce/extract.py +111 -0
- castor_extractor/warehouse/salesforce/format.py +67 -0
- castor_extractor/warehouse/salesforce/format_test.py +32 -0
- castor_extractor/warehouse/salesforce/soql.py +45 -0
- castor_extractor-0.16.4.dist-info/LICENCE +86 -0
- {castor_extractor-0.16.1.dist-info → castor_extractor-0.16.4.dist-info}/METADATA +2 -3
- {castor_extractor-0.16.1.dist-info → castor_extractor-0.16.4.dist-info}/RECORD +39 -27
- {castor_extractor-0.16.1.dist-info → castor_extractor-0.16.4.dist-info}/WHEEL +1 -1
- {castor_extractor-0.16.1.dist-info → castor_extractor-0.16.4.dist-info}/entry_points.txt +2 -1
- castor_extractor/visualization/salesforce_reporting/client/constants.py +0 -2
- castor_extractor/visualization/salesforce_reporting/client/credentials.py +0 -33
|
@@ -1,2 +1 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .rest import SalesforceClient
|
|
1
|
+
from .rest import SalesforceReportingClient
|
|
@@ -1,13 +1,8 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
import
|
|
3
|
-
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
|
|
4
|
-
|
|
5
|
-
import requests
|
|
6
|
-
from requests import Response
|
|
2
|
+
from typing import Dict, Iterator, List, Optional
|
|
7
3
|
|
|
4
|
+
from ....utils.salesforce import SalesforceBaseClient
|
|
8
5
|
from ..assets import SalesforceReportingAsset
|
|
9
|
-
from .constants import DEFAULT_API_VERSION, DEFAULT_PAGINATION_LIMIT
|
|
10
|
-
from .credentials import SalesforceCredentials
|
|
11
6
|
from .soql import queries
|
|
12
7
|
|
|
13
8
|
logger = logging.getLogger(__name__)
|
|
@@ -19,89 +14,11 @@ REQUIRING_URL_ASSETS = (
|
|
|
19
14
|
)
|
|
20
15
|
|
|
21
16
|
|
|
22
|
-
class
|
|
17
|
+
class SalesforceReportingClient(SalesforceBaseClient):
|
|
23
18
|
"""
|
|
24
|
-
Salesforce API client
|
|
25
|
-
https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/intro_rest.htm
|
|
19
|
+
Salesforce Reporting API client
|
|
26
20
|
"""
|
|
27
21
|
|
|
28
|
-
api_version = DEFAULT_API_VERSION
|
|
29
|
-
pagination_limit = DEFAULT_PAGINATION_LIMIT
|
|
30
|
-
|
|
31
|
-
def __init__(
|
|
32
|
-
self,
|
|
33
|
-
credentials: SalesforceCredentials,
|
|
34
|
-
instance_url: str,
|
|
35
|
-
):
|
|
36
|
-
self.credentials = credentials
|
|
37
|
-
self.instance_url = instance_url
|
|
38
|
-
self._token = self._access_token()
|
|
39
|
-
|
|
40
|
-
def _access_token(self) -> Tuple[str, str]:
|
|
41
|
-
url = f"{self.instance_url}/services/oauth2/token"
|
|
42
|
-
response = self._call(
|
|
43
|
-
url, "POST", data=self.credentials.token_request_payload()
|
|
44
|
-
)
|
|
45
|
-
return response["access_token"]
|
|
46
|
-
|
|
47
|
-
def _header(self) -> Dict:
|
|
48
|
-
return {"Authorization": f"Bearer {self._token}"}
|
|
49
|
-
|
|
50
|
-
@staticmethod
|
|
51
|
-
def _call(
|
|
52
|
-
url: str,
|
|
53
|
-
method: str = "GET",
|
|
54
|
-
*,
|
|
55
|
-
header: Optional[Dict] = None,
|
|
56
|
-
params: Optional[Dict] = None,
|
|
57
|
-
data: Optional[Dict] = None,
|
|
58
|
-
processor: Optional[Callable] = None,
|
|
59
|
-
) -> Any:
|
|
60
|
-
logger.debug(f"Calling {method} on {url}")
|
|
61
|
-
result = requests.request(
|
|
62
|
-
method,
|
|
63
|
-
url,
|
|
64
|
-
headers=header,
|
|
65
|
-
params=params,
|
|
66
|
-
data=data,
|
|
67
|
-
)
|
|
68
|
-
result.raise_for_status()
|
|
69
|
-
|
|
70
|
-
if processor:
|
|
71
|
-
return processor(result)
|
|
72
|
-
|
|
73
|
-
return result.json()
|
|
74
|
-
|
|
75
|
-
@staticmethod
|
|
76
|
-
def _query_processor(response: Response) -> Tuple[dict, Optional[str]]:
|
|
77
|
-
results = response.json()
|
|
78
|
-
return results["records"], results.get("nextRecordsUrl")
|
|
79
|
-
|
|
80
|
-
def _query_all(self, query: str) -> Iterator[Dict]:
|
|
81
|
-
"""
|
|
82
|
-
Run a SOQL query over salesforce API.
|
|
83
|
-
|
|
84
|
-
more: https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/dome_query.htm
|
|
85
|
-
"""
|
|
86
|
-
url = f"{self.instance_url}/services/data/v{self.api_version}/query"
|
|
87
|
-
records, next_page = self._call(
|
|
88
|
-
url,
|
|
89
|
-
params={"q": query},
|
|
90
|
-
processor=self._query_processor,
|
|
91
|
-
header=self._header(),
|
|
92
|
-
)
|
|
93
|
-
yield from records
|
|
94
|
-
|
|
95
|
-
page_count = 0
|
|
96
|
-
while next_page and page_count <= self.pagination_limit:
|
|
97
|
-
logger.info(f"querying page {page_count}")
|
|
98
|
-
url = f"{self.instance_url}{next_page}"
|
|
99
|
-
records, next_page = self._call(
|
|
100
|
-
url, processor=self._query_processor, header=self._header()
|
|
101
|
-
)
|
|
102
|
-
yield from records
|
|
103
|
-
page_count += 1
|
|
104
|
-
|
|
105
22
|
def _get_asset_url(
|
|
106
23
|
self, asset_type: SalesforceReportingAsset, asset: dict
|
|
107
24
|
) -> Optional[str]:
|
|
@@ -111,15 +28,15 @@ class SalesforceClient:
|
|
|
111
28
|
|
|
112
29
|
if asset_type == SalesforceReportingAsset.DASHBOARDS:
|
|
113
30
|
path = f"lightning/r/Dashboard/{asset['Id']}/view"
|
|
114
|
-
return
|
|
31
|
+
return self.build_url(self._host, path)
|
|
115
32
|
|
|
116
33
|
if asset_type == SalesforceReportingAsset.FOLDERS:
|
|
117
34
|
path = asset["attributes"]["url"].lstrip("/")
|
|
118
|
-
return
|
|
35
|
+
return self.build_url(self._host, path)
|
|
119
36
|
|
|
120
37
|
if asset_type == SalesforceReportingAsset.REPORTS:
|
|
121
38
|
path = f"lightning/r/Report/{asset['Id']}/view"
|
|
122
|
-
return
|
|
39
|
+
return self.build_url(self._host, path)
|
|
123
40
|
|
|
124
41
|
return None
|
|
125
42
|
|
|
@@ -10,14 +10,15 @@ from ...utils import (
|
|
|
10
10
|
write_json,
|
|
11
11
|
write_summary,
|
|
12
12
|
)
|
|
13
|
+
from ...utils.salesforce import SalesforceCredentials
|
|
13
14
|
from .assets import SalesforceReportingAsset
|
|
14
|
-
from .client import
|
|
15
|
+
from .client import SalesforceReportingClient
|
|
15
16
|
|
|
16
17
|
logger = logging.getLogger(__name__)
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
def iterate_all_data(
|
|
20
|
-
client:
|
|
21
|
+
client: SalesforceReportingClient,
|
|
21
22
|
) -> Iterable[Tuple[str, Union[list, dict]]]:
|
|
22
23
|
"""Iterate over the extracted data from Salesforce"""
|
|
23
24
|
|
|
@@ -30,10 +31,10 @@ def iterate_all_data(
|
|
|
30
31
|
def extract_all(
|
|
31
32
|
username: str,
|
|
32
33
|
password: str,
|
|
33
|
-
|
|
34
|
-
|
|
34
|
+
client_id: str,
|
|
35
|
+
client_secret: str,
|
|
35
36
|
security_token: str,
|
|
36
|
-
|
|
37
|
+
base_url: str,
|
|
37
38
|
output_directory: Optional[str] = None,
|
|
38
39
|
) -> None:
|
|
39
40
|
"""
|
|
@@ -44,11 +45,12 @@ def extract_all(
|
|
|
44
45
|
creds = SalesforceCredentials(
|
|
45
46
|
username=username,
|
|
46
47
|
password=password,
|
|
47
|
-
|
|
48
|
-
|
|
48
|
+
client_id=client_id,
|
|
49
|
+
client_secret=client_secret,
|
|
49
50
|
security_token=security_token,
|
|
51
|
+
base_url=base_url,
|
|
50
52
|
)
|
|
51
|
-
client =
|
|
53
|
+
client = SalesforceReportingClient(credentials=creds)
|
|
52
54
|
ts = current_timestamp()
|
|
53
55
|
|
|
54
56
|
for key, data in iterate_all_data(client):
|
|
@@ -12,10 +12,12 @@ class TableauAsset(ExternalAsset):
|
|
|
12
12
|
CUSTOM_SQL_TABLE = "custom_sql_tables"
|
|
13
13
|
CUSTOM_SQL_QUERY = "custom_sql_queries"
|
|
14
14
|
DASHBOARD = "dashboards"
|
|
15
|
+
DASHBOARD_SHEET = "dashboards_sheets"
|
|
15
16
|
DATASOURCE = "datasources"
|
|
16
17
|
FIELD = "fields"
|
|
17
18
|
PROJECT = "projects"
|
|
18
19
|
PUBLISHED_DATASOURCE = "published_datasources"
|
|
20
|
+
SHEET = "sheets"
|
|
19
21
|
USAGE = "views"
|
|
20
22
|
USER = "users"
|
|
21
23
|
WORKBOOK = "workbooks"
|
|
@@ -25,7 +27,9 @@ class TableauAsset(ExternalAsset):
|
|
|
25
27
|
def optional(cls) -> Set["TableauAsset"]:
|
|
26
28
|
return {
|
|
27
29
|
TableauAsset.DASHBOARD,
|
|
30
|
+
TableauAsset.DASHBOARD_SHEET,
|
|
28
31
|
TableauAsset.FIELD,
|
|
32
|
+
TableauAsset.SHEET,
|
|
29
33
|
TableauAsset.PUBLISHED_DATASOURCE,
|
|
30
34
|
}
|
|
31
35
|
|
|
@@ -42,4 +46,5 @@ class TableauGraphqlAsset(Enum):
|
|
|
42
46
|
DASHBOARD = "dashboards"
|
|
43
47
|
DATASOURCE = "datasources"
|
|
44
48
|
GROUP_FIELD = "groupFields"
|
|
49
|
+
SHEETS = "sheets"
|
|
45
50
|
WORKBOOK_TO_DATASOURCE = "workbooks"
|
|
@@ -173,6 +173,13 @@ class ApiClient:
|
|
|
173
173
|
TableauAsset.DASHBOARD,
|
|
174
174
|
)
|
|
175
175
|
|
|
176
|
+
def _fetch_sheets(self) -> SerializedAsset:
|
|
177
|
+
"""Fetches sheets"""
|
|
178
|
+
|
|
179
|
+
return self._fetch_paginated_objects(
|
|
180
|
+
TableauAsset.SHEET,
|
|
181
|
+
)
|
|
182
|
+
|
|
176
183
|
def _fetch_paginated_objects(self, asset: TableauAsset) -> SerializedAsset:
|
|
177
184
|
"""Fetches paginated objects"""
|
|
178
185
|
|
|
@@ -203,6 +210,9 @@ class ApiClient:
|
|
|
203
210
|
if asset == TableauAsset.PUBLISHED_DATASOURCE:
|
|
204
211
|
assets = self._fetch_published_datasources()
|
|
205
212
|
|
|
213
|
+
if asset == TableauAsset.SHEET:
|
|
214
|
+
assets = self._fetch_sheets()
|
|
215
|
+
|
|
206
216
|
if asset == TableauAsset.USAGE:
|
|
207
217
|
assets = self._fetch_usages(self._safe_mode)
|
|
208
218
|
|
|
@@ -111,15 +111,15 @@ class GQLQueryFields(Enum):
|
|
|
111
111
|
"""
|
|
112
112
|
|
|
113
113
|
DASHBOARDS: str = """
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
114
|
+
id
|
|
115
|
+
name
|
|
116
|
+
path
|
|
117
|
+
tags {
|
|
118
|
+
name
|
|
119
|
+
}
|
|
120
|
+
workbook {
|
|
121
|
+
luid # to retrieve the parent
|
|
122
|
+
}
|
|
123
123
|
"""
|
|
124
124
|
|
|
125
125
|
DATASOURCE: str = """
|
|
@@ -160,6 +160,21 @@ class GQLQueryFields(Enum):
|
|
|
160
160
|
role
|
|
161
161
|
"""
|
|
162
162
|
|
|
163
|
+
SHEET: str = """
|
|
164
|
+
containedInDashboards {
|
|
165
|
+
id
|
|
166
|
+
}
|
|
167
|
+
id
|
|
168
|
+
index
|
|
169
|
+
name
|
|
170
|
+
upstreamFields{
|
|
171
|
+
name
|
|
172
|
+
}
|
|
173
|
+
workbook {
|
|
174
|
+
luid
|
|
175
|
+
}
|
|
176
|
+
"""
|
|
177
|
+
|
|
163
178
|
WORKBOOK_TO_DATASOURCE: str = """
|
|
164
179
|
luid
|
|
165
180
|
id
|
|
@@ -219,6 +234,12 @@ QUERY_FIELDS: Dict[TableauAsset, QueryInfo] = {
|
|
|
219
234
|
OBJECT_TYPE: TableauGraphqlAsset.GROUP_FIELD,
|
|
220
235
|
},
|
|
221
236
|
],
|
|
237
|
+
TableauAsset.SHEET: [
|
|
238
|
+
{
|
|
239
|
+
FIELDS: GQLQueryFields.SHEET,
|
|
240
|
+
OBJECT_TYPE: TableauGraphqlAsset.SHEETS,
|
|
241
|
+
},
|
|
242
|
+
],
|
|
222
243
|
TableauAsset.WORKBOOK_TO_DATASOURCE: [
|
|
223
244
|
{
|
|
224
245
|
FIELDS: GQLQueryFields.WORKBOOK_TO_DATASOURCE,
|
|
@@ -31,7 +31,7 @@ class DatabricksClient(APIClient):
|
|
|
31
31
|
db_allowed: Optional[Set[str]] = None,
|
|
32
32
|
db_blocked: Optional[Set[str]] = None,
|
|
33
33
|
):
|
|
34
|
-
super().__init__(credentials)
|
|
34
|
+
super().__init__(host=credentials.host, token=credentials.token)
|
|
35
35
|
self._db_allowed = db_allowed
|
|
36
36
|
self._db_blocked = db_blocked
|
|
37
37
|
self.formatter = DatabricksFormatter()
|
|
@@ -87,15 +87,32 @@ class DatabricksClient(APIClient):
|
|
|
87
87
|
content.get("tables", []), schema
|
|
88
88
|
)
|
|
89
89
|
|
|
90
|
-
|
|
90
|
+
@staticmethod
|
|
91
|
+
def _match_table_with_user(table: dict, user_id_by_email: dict) -> dict:
|
|
92
|
+
table_owner_email = table.get("owner_email")
|
|
93
|
+
if not table_owner_email:
|
|
94
|
+
return table
|
|
95
|
+
owner_external_id = user_id_by_email.get(table_owner_email)
|
|
96
|
+
if not owner_external_id:
|
|
97
|
+
return table
|
|
98
|
+
return {**table, "owner_external_id": owner_external_id}
|
|
99
|
+
|
|
100
|
+
def tables_and_columns(
|
|
101
|
+
self, schemas: List[dict], users: List[dict]
|
|
102
|
+
) -> TablesColumns:
|
|
91
103
|
"""
|
|
92
104
|
Get the databricks tables & columns leveraging the unity catalog API
|
|
93
105
|
"""
|
|
94
106
|
tables: List[dict] = []
|
|
95
107
|
columns: List[dict] = []
|
|
108
|
+
user_id_by_email = {user.get("email"): user.get("id") for user in users}
|
|
96
109
|
for schema in schemas:
|
|
97
110
|
t_to_add, c_to_add = self._tables_columns_of_schema(schema)
|
|
98
|
-
|
|
111
|
+
t_with_owner = [
|
|
112
|
+
self._match_table_with_user(table, user_id_by_email)
|
|
113
|
+
for table in t_to_add
|
|
114
|
+
]
|
|
115
|
+
tables.extend(t_with_owner)
|
|
99
116
|
columns.extend(c_to_add)
|
|
100
117
|
return tables, columns
|
|
101
118
|
|
|
@@ -64,3 +64,17 @@ def test_DatabricksClient__keep_catalog():
|
|
|
64
64
|
assert client._keep_catalog("staging")
|
|
65
65
|
assert not client._keep_catalog("dev")
|
|
66
66
|
assert not client._keep_catalog("something_unknown")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_DatabricksClient__match_table_with_user():
|
|
70
|
+
client = MockDatabricksClient()
|
|
71
|
+
users_by_email = {"bob@castordoc.com": 3}
|
|
72
|
+
|
|
73
|
+
table = {"id": 1, "owner_email": "bob@castordoc.com"}
|
|
74
|
+
table_with_owner = client._match_table_with_user(table, users_by_email)
|
|
75
|
+
|
|
76
|
+
assert table_with_owner == {**table, "owner_external_id": 3}
|
|
77
|
+
|
|
78
|
+
table_without_owner = {"id": 1, "owner_email": None}
|
|
79
|
+
actual = client._match_table_with_user(table_without_owner, users_by_email)
|
|
80
|
+
assert actual == table_without_owner
|
|
@@ -25,7 +25,4 @@ def to_credentials(params: dict) -> DatabricksCredentials:
|
|
|
25
25
|
"""extract Databricks credentials"""
|
|
26
26
|
host = params.get("host") or from_env(_HOST)
|
|
27
27
|
token = params.get("token") or from_env(_TOKEN)
|
|
28
|
-
return DatabricksCredentials(
|
|
29
|
-
host=host,
|
|
30
|
-
token=token,
|
|
31
|
-
)
|
|
28
|
+
return DatabricksCredentials(host=host, token=token)
|
|
@@ -43,7 +43,7 @@ class DatabricksExtractionProcessor:
|
|
|
43
43
|
self._storage = storage
|
|
44
44
|
self._skip_existing = skip_existing
|
|
45
45
|
|
|
46
|
-
def _should_not_reextract(self, asset_group) -> bool:
|
|
46
|
+
def _should_not_reextract(self, asset_group: WarehouseAssetGroup) -> bool:
|
|
47
47
|
"""helper function to determine whether we need to extract"""
|
|
48
48
|
if not self._skip_existing:
|
|
49
49
|
return False
|
|
@@ -82,7 +82,8 @@ class DatabricksExtractionProcessor:
|
|
|
82
82
|
|
|
83
83
|
del databases
|
|
84
84
|
|
|
85
|
-
|
|
85
|
+
users = self._client.users()
|
|
86
|
+
tables, columns = self._client.tables_and_columns(schemas, users)
|
|
86
87
|
|
|
87
88
|
location = self._storage.put(WarehouseAsset.TABLE.value, tables)
|
|
88
89
|
catalog_locations[WarehouseAsset.TABLE.value] = location
|
|
@@ -19,10 +19,11 @@ def _to_datetime_or_none(time_ms: Optional[int]) -> Optional[datetime]:
|
|
|
19
19
|
|
|
20
20
|
def _table_payload(schema: dict, table: dict) -> dict:
|
|
21
21
|
return {
|
|
22
|
+
"description": table.get("comment"),
|
|
22
23
|
"id": table["table_id"],
|
|
24
|
+
"owner_email": table.get("owner"),
|
|
23
25
|
"schema_id": f"{schema['id']}",
|
|
24
26
|
"table_name": table["name"],
|
|
25
|
-
"description": table.get("comment"),
|
|
26
27
|
"tags": [],
|
|
27
28
|
"type": table.get("table_type"),
|
|
28
29
|
}
|
|
@@ -30,12 +31,12 @@ def _table_payload(schema: dict, table: dict) -> dict:
|
|
|
30
31
|
|
|
31
32
|
def _column_payload(table: dict, column: dict) -> dict:
|
|
32
33
|
return {
|
|
33
|
-
"id": f"`{table['id']}`.`{column['name']}`",
|
|
34
34
|
"column_name": column["name"],
|
|
35
|
-
"table_id": table["id"],
|
|
36
|
-
"description": column.get("comment"),
|
|
37
35
|
"data_type": column["type_name"],
|
|
36
|
+
"description": column.get("comment"),
|
|
37
|
+
"id": f"`{table['id']}`.`{column['name']}`",
|
|
38
38
|
"ordinal_position": column["position"],
|
|
39
|
+
"table_id": table["id"],
|
|
39
40
|
}
|
|
40
41
|
|
|
41
42
|
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, Iterator, List
|
|
3
|
+
|
|
4
|
+
from tqdm import tqdm # type: ignore
|
|
5
|
+
|
|
6
|
+
from ...utils.salesforce import SalesforceBaseClient, SalesforceCredentials
|
|
7
|
+
from .format import SalesforceFormatter
|
|
8
|
+
from .soql import SOBJECT_FIELDS_QUERY_TPL, SOBJECTS_QUERY_TPL
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SalesforceClient(SalesforceBaseClient):
|
|
14
|
+
"""
|
|
15
|
+
Salesforce API client to extract sobjects
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
# Implicit (hard-coded in Salesforce) limitation when using SOQL of 2,000 rows
|
|
19
|
+
LIMIT_RECORDS_PER_PAGE = 2000
|
|
20
|
+
|
|
21
|
+
def __init__(self, credentials: SalesforceCredentials):
|
|
22
|
+
super().__init__(credentials)
|
|
23
|
+
self.formatter = SalesforceFormatter()
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def name() -> str:
|
|
27
|
+
return "Salesforce"
|
|
28
|
+
|
|
29
|
+
def _format_query(self, query_template: str, start_durable_id: str) -> str:
|
|
30
|
+
return query_template.format(
|
|
31
|
+
start_durable_id=start_durable_id,
|
|
32
|
+
limit=self.LIMIT_RECORDS_PER_PAGE,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def _next_records(
|
|
36
|
+
self, url: str, query_template: str, start_durable_id: str = "0000"
|
|
37
|
+
) -> List[dict]:
|
|
38
|
+
query = self._format_query(
|
|
39
|
+
query_template, start_durable_id=start_durable_id
|
|
40
|
+
)
|
|
41
|
+
records, _ = self._call(
|
|
42
|
+
url, params={"q": query}, processor=self._query_processor
|
|
43
|
+
)
|
|
44
|
+
return records
|
|
45
|
+
|
|
46
|
+
def _is_last_page(self, records: List[dict]) -> bool:
|
|
47
|
+
return len(records) < self.LIMIT_RECORDS_PER_PAGE
|
|
48
|
+
|
|
49
|
+
def _should_query_next_page(
|
|
50
|
+
self, records: List[dict], page_number: int
|
|
51
|
+
) -> bool:
|
|
52
|
+
return not (
|
|
53
|
+
self._is_last_page(records)
|
|
54
|
+
or self._has_reached_pagination_limit(page_number)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def _query_all(self, query_template: str) -> Iterator[dict]:
|
|
58
|
+
"""
|
|
59
|
+
Run a SOQL query over salesforce API
|
|
60
|
+
|
|
61
|
+
Note, pagination is performed via a LIMIT in the SOQL query and requires
|
|
62
|
+
that ids are sorted. The SOQL query must support `limit` and
|
|
63
|
+
`start_durable_id` as parameters.
|
|
64
|
+
"""
|
|
65
|
+
url = self.query_url
|
|
66
|
+
logger.info("querying page 0")
|
|
67
|
+
records = self._next_records(url, query_template)
|
|
68
|
+
yield from records
|
|
69
|
+
|
|
70
|
+
page_count = 1
|
|
71
|
+
while self._should_query_next_page(records, page_count):
|
|
72
|
+
logger.info(f"querying page {page_count}")
|
|
73
|
+
last_durable_id = records[-1]["DurableId"]
|
|
74
|
+
records = self._next_records(
|
|
75
|
+
url, query_template, start_durable_id=last_durable_id
|
|
76
|
+
)
|
|
77
|
+
yield from records
|
|
78
|
+
page_count += 1
|
|
79
|
+
|
|
80
|
+
def fetch_sobjects(self) -> List[dict]:
|
|
81
|
+
"""Fetch all sobjects"""
|
|
82
|
+
logger.info("Extracting sobjects")
|
|
83
|
+
return list(self._query_all(SOBJECTS_QUERY_TPL))
|
|
84
|
+
|
|
85
|
+
def fetch_fields(self, sobject_name: str) -> List[dict]:
|
|
86
|
+
"""Fetches fields of a given sobject"""
|
|
87
|
+
query = SOBJECT_FIELDS_QUERY_TPL.format(
|
|
88
|
+
entity_definition_id=sobject_name
|
|
89
|
+
)
|
|
90
|
+
response = self._call(self.tooling_url, params={"q": query})
|
|
91
|
+
return response["records"]
|
|
92
|
+
|
|
93
|
+
def tables(self) -> List[dict]:
|
|
94
|
+
"""
|
|
95
|
+
Get Salesforce sobjects as tables
|
|
96
|
+
"""
|
|
97
|
+
sobjects = self.fetch_sobjects()
|
|
98
|
+
logger.info(f"Extracted {len(sobjects)} sobjects")
|
|
99
|
+
return self.formatter.tables(sobjects)
|
|
100
|
+
|
|
101
|
+
def columns(
|
|
102
|
+
self, sobject_names: List[str], show_progress: bool = True
|
|
103
|
+
) -> List[dict]:
|
|
104
|
+
"""
|
|
105
|
+
Get salesforce sobject fields as columns
|
|
106
|
+
show_progress: optionally deactivate the tqdm progress bar
|
|
107
|
+
"""
|
|
108
|
+
sobject_fields: Dict[str, List[dict]] = dict()
|
|
109
|
+
for sobject_name in tqdm(sobject_names, disable=not show_progress):
|
|
110
|
+
fields = self.fetch_fields(sobject_name)
|
|
111
|
+
sobject_fields[sobject_name] = fields
|
|
112
|
+
return self.formatter.columns(sobject_fields)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, List, Tuple
|
|
3
|
+
|
|
4
|
+
from ...utils import AbstractStorage, LocalStorage, write_summary
|
|
5
|
+
from ...utils.salesforce import to_credentials
|
|
6
|
+
from ..abstract import (
|
|
7
|
+
SupportedAssets,
|
|
8
|
+
WarehouseAsset,
|
|
9
|
+
WarehouseAssetGroup,
|
|
10
|
+
common_args,
|
|
11
|
+
)
|
|
12
|
+
from .client import SalesforceClient
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
Paths = Dict[str, str]
|
|
18
|
+
|
|
19
|
+
SALESFORCE_CATALOG_ASSETS: Tuple[WarehouseAsset, ...] = (
|
|
20
|
+
WarehouseAsset.TABLE,
|
|
21
|
+
WarehouseAsset.COLUMN,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
SALESFORCE_ASSETS: SupportedAssets = {
|
|
25
|
+
WarehouseAssetGroup.CATALOG: SALESFORCE_CATALOG_ASSETS
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SalesforceExtractionProcessor:
|
|
30
|
+
"""Salesforce API-based extraction management - warehouse part"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
client: SalesforceClient,
|
|
35
|
+
storage: AbstractStorage,
|
|
36
|
+
skip_existing: bool = False,
|
|
37
|
+
):
|
|
38
|
+
self._client = client
|
|
39
|
+
self._storage = storage
|
|
40
|
+
self._skip_existing = skip_existing
|
|
41
|
+
|
|
42
|
+
def _should_extract(self) -> bool:
|
|
43
|
+
"""helper function to determine whether we need to extract"""
|
|
44
|
+
if not self._skip_existing:
|
|
45
|
+
return True
|
|
46
|
+
|
|
47
|
+
for asset in SALESFORCE_CATALOG_ASSETS:
|
|
48
|
+
if not self._storage.exists(asset.value):
|
|
49
|
+
return True
|
|
50
|
+
|
|
51
|
+
logger.info("Skipped, files for catalog already exist")
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
def _existing_group_paths(self) -> Paths:
|
|
55
|
+
return {
|
|
56
|
+
a.value: self._storage.path(a.value)
|
|
57
|
+
for a in SALESFORCE_CATALOG_ASSETS
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
def extract_catalog(self, show_progress: bool = True) -> Paths:
|
|
61
|
+
"""
|
|
62
|
+
Extract the following catalog assets: tables and columns
|
|
63
|
+
and return the locations of the extracted data
|
|
64
|
+
"""
|
|
65
|
+
if not self._should_extract():
|
|
66
|
+
return self._existing_group_paths()
|
|
67
|
+
|
|
68
|
+
catalog_locations: Paths = dict()
|
|
69
|
+
|
|
70
|
+
tables = self._client.tables()
|
|
71
|
+
location = self._storage.put(WarehouseAsset.TABLE.value, tables)
|
|
72
|
+
catalog_locations[WarehouseAsset.TABLE.value] = location
|
|
73
|
+
logger.info(f"Extracted {len(tables)} tables to {location}")
|
|
74
|
+
|
|
75
|
+
table_names = [t["table_name"] for t in tables]
|
|
76
|
+
columns = self._client.columns(table_names, show_progress)
|
|
77
|
+
location = self._storage.put(WarehouseAsset.COLUMN.value, columns)
|
|
78
|
+
catalog_locations[WarehouseAsset.COLUMN.value] = location
|
|
79
|
+
logger.info(f"Extracted {len(columns)} columns to {location}")
|
|
80
|
+
return catalog_locations
|
|
81
|
+
|
|
82
|
+
def extract_role(self) -> Paths:
|
|
83
|
+
"""extract no users and return the empty file location"""
|
|
84
|
+
users: List[dict] = []
|
|
85
|
+
location = self._storage.put(WarehouseAsset.USER.value, users)
|
|
86
|
+
logger.info(f"Extracted {len(users)} users to {location}")
|
|
87
|
+
return {WarehouseAsset.USER.value: location}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def extract_all(**kwargs) -> None:
|
|
91
|
+
"""
|
|
92
|
+
Extract all assets from Salesforce and store the results in CSV files
|
|
93
|
+
"""
|
|
94
|
+
output_directory, skip_existing = common_args(kwargs)
|
|
95
|
+
|
|
96
|
+
client = SalesforceClient(credentials=to_credentials(kwargs))
|
|
97
|
+
storage = LocalStorage(directory=output_directory)
|
|
98
|
+
extractor = SalesforceExtractionProcessor(
|
|
99
|
+
client=client,
|
|
100
|
+
storage=storage,
|
|
101
|
+
skip_existing=skip_existing,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
extractor.extract_catalog()
|
|
105
|
+
extractor.extract_role()
|
|
106
|
+
|
|
107
|
+
write_summary(
|
|
108
|
+
output_directory,
|
|
109
|
+
storage.stored_at_ts,
|
|
110
|
+
client_name=client.name(),
|
|
111
|
+
)
|