castor-extractor 0.19.0__py3-none-any.whl → 0.19.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of castor-extractor might be problematic. Click here for more details.
- CHANGELOG.md +29 -2
- castor_extractor/file_checker/templates/generic_warehouse.py +1 -1
- castor_extractor/knowledge/notion/client/client.py +44 -80
- castor_extractor/knowledge/notion/client/client_test.py +9 -4
- castor_extractor/knowledge/notion/client/constants.py +1 -0
- castor_extractor/knowledge/notion/client/endpoints.py +1 -1
- castor_extractor/knowledge/notion/client/pagination.py +9 -5
- castor_extractor/quality/soda/assets.py +1 -1
- castor_extractor/quality/soda/client/client.py +30 -83
- castor_extractor/quality/soda/client/credentials.py +0 -11
- castor_extractor/quality/soda/client/endpoints.py +3 -6
- castor_extractor/quality/soda/client/pagination.py +25 -0
- castor_extractor/utils/__init__.py +13 -2
- castor_extractor/utils/client/__init__.py +14 -0
- castor_extractor/utils/client/api/__init__.py +5 -0
- castor_extractor/utils/client/api/auth.py +76 -0
- castor_extractor/utils/client/api/auth_test.py +49 -0
- castor_extractor/utils/client/api/client.py +153 -0
- castor_extractor/utils/client/api/client_test.py +47 -0
- castor_extractor/utils/client/api/pagination.py +83 -0
- castor_extractor/utils/client/api/pagination_test.py +51 -0
- castor_extractor/utils/{safe_request_test.py → client/api/safe_request_test.py} +4 -1
- castor_extractor/utils/client/api/utils.py +9 -0
- castor_extractor/utils/client/api/utils_test.py +16 -0
- castor_extractor/utils/collection.py +34 -2
- castor_extractor/utils/collection_test.py +17 -3
- castor_extractor/utils/pager/__init__.py +0 -1
- castor_extractor/utils/retry.py +44 -0
- castor_extractor/utils/retry_test.py +26 -1
- castor_extractor/utils/salesforce/client.py +44 -49
- castor_extractor/utils/salesforce/client_test.py +2 -2
- castor_extractor/utils/salesforce/pagination.py +33 -0
- castor_extractor/visualization/domo/client/client.py +10 -5
- castor_extractor/visualization/domo/client/credentials.py +1 -1
- castor_extractor/visualization/domo/client/endpoints.py +19 -7
- castor_extractor/visualization/looker/api/credentials.py +1 -1
- castor_extractor/visualization/metabase/client/api/client.py +26 -11
- castor_extractor/visualization/metabase/client/api/credentials.py +1 -1
- castor_extractor/visualization/metabase/client/db/credentials.py +1 -1
- castor_extractor/visualization/mode/client/credentials.py +1 -1
- castor_extractor/visualization/qlik/client/engine/credentials.py +1 -1
- castor_extractor/visualization/salesforce_reporting/client/rest.py +4 -3
- castor_extractor/visualization/sigma/client/client.py +106 -111
- castor_extractor/visualization/sigma/client/credentials.py +11 -1
- castor_extractor/visualization/sigma/client/endpoints.py +1 -1
- castor_extractor/visualization/sigma/client/pagination.py +22 -18
- castor_extractor/visualization/tableau/tests/unit/rest_api/auth_test.py +0 -1
- castor_extractor/visualization/tableau/tests/unit/rest_api/credentials_test.py +0 -3
- castor_extractor/visualization/tableau_revamp/assets.py +11 -0
- castor_extractor/visualization/tableau_revamp/client/client.py +71 -151
- castor_extractor/visualization/tableau_revamp/client/client_metadata_api.py +95 -0
- castor_extractor/visualization/tableau_revamp/client/client_rest_api.py +128 -0
- castor_extractor/visualization/tableau_revamp/client/client_tsc.py +66 -0
- castor_extractor/visualization/tableau_revamp/client/{tsc_fields.py → rest_fields.py} +15 -2
- castor_extractor/visualization/tableau_revamp/constants.py +0 -2
- castor_extractor/visualization/tableau_revamp/extract.py +5 -11
- castor_extractor/warehouse/databricks/api_client.py +239 -0
- castor_extractor/warehouse/databricks/api_client_test.py +15 -0
- castor_extractor/warehouse/databricks/client.py +37 -490
- castor_extractor/warehouse/databricks/client_test.py +1 -99
- castor_extractor/warehouse/databricks/endpoints.py +28 -0
- castor_extractor/warehouse/databricks/lineage.py +141 -0
- castor_extractor/warehouse/databricks/lineage_test.py +34 -0
- castor_extractor/warehouse/databricks/pagination.py +22 -0
- castor_extractor/warehouse/databricks/sql_client.py +90 -0
- castor_extractor/warehouse/databricks/utils.py +44 -1
- castor_extractor/warehouse/databricks/utils_test.py +58 -1
- castor_extractor/warehouse/mysql/client.py +0 -2
- castor_extractor/warehouse/salesforce/client.py +12 -59
- castor_extractor/warehouse/salesforce/pagination.py +34 -0
- castor_extractor/warehouse/sqlserver/client.py +0 -1
- castor_extractor-0.19.6.dist-info/METADATA +903 -0
- {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/RECORD +77 -60
- castor_extractor/utils/client/api.py +0 -87
- castor_extractor/utils/client/api_test.py +0 -24
- castor_extractor/utils/pager/pager_on_token.py +0 -52
- castor_extractor/utils/pager/pager_on_token_test.py +0 -73
- castor_extractor/visualization/sigma/client/client_test.py +0 -54
- castor_extractor-0.19.0.dist-info/METADATA +0 -207
- /castor_extractor/utils/{safe_request.py → client/api/safe_request.py} +0 -0
- {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/LICENCE +0 -0
- {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/WHEEL +0 -0
- {castor_extractor-0.19.0.dist-info → castor_extractor-0.19.6.dist-info}/entry_points.txt +0 -0
|
@@ -1,14 +1,7 @@
|
|
|
1
|
-
from datetime import date
|
|
2
1
|
from unittest.mock import Mock, patch
|
|
3
2
|
|
|
4
|
-
from freezegun import freeze_time
|
|
5
|
-
|
|
6
|
-
from ..abstract.time_filter import TimeFilter
|
|
7
3
|
from .client import (
|
|
8
4
|
DatabricksClient,
|
|
9
|
-
DatabricksCredentials,
|
|
10
|
-
LineageLinks,
|
|
11
|
-
_day_hour_to_epoch_ms,
|
|
12
5
|
)
|
|
13
6
|
from .test_constants import (
|
|
14
7
|
CLOSER_DATE,
|
|
@@ -18,74 +11,12 @@ from .test_constants import (
|
|
|
18
11
|
)
|
|
19
12
|
|
|
20
13
|
|
|
21
|
-
def test__day_hour_to_epoch_ms():
|
|
22
|
-
_day_hour_to_epoch_ms(date(2023, 2, 14), 14) == 1644847200000
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
@freeze_time("2023-7-4")
|
|
26
|
-
def test_DatabricksClient__hourly_time_filters():
|
|
27
|
-
credentials = DatabricksCredentials(
|
|
28
|
-
host="carthago",
|
|
29
|
-
token="delenda",
|
|
30
|
-
http_host="est",
|
|
31
|
-
)
|
|
32
|
-
client = DatabricksClient(credentials)
|
|
33
|
-
|
|
34
|
-
# default is yesterday
|
|
35
|
-
default_filters = [f for f in client._hourly_time_filters(None)]
|
|
36
|
-
|
|
37
|
-
assert len(default_filters) == 24 # number of hours in a day
|
|
38
|
-
|
|
39
|
-
first = default_filters[0]
|
|
40
|
-
start = first["filter_by"]["query_start_time_range"]["start_time_ms"]
|
|
41
|
-
last = default_filters[-1]
|
|
42
|
-
end = last["filter_by"]["query_start_time_range"]["end_time_ms"]
|
|
43
|
-
assert start == 1688342400000 # July 3, 2023 12:00:00 AM GMT
|
|
44
|
-
assert end == 1688428800000 # July 4, 2023 12:00:00 AM GMT
|
|
45
|
-
|
|
46
|
-
# custom time (from execution_date in DAG for example)
|
|
47
|
-
time_filter = TimeFilter(day=date(2020, 10, 15))
|
|
48
|
-
custom_filters = [f for f in client._hourly_time_filters(time_filter)]
|
|
49
|
-
|
|
50
|
-
assert len(custom_filters) == 24
|
|
51
|
-
|
|
52
|
-
first = custom_filters[0]
|
|
53
|
-
start = first["filter_by"]["query_start_time_range"]["start_time_ms"]
|
|
54
|
-
last = custom_filters[-1]
|
|
55
|
-
end = last["filter_by"]["query_start_time_range"]["end_time_ms"]
|
|
56
|
-
assert start == 1602720000000 # Oct 15, 2020 12:00:00 AM
|
|
57
|
-
assert end == 1602806400000 # Oct 16, 2020 12:00:00 AM
|
|
58
|
-
|
|
59
|
-
# hourly extraction: note that hour_min == hour_max
|
|
60
|
-
hourly = TimeFilter(day=date(2023, 4, 14), hour_min=4, hour_max=4)
|
|
61
|
-
hourly_filters = [f for f in client._hourly_time_filters(hourly)]
|
|
62
|
-
expected_hourly = [
|
|
63
|
-
{
|
|
64
|
-
"filter_by": {
|
|
65
|
-
"query_start_time_range": {
|
|
66
|
-
"end_time_ms": 1681448400000, # April 14, 2023 5:00:00 AM
|
|
67
|
-
"start_time_ms": 1681444800000, # April 14, 2023 4:00:00 AM
|
|
68
|
-
}
|
|
69
|
-
}
|
|
70
|
-
}
|
|
71
|
-
]
|
|
72
|
-
assert hourly_filters == expected_hourly
|
|
73
|
-
|
|
74
|
-
|
|
75
14
|
class MockDatabricksClient(DatabricksClient):
|
|
76
15
|
def __init__(self):
|
|
77
16
|
self._db_allowed = ["prd", "staging"]
|
|
78
17
|
self._db_blocked = ["dev"]
|
|
79
18
|
|
|
80
19
|
|
|
81
|
-
def test_DatabricksClient__keep_catalog():
|
|
82
|
-
client = MockDatabricksClient()
|
|
83
|
-
assert client._keep_catalog("prd")
|
|
84
|
-
assert client._keep_catalog("staging")
|
|
85
|
-
assert not client._keep_catalog("dev")
|
|
86
|
-
assert not client._keep_catalog("something_unknown")
|
|
87
|
-
|
|
88
|
-
|
|
89
20
|
def test_DatabricksClient__get_user_mapping():
|
|
90
21
|
client = MockDatabricksClient()
|
|
91
22
|
users = [
|
|
@@ -120,7 +51,7 @@ def test_DatabricksClient__match_table_with_user():
|
|
|
120
51
|
|
|
121
52
|
|
|
122
53
|
@patch(
|
|
123
|
-
"source.packages.extractor.castor_extractor.warehouse.databricks.client.
|
|
54
|
+
"source.packages.extractor.castor_extractor.warehouse.databricks.client.DatabricksAPIClient._get",
|
|
124
55
|
side_effect=TABLE_LINEAGE_SIDE_EFFECT,
|
|
125
56
|
)
|
|
126
57
|
def test_DatabricksClient_table_lineage(mock_get):
|
|
@@ -141,32 +72,3 @@ def test_DatabricksClient_table_lineage(mock_get):
|
|
|
141
72
|
}
|
|
142
73
|
assert expected_link_1 in lineage
|
|
143
74
|
assert expected_link_2 in lineage
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
def test_LineageLinks_add():
|
|
147
|
-
links = LineageLinks()
|
|
148
|
-
timestamped_link = ("parent", "child", None)
|
|
149
|
-
expected_key = ("parent", "child")
|
|
150
|
-
|
|
151
|
-
links.add(timestamped_link)
|
|
152
|
-
|
|
153
|
-
assert expected_key in links.lineage
|
|
154
|
-
assert links.lineage[expected_key] is None
|
|
155
|
-
|
|
156
|
-
# we replace None by an actual timestamp
|
|
157
|
-
timestamped_link = ("parent", "child", OLDER_DATE)
|
|
158
|
-
links.add(timestamped_link)
|
|
159
|
-
assert expected_key in links.lineage
|
|
160
|
-
assert links.lineage[expected_key] == OLDER_DATE
|
|
161
|
-
|
|
162
|
-
# we update with the more recent timestamp
|
|
163
|
-
timestamped_link = ("parent", "child", CLOSER_DATE)
|
|
164
|
-
links.add(timestamped_link)
|
|
165
|
-
assert expected_key in links.lineage
|
|
166
|
-
assert links.lineage[expected_key] == CLOSER_DATE
|
|
167
|
-
|
|
168
|
-
# we keep the more recent timestamp
|
|
169
|
-
timestamped_link = ("parent", "child", OLDER_DATE)
|
|
170
|
-
links.add(timestamped_link)
|
|
171
|
-
assert expected_key in links.lineage
|
|
172
|
-
assert links.lineage[expected_key] == CLOSER_DATE
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
class DatabricksEndpointFactory:
|
|
2
|
+
@classmethod
|
|
3
|
+
def tables(cls):
|
|
4
|
+
return "api/2.1/unity-catalog/tables"
|
|
5
|
+
|
|
6
|
+
@classmethod
|
|
7
|
+
def schemas(cls):
|
|
8
|
+
return "api/2.1/unity-catalog/schemas"
|
|
9
|
+
|
|
10
|
+
@classmethod
|
|
11
|
+
def databases(cls):
|
|
12
|
+
return "api/2.1/unity-catalog/catalogs"
|
|
13
|
+
|
|
14
|
+
@classmethod
|
|
15
|
+
def table_lineage(cls):
|
|
16
|
+
return "api/2.0/lineage-tracking/table-lineage"
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def column_lineage(cls):
|
|
20
|
+
return "api/2.0/lineage-tracking/column-lineage"
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def queries(cls):
|
|
24
|
+
return "api/2.0/sql/history/queries"
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def users(cls):
|
|
28
|
+
return "api/2.0/preview/scim/v2/Users"
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from typing import Dict, List, Set, Tuple, cast
|
|
2
|
+
|
|
3
|
+
from .types import Link, Ostr, OTimestampedLink, TimestampedLink
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LineageLinks:
|
|
7
|
+
"""
|
|
8
|
+
helper class that handles lineage deduplication and filtering
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
self.lineage: Dict[Link, Ostr] = dict()
|
|
13
|
+
|
|
14
|
+
def add(self, timestamped_link: TimestampedLink) -> None:
|
|
15
|
+
"""
|
|
16
|
+
keep the most recent lineage link, adding to `self.lineage`
|
|
17
|
+
"""
|
|
18
|
+
parent, child, timestamp = timestamped_link
|
|
19
|
+
link = (parent, child)
|
|
20
|
+
if not self.lineage.get(link):
|
|
21
|
+
self.lineage[link] = timestamp
|
|
22
|
+
return
|
|
23
|
+
|
|
24
|
+
if not timestamp:
|
|
25
|
+
return
|
|
26
|
+
# keep most recent link; cast for mypy
|
|
27
|
+
recent = max(cast(str, self.lineage[link]), cast(str, timestamp))
|
|
28
|
+
self.lineage[link] = recent
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _to_table_path(table: dict) -> Ostr:
|
|
32
|
+
if table.get("name"):
|
|
33
|
+
return f"{table['catalog_name']}.{table['schema_name']}.{table['name']}"
|
|
34
|
+
return None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _to_column_path(column: dict) -> Ostr:
|
|
38
|
+
if column.get("name"):
|
|
39
|
+
return f"{column['catalog_name']}.{column['schema_name']}.{column['table_name']}.{column['name']}"
|
|
40
|
+
return None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _link(path_from: Ostr, path_to: Ostr, timestamp: Ostr) -> OTimestampedLink:
|
|
44
|
+
"""exclude missing path and self-lineage"""
|
|
45
|
+
if (not path_from) or (not path_to):
|
|
46
|
+
return None
|
|
47
|
+
is_self_lineage = path_from.lower() == path_to.lower()
|
|
48
|
+
if is_self_lineage:
|
|
49
|
+
return None
|
|
50
|
+
return path_from, path_to, timestamp
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def single_table_lineage_links(
|
|
54
|
+
table_path: str, single_table_lineage: dict
|
|
55
|
+
) -> List[TimestampedLink]:
|
|
56
|
+
"""
|
|
57
|
+
process databricks lineage API response for a given table
|
|
58
|
+
returns a list of (parent, child, timestamp)
|
|
59
|
+
|
|
60
|
+
Note: in `upstreams` or `downstreams` we only care about `tableInfo`,
|
|
61
|
+
we could also have `notebookInfos` or `fileInfo`
|
|
62
|
+
"""
|
|
63
|
+
links: List[OTimestampedLink] = []
|
|
64
|
+
# add parent:
|
|
65
|
+
for link in single_table_lineage.get("upstreams", []):
|
|
66
|
+
parent = link.get("tableInfo", {})
|
|
67
|
+
parent_path = _to_table_path(parent)
|
|
68
|
+
timestamp: Ostr = parent.get("lineage_timestamp")
|
|
69
|
+
links.append(_link(parent_path, table_path, timestamp))
|
|
70
|
+
|
|
71
|
+
# add children:
|
|
72
|
+
for link in single_table_lineage.get("downstreams", []):
|
|
73
|
+
child = link.get("tableInfo", {})
|
|
74
|
+
child_path = _to_table_path(child)
|
|
75
|
+
timestamp = child.get("lineage_timestamp")
|
|
76
|
+
links.append(_link(table_path, child_path, timestamp))
|
|
77
|
+
|
|
78
|
+
return list(filter(None, links))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def single_column_lineage_links(
|
|
82
|
+
column_path: str, single_column_lineage: dict
|
|
83
|
+
) -> List[TimestampedLink]:
|
|
84
|
+
"""
|
|
85
|
+
process databricks lineage API response for a given table
|
|
86
|
+
returns a list of (parent, child, timestamp)
|
|
87
|
+
|
|
88
|
+
Note: in `upstreams` or `downstreams` we only care about `tableInfo`,
|
|
89
|
+
we could also have `notebookInfos` or `fileInfo`
|
|
90
|
+
"""
|
|
91
|
+
links: List[OTimestampedLink] = []
|
|
92
|
+
# add parent:
|
|
93
|
+
for link in single_column_lineage.get("upstream_cols", []):
|
|
94
|
+
parent_path = _to_column_path(link)
|
|
95
|
+
timestamp: Ostr = link.get("lineage_timestamp")
|
|
96
|
+
links.append(_link(parent_path, column_path, timestamp))
|
|
97
|
+
|
|
98
|
+
# add children:
|
|
99
|
+
for link in single_column_lineage.get("downstream_cols", []):
|
|
100
|
+
child_path = _to_column_path(link)
|
|
101
|
+
timestamp = link.get("lineage_timestamp")
|
|
102
|
+
links.append(_link(column_path, child_path, timestamp))
|
|
103
|
+
|
|
104
|
+
return list(filter(None, links))
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def paths_for_column_lineage(
|
|
108
|
+
tables: List[dict], columns: List[dict], table_lineage: List[dict]
|
|
109
|
+
) -> List[Tuple[str, str]]:
|
|
110
|
+
"""
|
|
111
|
+
helper providing a list of candidate columns to look lineage for:
|
|
112
|
+
we only look for column lineage where there is table lineage
|
|
113
|
+
"""
|
|
114
|
+
# mapping between table id and its path db.schema.table
|
|
115
|
+
# table["schema_id"] follows the pattern `db.schema`
|
|
116
|
+
mapping = {
|
|
117
|
+
table["id"]: ".".join([table["schema_id"], table["table_name"]])
|
|
118
|
+
for table in tables
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
tables_with_lineage: Set[str] = set()
|
|
122
|
+
for t in table_lineage:
|
|
123
|
+
tables_with_lineage.add(t["parent_path"])
|
|
124
|
+
tables_with_lineage.add(t["child_path"])
|
|
125
|
+
|
|
126
|
+
paths_to_return: List[Tuple[str, str]] = []
|
|
127
|
+
for column in columns:
|
|
128
|
+
table_path = mapping[column["table_id"]]
|
|
129
|
+
if table_path not in tables_with_lineage:
|
|
130
|
+
continue
|
|
131
|
+
column_ = (table_path, column["column_name"])
|
|
132
|
+
paths_to_return.append(column_)
|
|
133
|
+
|
|
134
|
+
return paths_to_return
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def deduplicate_lineage(lineages: List[TimestampedLink]) -> dict:
|
|
138
|
+
deduplicated_lineage = LineageLinks()
|
|
139
|
+
for timestamped_link in lineages:
|
|
140
|
+
deduplicated_lineage.add(timestamped_link)
|
|
141
|
+
return deduplicated_lineage.lineage
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from .lineage import LineageLinks
|
|
2
|
+
from .test_constants import (
|
|
3
|
+
CLOSER_DATE,
|
|
4
|
+
OLDER_DATE,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_LineageLinks_add():
|
|
9
|
+
links = LineageLinks()
|
|
10
|
+
timestamped_link = ("parent", "child", None)
|
|
11
|
+
expected_key = ("parent", "child")
|
|
12
|
+
|
|
13
|
+
links.add(timestamped_link)
|
|
14
|
+
|
|
15
|
+
assert expected_key in links.lineage
|
|
16
|
+
assert links.lineage[expected_key] is None
|
|
17
|
+
|
|
18
|
+
# we replace None by an actual timestamp
|
|
19
|
+
timestamped_link = ("parent", "child", OLDER_DATE)
|
|
20
|
+
links.add(timestamped_link)
|
|
21
|
+
assert expected_key in links.lineage
|
|
22
|
+
assert links.lineage[expected_key] == OLDER_DATE
|
|
23
|
+
|
|
24
|
+
# we update with the more recent timestamp
|
|
25
|
+
timestamped_link = ("parent", "child", CLOSER_DATE)
|
|
26
|
+
links.add(timestamped_link)
|
|
27
|
+
assert expected_key in links.lineage
|
|
28
|
+
assert links.lineage[expected_key] == CLOSER_DATE
|
|
29
|
+
|
|
30
|
+
# we keep the more recent timestamp
|
|
31
|
+
timestamped_link = ("parent", "child", OLDER_DATE)
|
|
32
|
+
links.add(timestamped_link)
|
|
33
|
+
assert expected_key in links.lineage
|
|
34
|
+
assert links.lineage[expected_key] == CLOSER_DATE
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import Field
|
|
4
|
+
|
|
5
|
+
from ...utils import PaginationModel
|
|
6
|
+
|
|
7
|
+
DATABRICKS_PAGE_SIZE = 100
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DatabricksPagination(PaginationModel):
|
|
11
|
+
next_page_token: Optional[str] = None
|
|
12
|
+
has_next_page: bool = False
|
|
13
|
+
res: List[dict] = Field(default_factory=list)
|
|
14
|
+
|
|
15
|
+
def is_last(self) -> bool:
|
|
16
|
+
return not (self.has_next_page and self.next_page_token)
|
|
17
|
+
|
|
18
|
+
def next_page_payload(self) -> dict:
|
|
19
|
+
return {"page_token": self.next_page_token}
|
|
20
|
+
|
|
21
|
+
def page_results(self) -> list:
|
|
22
|
+
return self.res
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from databricks import sql # type: ignore
|
|
7
|
+
|
|
8
|
+
from .credentials import DatabricksCredentials
|
|
9
|
+
from .format import TagMapping
|
|
10
|
+
from .utils import build_path, tag_label
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
_INFORMATION_SCHEMA_SQL = "SELECT * FROM system.information_schema"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TagEntity(Enum):
|
|
18
|
+
"""Entities that can be tagged in Databricks"""
|
|
19
|
+
|
|
20
|
+
COLUMN = "COLUMN"
|
|
21
|
+
TABLE = "TABLE"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DatabricksSQLClient:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
credentials: DatabricksCredentials,
|
|
28
|
+
has_table_tags: bool = False,
|
|
29
|
+
has_column_tags: bool = False,
|
|
30
|
+
):
|
|
31
|
+
self._http_path = credentials.http_path
|
|
32
|
+
self._has_table_tags = has_table_tags
|
|
33
|
+
self._has_column_tags = has_column_tags
|
|
34
|
+
self._host = credentials.host
|
|
35
|
+
self._token = credentials.token
|
|
36
|
+
|
|
37
|
+
def execute_sql(
|
|
38
|
+
self,
|
|
39
|
+
query: str,
|
|
40
|
+
params: Optional[dict] = None,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Execute a SQL query on Databricks system tables and return the results.
|
|
44
|
+
https://docs.databricks.com/en/dev-tools/python-sql-connector.html
|
|
45
|
+
|
|
46
|
+
//!\\ credentials.http_path is required in order to run SQL queries
|
|
47
|
+
"""
|
|
48
|
+
assert self._http_path, "HTTP_PATH is required to run SQL queries"
|
|
49
|
+
with sql.connect(
|
|
50
|
+
server_hostname=self._host,
|
|
51
|
+
http_path=self._http_path,
|
|
52
|
+
access_token=self._token,
|
|
53
|
+
) as connection:
|
|
54
|
+
with connection.cursor() as cursor:
|
|
55
|
+
cursor.execute(query, params)
|
|
56
|
+
return cursor.fetchall()
|
|
57
|
+
|
|
58
|
+
def _needs_extraction(self, entity: TagEntity) -> bool:
|
|
59
|
+
if entity == TagEntity.TABLE:
|
|
60
|
+
return self._has_table_tags
|
|
61
|
+
if entity == TagEntity.COLUMN:
|
|
62
|
+
return self._has_column_tags
|
|
63
|
+
raise AssertionError(f"Entity not supported: {entity}")
|
|
64
|
+
|
|
65
|
+
def get_tags_mapping(self, entity: TagEntity) -> TagMapping:
|
|
66
|
+
"""
|
|
67
|
+
Fetch tags of the given entity and build a mapping:
|
|
68
|
+
{ path: list[tags] }
|
|
69
|
+
|
|
70
|
+
https://docs.databricks.com/en/sql/language-manual/information-schema/table_tags.html
|
|
71
|
+
https://docs.databricks.com/en/sql/language-manual/information-schema/column_tags.html
|
|
72
|
+
"""
|
|
73
|
+
if not self._needs_extraction(entity):
|
|
74
|
+
# extracting tags require additional credentials (http_path)
|
|
75
|
+
return dict()
|
|
76
|
+
|
|
77
|
+
table = f"{entity.value.lower()}_tags"
|
|
78
|
+
query = f"{_INFORMATION_SCHEMA_SQL}.{table}"
|
|
79
|
+
result = self.execute_sql(query)
|
|
80
|
+
mapping = defaultdict(list)
|
|
81
|
+
for row in result:
|
|
82
|
+
dict_row = row.asDict()
|
|
83
|
+
keys = ["catalog_name", "schema_name", "table_name"]
|
|
84
|
+
if entity == TagEntity.COLUMN:
|
|
85
|
+
keys.append("column_name")
|
|
86
|
+
path = build_path(dict_row, keys)
|
|
87
|
+
label = tag_label(dict_row)
|
|
88
|
+
mapping[path].append(label)
|
|
89
|
+
|
|
90
|
+
return mapping
|
|
@@ -1,4 +1,16 @@
|
|
|
1
|
-
from
|
|
1
|
+
from datetime import date
|
|
2
|
+
from typing import Dict, Iterable, List, Optional
|
|
3
|
+
|
|
4
|
+
from ...utils import at_midnight
|
|
5
|
+
from ..abstract import TimeFilter
|
|
6
|
+
|
|
7
|
+
_DEFAULT_HOUR_MIN = 0
|
|
8
|
+
_DEFAULT_HOUR_MAX = 23
|
|
9
|
+
_NUM_HOURS_IN_A_DAY = 24
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _day_hour_to_epoch_ms(day: date, hour: int) -> int:
|
|
13
|
+
return int(at_midnight(day).timestamp() * 1000) + (hour * 3600 * 1000)
|
|
2
14
|
|
|
3
15
|
|
|
4
16
|
def build_path(
|
|
@@ -25,3 +37,34 @@ def tag_label(row: Dict) -> str:
|
|
|
25
37
|
if not tag_value:
|
|
26
38
|
return tag_name
|
|
27
39
|
return f"{tag_name}:{tag_value}"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _time_filter_payload(start_time_ms: int, end_time_ms: int) -> dict:
|
|
43
|
+
return {
|
|
44
|
+
"filter_by": {
|
|
45
|
+
"query_start_time_range": {
|
|
46
|
+
"end_time_ms": end_time_ms,
|
|
47
|
+
"start_time_ms": start_time_ms,
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def hourly_time_filters(time_filter: Optional[TimeFilter]) -> Iterable[dict]:
|
|
54
|
+
"""time filters to retrieve Databricks' queries: 1h duration each"""
|
|
55
|
+
# define an explicit time window
|
|
56
|
+
if not time_filter:
|
|
57
|
+
time_filter = TimeFilter.default()
|
|
58
|
+
|
|
59
|
+
assert time_filter # for mypy
|
|
60
|
+
|
|
61
|
+
hour_min = time_filter.hour_min
|
|
62
|
+
hour_max = time_filter.hour_max
|
|
63
|
+
day = time_filter.day
|
|
64
|
+
if hour_min is None or hour_max is None: # fallback to an entire day
|
|
65
|
+
hour_min, hour_max = _DEFAULT_HOUR_MIN, _DEFAULT_HOUR_MAX
|
|
66
|
+
|
|
67
|
+
for index in range(hour_min, min(hour_max + 1, _NUM_HOURS_IN_A_DAY)):
|
|
68
|
+
start_time_ms = _day_hour_to_epoch_ms(day, index)
|
|
69
|
+
end_time_ms = _day_hour_to_epoch_ms(day, index + 1)
|
|
70
|
+
yield _time_filter_payload(start_time_ms, end_time_ms)
|
|
@@ -1,4 +1,14 @@
|
|
|
1
|
-
from
|
|
1
|
+
from datetime import date
|
|
2
|
+
|
|
3
|
+
from freezegun import freeze_time
|
|
4
|
+
|
|
5
|
+
from ..abstract import TimeFilter
|
|
6
|
+
from .utils import (
|
|
7
|
+
_day_hour_to_epoch_ms,
|
|
8
|
+
build_path,
|
|
9
|
+
hourly_time_filters,
|
|
10
|
+
tag_label,
|
|
11
|
+
)
|
|
2
12
|
|
|
3
13
|
|
|
4
14
|
def test_build_path():
|
|
@@ -23,3 +33,50 @@ def test_tag_label():
|
|
|
23
33
|
"tag_value": "fou",
|
|
24
34
|
}
|
|
25
35
|
assert tag_label(row) == "fi:fou"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def test__day_hour_to_epoch_ms():
|
|
39
|
+
assert _day_hour_to_epoch_ms(date(2023, 2, 14), 14) == 1676383200000
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@freeze_time("2023-7-4")
|
|
43
|
+
def test_hourly_time_filters():
|
|
44
|
+
# default is yesterday
|
|
45
|
+
default_filters = [f for f in hourly_time_filters(None)]
|
|
46
|
+
|
|
47
|
+
assert len(default_filters) == 24 # number of hours in a day
|
|
48
|
+
|
|
49
|
+
first = default_filters[0]
|
|
50
|
+
start = first["filter_by"]["query_start_time_range"]["start_time_ms"]
|
|
51
|
+
last = default_filters[-1]
|
|
52
|
+
end = last["filter_by"]["query_start_time_range"]["end_time_ms"]
|
|
53
|
+
assert start == 1688342400000 # July 3, 2023 12:00:00 AM GMT
|
|
54
|
+
assert end == 1688428800000 # July 4, 2023 12:00:00 AM GMT
|
|
55
|
+
|
|
56
|
+
# custom time (from execution_date in DAG for example)
|
|
57
|
+
time_filter = TimeFilter(day=date(2020, 10, 15))
|
|
58
|
+
custom_filters = [f for f in hourly_time_filters(time_filter)]
|
|
59
|
+
|
|
60
|
+
assert len(custom_filters) == 24
|
|
61
|
+
|
|
62
|
+
first = custom_filters[0]
|
|
63
|
+
start = first["filter_by"]["query_start_time_range"]["start_time_ms"]
|
|
64
|
+
last = custom_filters[-1]
|
|
65
|
+
end = last["filter_by"]["query_start_time_range"]["end_time_ms"]
|
|
66
|
+
assert start == 1602720000000 # Oct 15, 2020 12:00:00 AM
|
|
67
|
+
assert end == 1602806400000 # Oct 16, 2020 12:00:00 AM
|
|
68
|
+
|
|
69
|
+
# hourly extraction: note that hour_min == hour_max
|
|
70
|
+
hourly = TimeFilter(day=date(2023, 4, 14), hour_min=4, hour_max=4)
|
|
71
|
+
hourly_filters = [f for f in hourly_time_filters(hourly)]
|
|
72
|
+
expected_hourly = [
|
|
73
|
+
{
|
|
74
|
+
"filter_by": {
|
|
75
|
+
"query_start_time_range": {
|
|
76
|
+
"end_time_ms": 1681448400000, # April 14, 2023 5:00:00 AM
|
|
77
|
+
"start_time_ms": 1681444800000, # April 14, 2023 4:00:00 AM
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
]
|
|
82
|
+
assert hourly_filters == expected_hourly
|
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Dict, List, Optional, Tuple
|
|
3
4
|
|
|
4
5
|
from tqdm import tqdm # type: ignore
|
|
5
6
|
|
|
7
|
+
from ...utils import fetch_all_pages
|
|
6
8
|
from ...utils.salesforce import SalesforceBaseClient, SalesforceCredentials
|
|
7
9
|
from .format import SalesforceFormatter
|
|
10
|
+
from .pagination import SalesforceSQLPagination, format_sobject_query
|
|
8
11
|
from .soql import (
|
|
9
12
|
DESCRIPTION_QUERY_TPL,
|
|
10
13
|
SOBJECT_FIELDS_QUERY_TPL,
|
|
11
|
-
SOBJECTS_QUERY_TPL,
|
|
12
14
|
)
|
|
13
15
|
|
|
14
16
|
logger = logging.getLogger(__name__)
|
|
@@ -19,9 +21,6 @@ class SalesforceClient(SalesforceBaseClient):
|
|
|
19
21
|
Salesforce API client to extract sobjects
|
|
20
22
|
"""
|
|
21
23
|
|
|
22
|
-
# Implicit (hard-coded in Salesforce) limitation when using SOQL of 2,000 rows
|
|
23
|
-
LIMIT_RECORDS_PER_PAGE = 2000
|
|
24
|
-
|
|
25
24
|
def __init__(self, credentials: SalesforceCredentials):
|
|
26
25
|
super().__init__(credentials)
|
|
27
26
|
self.formatter = SalesforceFormatter()
|
|
@@ -30,74 +29,28 @@ class SalesforceClient(SalesforceBaseClient):
|
|
|
30
29
|
def name() -> str:
|
|
31
30
|
return "Salesforce"
|
|
32
31
|
|
|
33
|
-
def _format_query(self, query_template: str, start_durable_id: str) -> str:
|
|
34
|
-
return query_template.format(
|
|
35
|
-
start_durable_id=start_durable_id,
|
|
36
|
-
limit=self.LIMIT_RECORDS_PER_PAGE,
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
def _next_records(
|
|
40
|
-
self, url: str, query_template: str, start_durable_id: str = "0000"
|
|
41
|
-
) -> List[dict]:
|
|
42
|
-
query = self._format_query(
|
|
43
|
-
query_template, start_durable_id=start_durable_id
|
|
44
|
-
)
|
|
45
|
-
records, _ = self._call(
|
|
46
|
-
url, params={"q": query}, processor=self._query_processor
|
|
47
|
-
)
|
|
48
|
-
return records
|
|
49
|
-
|
|
50
|
-
def _is_last_page(self, records: List[dict]) -> bool:
|
|
51
|
-
return len(records) < self.LIMIT_RECORDS_PER_PAGE
|
|
52
|
-
|
|
53
|
-
def _should_query_next_page(
|
|
54
|
-
self, records: List[dict], page_number: int
|
|
55
|
-
) -> bool:
|
|
56
|
-
return not (
|
|
57
|
-
self._is_last_page(records)
|
|
58
|
-
or self._has_reached_pagination_limit(page_number)
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
def _query_all(self, query_template: str) -> Iterator[dict]:
|
|
62
|
-
"""
|
|
63
|
-
Run a SOQL query over salesforce API
|
|
64
|
-
|
|
65
|
-
Note, pagination is performed via a LIMIT in the SOQL query and requires
|
|
66
|
-
that ids are sorted. The SOQL query must support `limit` and
|
|
67
|
-
`start_durable_id` as parameters.
|
|
68
|
-
"""
|
|
69
|
-
url = self.query_url
|
|
70
|
-
logger.info("querying page 0")
|
|
71
|
-
records = self._next_records(url, query_template)
|
|
72
|
-
yield from records
|
|
73
|
-
|
|
74
|
-
page_count = 1
|
|
75
|
-
while self._should_query_next_page(records, page_count):
|
|
76
|
-
logger.info(f"querying page {page_count}")
|
|
77
|
-
last_durable_id = records[-1]["DurableId"]
|
|
78
|
-
records = self._next_records(
|
|
79
|
-
url, query_template, start_durable_id=last_durable_id
|
|
80
|
-
)
|
|
81
|
-
yield from records
|
|
82
|
-
page_count += 1
|
|
83
|
-
|
|
84
32
|
def fetch_sobjects(self) -> List[dict]:
|
|
85
33
|
"""Fetch all sobjects"""
|
|
86
34
|
logger.info("Extracting sobjects")
|
|
87
|
-
|
|
35
|
+
query = format_sobject_query()
|
|
36
|
+
request_ = partial(
|
|
37
|
+
self._get, endpoint=self.query_endpoint, params={"q": query}
|
|
38
|
+
)
|
|
39
|
+
results = fetch_all_pages(request_, SalesforceSQLPagination)
|
|
40
|
+
return list(results)
|
|
88
41
|
|
|
89
42
|
def fetch_fields(self, sobject_name: str) -> List[dict]:
|
|
90
43
|
"""Fetches fields of a given sobject"""
|
|
91
44
|
query = SOBJECT_FIELDS_QUERY_TPL.format(
|
|
92
45
|
entity_definition_id=sobject_name
|
|
93
46
|
)
|
|
94
|
-
response = self.
|
|
47
|
+
response = self._get(self.tooling_endpoint, params={"q": query})
|
|
95
48
|
return response["records"]
|
|
96
49
|
|
|
97
50
|
def fetch_description(self, table_name: str) -> Optional[str]:
|
|
98
51
|
"""Retrieve description of a table"""
|
|
99
52
|
query = DESCRIPTION_QUERY_TPL.format(table_name=table_name)
|
|
100
|
-
response = self.
|
|
53
|
+
response = self._get(self.tooling_endpoint, params={"q": query})
|
|
101
54
|
if not response["records"]:
|
|
102
55
|
return None
|
|
103
56
|
return response["records"][0]["Description"]
|