castor-extractor 0.18.5__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of castor-extractor might be problematic. Click here for more details.
- CHANGELOG.md +48 -1
- castor_extractor/commands/extract_looker.py +3 -3
- castor_extractor/commands/extract_metabase_api.py +1 -1
- castor_extractor/commands/extract_metabase_db.py +1 -1
- castor_extractor/commands/extract_notion.py +16 -0
- castor_extractor/commands/file_check.py +5 -2
- castor_extractor/commands/upload.py +5 -3
- castor_extractor/knowledge/__init__.py +0 -0
- castor_extractor/knowledge/notion/__init__.py +3 -0
- castor_extractor/knowledge/notion/assets.py +9 -0
- castor_extractor/knowledge/notion/client/__init__.py +2 -0
- castor_extractor/knowledge/notion/client/client.py +145 -0
- castor_extractor/knowledge/notion/client/client_test.py +67 -0
- castor_extractor/knowledge/notion/client/constants.py +3 -0
- castor_extractor/knowledge/notion/client/credentials.py +16 -0
- castor_extractor/knowledge/notion/client/endpoints.py +18 -0
- castor_extractor/knowledge/notion/client/pagination.py +16 -0
- castor_extractor/knowledge/notion/extract.py +59 -0
- castor_extractor/quality/__init__.py +0 -0
- castor_extractor/quality/soda/__init__.py +2 -0
- castor_extractor/quality/soda/assets.py +8 -0
- castor_extractor/quality/soda/client/__init__.py +1 -0
- castor_extractor/quality/soda/client/client.py +99 -0
- castor_extractor/quality/soda/client/credentials.py +28 -0
- castor_extractor/quality/soda/client/endpoints.py +13 -0
- castor_extractor/types.py +1 -3
- castor_extractor/uploader/upload.py +0 -1
- castor_extractor/utils/__init__.py +2 -0
- castor_extractor/utils/argument_parser_test.py +0 -1
- castor_extractor/utils/client/api.py +29 -11
- castor_extractor/utils/client/api_test.py +9 -1
- castor_extractor/utils/object_test.py +1 -1
- castor_extractor/utils/pager/pager.py +1 -1
- castor_extractor/utils/pager/pager_on_id.py +11 -6
- castor_extractor/utils/safe_request.py +5 -3
- castor_extractor/utils/safe_request_test.py +1 -3
- castor_extractor/utils/string_test.py +1 -1
- castor_extractor/utils/time.py +11 -0
- castor_extractor/visualization/domo/client/client.py +2 -3
- castor_extractor/visualization/looker/api/client.py +35 -0
- castor_extractor/visualization/looker/api/extraction_parameters.py +2 -1
- castor_extractor/visualization/looker/extract.py +2 -2
- castor_extractor/visualization/metabase/assets.py +3 -1
- castor_extractor/visualization/metabase/extract.py +20 -8
- castor_extractor/visualization/mode/client/client.py +1 -1
- castor_extractor/visualization/powerbi/client/constants.py +1 -1
- castor_extractor/visualization/powerbi/client/rest.py +5 -15
- castor_extractor/visualization/qlik/client/engine/client.py +36 -5
- castor_extractor/visualization/qlik/client/engine/constants.py +1 -0
- castor_extractor/visualization/qlik/client/engine/error.py +18 -1
- castor_extractor/visualization/salesforce_reporting/client/soql.py +3 -1
- castor_extractor/visualization/tableau/extract.py +40 -16
- castor_extractor/visualization/tableau_revamp/client/client.py +2 -5
- castor_extractor/visualization/tableau_revamp/extract.py +3 -2
- castor_extractor/warehouse/bigquery/client.py +41 -6
- castor_extractor/warehouse/bigquery/extract.py +1 -0
- castor_extractor/warehouse/bigquery/query.py +23 -9
- castor_extractor/warehouse/bigquery/types.py +1 -2
- castor_extractor/warehouse/databricks/client.py +54 -35
- castor_extractor/warehouse/databricks/client_test.py +44 -31
- castor_extractor/warehouse/salesforce/client.py +28 -3
- castor_extractor/warehouse/salesforce/format.py +1 -1
- castor_extractor/warehouse/salesforce/format_test.py +1 -2
- castor_extractor/warehouse/salesforce/soql.py +6 -1
- {castor_extractor-0.18.5.dist-info → castor_extractor-0.19.0.dist-info}/METADATA +4 -4
- {castor_extractor-0.18.5.dist-info → castor_extractor-0.19.0.dist-info}/RECORD +69 -50
- {castor_extractor-0.18.5.dist-info → castor_extractor-0.19.0.dist-info}/entry_points.txt +1 -0
- {castor_extractor-0.18.5.dist-info → castor_extractor-0.19.0.dist-info}/LICENCE +0 -0
- {castor_extractor-0.18.5.dist-info → castor_extractor-0.19.0.dist-info}/WHEEL +0 -0
|
@@ -26,46 +26,70 @@ def iterate_all_data(
|
|
|
26
26
|
yield TableauAsset.USER, deep_serialize(client.fetch(TableauAsset.USER))
|
|
27
27
|
|
|
28
28
|
logger.info("Extracting WORKBOOK from Tableau API")
|
|
29
|
-
yield
|
|
30
|
-
|
|
29
|
+
yield (
|
|
30
|
+
TableauAsset.WORKBOOK,
|
|
31
|
+
deep_serialize(
|
|
32
|
+
client.fetch(TableauAsset.WORKBOOK),
|
|
33
|
+
),
|
|
31
34
|
)
|
|
32
35
|
|
|
33
36
|
logger.info("Extracting DASHBOARD from Tableau API")
|
|
34
|
-
yield
|
|
35
|
-
|
|
37
|
+
yield (
|
|
38
|
+
TableauAsset.DASHBOARD,
|
|
39
|
+
deep_serialize(
|
|
40
|
+
client.fetch(TableauAsset.DASHBOARD),
|
|
41
|
+
),
|
|
36
42
|
)
|
|
37
43
|
|
|
38
44
|
logger.info("Extracting PUBLISHED DATASOURCE from Tableau API")
|
|
39
|
-
yield
|
|
40
|
-
|
|
45
|
+
yield (
|
|
46
|
+
TableauAsset.PUBLISHED_DATASOURCE,
|
|
47
|
+
deep_serialize(
|
|
48
|
+
client.fetch(TableauAsset.PUBLISHED_DATASOURCE),
|
|
49
|
+
),
|
|
41
50
|
)
|
|
42
51
|
|
|
43
52
|
logger.info("Extracting PROJECT from Tableau API")
|
|
44
|
-
yield
|
|
45
|
-
|
|
53
|
+
yield (
|
|
54
|
+
TableauAsset.PROJECT,
|
|
55
|
+
deep_serialize(
|
|
56
|
+
client.fetch(TableauAsset.PROJECT),
|
|
57
|
+
),
|
|
46
58
|
)
|
|
47
59
|
|
|
48
60
|
logger.info("Extracting USAGE from Tableau API")
|
|
49
61
|
yield TableauAsset.USAGE, deep_serialize(client.fetch(TableauAsset.USAGE))
|
|
50
62
|
|
|
51
63
|
logger.info("Extracting WORKBOOK_TO_DATASOURCE from Tableau API")
|
|
52
|
-
yield
|
|
53
|
-
|
|
64
|
+
yield (
|
|
65
|
+
TableauAsset.WORKBOOK_TO_DATASOURCE,
|
|
66
|
+
deep_serialize(
|
|
67
|
+
client.fetch(TableauAsset.WORKBOOK_TO_DATASOURCE),
|
|
68
|
+
),
|
|
54
69
|
)
|
|
55
70
|
|
|
56
71
|
logger.info("Extracting DATASOURCE from Tableau API")
|
|
57
|
-
yield
|
|
58
|
-
|
|
72
|
+
yield (
|
|
73
|
+
TableauAsset.DATASOURCE,
|
|
74
|
+
deep_serialize(
|
|
75
|
+
client.fetch(TableauAsset.DATASOURCE),
|
|
76
|
+
),
|
|
59
77
|
)
|
|
60
78
|
|
|
61
79
|
logger.info("Extracting CUSTOM_SQL_TABLE from Tableau API")
|
|
62
|
-
yield
|
|
63
|
-
|
|
80
|
+
yield (
|
|
81
|
+
TableauAsset.CUSTOM_SQL_TABLE,
|
|
82
|
+
deep_serialize(
|
|
83
|
+
client.fetch(TableauAsset.CUSTOM_SQL_TABLE),
|
|
84
|
+
),
|
|
64
85
|
)
|
|
65
86
|
|
|
66
87
|
logger.info("Extracting CUSTOM_SQL_QUERY from Tableau API")
|
|
67
|
-
yield
|
|
68
|
-
|
|
88
|
+
yield (
|
|
89
|
+
TableauAsset.CUSTOM_SQL_QUERY,
|
|
90
|
+
deep_serialize(
|
|
91
|
+
client.fetch(TableauAsset.CUSTOM_SQL_QUERY),
|
|
92
|
+
),
|
|
69
93
|
)
|
|
70
94
|
|
|
71
95
|
logger.info("Extracting FIELD from Tableau API")
|
|
@@ -210,7 +210,6 @@ class TableauRevampClient:
|
|
|
210
210
|
self,
|
|
211
211
|
asset: TableauRevampAsset,
|
|
212
212
|
) -> SerializedAsset:
|
|
213
|
-
|
|
214
213
|
if asset == TableauRevampAsset.DATASOURCE:
|
|
215
214
|
data = TSC.Pager(self._server.datasources)
|
|
216
215
|
|
|
@@ -280,13 +279,11 @@ class TableauRevampClient:
|
|
|
280
279
|
|
|
281
280
|
return _enrich_workbooks_with_tsc(workbooks, workbook_projects)
|
|
282
281
|
|
|
283
|
-
def fetch(
|
|
284
|
-
self,
|
|
285
|
-
asset: TableauRevampAsset,
|
|
286
|
-
) -> SerializedAsset:
|
|
282
|
+
def fetch(self, asset: TableauRevampAsset) -> SerializedAsset:
|
|
287
283
|
"""
|
|
288
284
|
Extract the given Tableau Asset
|
|
289
285
|
"""
|
|
286
|
+
|
|
290
287
|
if asset == TableauRevampAsset.DATASOURCE:
|
|
291
288
|
# both APIs are required to extract datasources
|
|
292
289
|
return self._fetch_datasources()
|
|
@@ -23,8 +23,9 @@ def iterate_all_data(
|
|
|
23
23
|
"""Iterate over the extracted Data from Tableau"""
|
|
24
24
|
|
|
25
25
|
logger.info("Extracting USER from Tableau API")
|
|
26
|
-
yield
|
|
27
|
-
|
|
26
|
+
yield (
|
|
27
|
+
TableauRevampAsset.USER,
|
|
28
|
+
deep_serialize(client.fetch(TableauRevampAsset.USER)),
|
|
28
29
|
)
|
|
29
30
|
|
|
30
31
|
|
|
@@ -1,13 +1,14 @@
|
|
|
1
|
+
import itertools
|
|
1
2
|
import logging
|
|
2
|
-
from typing import List, Optional, Set
|
|
3
|
+
from typing import List, Optional, Set
|
|
3
4
|
|
|
4
|
-
from google.api_core.exceptions import Forbidden
|
|
5
|
-
from google.api_core.page_iterator import Iterator as PageIterator
|
|
5
|
+
from google.api_core.exceptions import Forbidden # type: ignore
|
|
6
6
|
from google.cloud.bigquery import Client as GoogleCloudClient # type: ignore
|
|
7
7
|
from google.cloud.bigquery.dataset import Dataset # type: ignore
|
|
8
8
|
from google.oauth2.service_account import Credentials # type: ignore
|
|
9
9
|
|
|
10
10
|
from ...utils import SqlalchemyClient, retry
|
|
11
|
+
from .types import SetTwoString
|
|
11
12
|
|
|
12
13
|
logger = logging.getLogger(__name__)
|
|
13
14
|
|
|
@@ -117,16 +118,50 @@ class BigQueryClient(SqlalchemyClient):
|
|
|
117
118
|
]
|
|
118
119
|
return self._projects
|
|
119
120
|
|
|
120
|
-
def get_regions(self) ->
|
|
121
|
+
def get_regions(self) -> SetTwoString:
|
|
121
122
|
"""
|
|
122
|
-
Returns
|
|
123
|
+
Returns (project_id, region) available for the given GCP client
|
|
124
|
+
- Loops trough projects -> datasets -> region
|
|
125
|
+
- Returns distinct values
|
|
126
|
+
Example:
|
|
127
|
+
project_A
|
|
128
|
+
-> dataset_1:US
|
|
129
|
+
project_B
|
|
130
|
+
-> empty
|
|
131
|
+
project_C
|
|
132
|
+
-> dataset_2:EU
|
|
133
|
+
-> dataset_3:EU
|
|
134
|
+
Will return:
|
|
135
|
+
{ (p_A, US), (p_C, EU) }
|
|
123
136
|
"""
|
|
124
137
|
return {
|
|
125
138
|
(ds.project, ds._properties["location"])
|
|
126
139
|
for ds in self._list_datasets()
|
|
127
140
|
}
|
|
128
141
|
|
|
129
|
-
def
|
|
142
|
+
def get_extended_regions(self) -> SetTwoString:
|
|
143
|
+
"""
|
|
144
|
+
Returns all combinations of (project_id, region) for the given client
|
|
145
|
+
- Fetch all projects
|
|
146
|
+
- Fetch all regions (cross projects)
|
|
147
|
+
- Returns a combination of the two lists
|
|
148
|
+
Example:
|
|
149
|
+
project_A
|
|
150
|
+
-> dataset_1:US
|
|
151
|
+
project_B
|
|
152
|
+
-> empty
|
|
153
|
+
project_C
|
|
154
|
+
-> dataset_2:EU
|
|
155
|
+
-> dataset_3:EU
|
|
156
|
+
Will return:
|
|
157
|
+
{ (p_A, EU), (p_A, US), (p_B, EU), (p_B, US), (p_C, EU), (p_C, US) }
|
|
158
|
+
"""
|
|
159
|
+
projects = self.get_projects()
|
|
160
|
+
regions = {ds._properties["location"] for ds in self._list_datasets()}
|
|
161
|
+
combinations = itertools.product(projects, regions)
|
|
162
|
+
return set(combinations)
|
|
163
|
+
|
|
164
|
+
def get_datasets(self) -> SetTwoString:
|
|
130
165
|
"""
|
|
131
166
|
Returns distinct (project_id, dataset_id) available for the given GCP client
|
|
132
167
|
"""
|
|
@@ -2,18 +2,16 @@ import logging
|
|
|
2
2
|
from typing import List, Optional
|
|
3
3
|
|
|
4
4
|
from ..abstract import (
|
|
5
|
-
QUERIES_DIR,
|
|
6
5
|
AbstractQueryBuilder,
|
|
7
6
|
ExtractionQuery,
|
|
8
7
|
TimeFilter,
|
|
9
8
|
WarehouseAsset,
|
|
10
9
|
)
|
|
11
|
-
|
|
12
|
-
# Those queries must be formatted with {region}
|
|
13
|
-
from .types import IterTwoString
|
|
10
|
+
from .types import SetTwoString
|
|
14
11
|
|
|
15
12
|
logger = logging.getLogger(__name__)
|
|
16
13
|
|
|
14
|
+
# Those queries must be formatted with {region}
|
|
17
15
|
REGION_REQUIRED = (
|
|
18
16
|
WarehouseAsset.COLUMN,
|
|
19
17
|
WarehouseAsset.DATABASE,
|
|
@@ -23,6 +21,11 @@ REGION_REQUIRED = (
|
|
|
23
21
|
WarehouseAsset.USER,
|
|
24
22
|
)
|
|
25
23
|
|
|
24
|
+
# Some clients use empty projects (no datasets) to run their SQL queries
|
|
25
|
+
# The extended regions is a combination of all regions with all projects
|
|
26
|
+
# It allows to extract those queries which were left apart before
|
|
27
|
+
EXTENDED_REGION_REQUIRED = (WarehouseAsset.QUERY,)
|
|
28
|
+
|
|
26
29
|
# Those queries must be formatted with {dataset}
|
|
27
30
|
DATASET_REQUIRED = (WarehouseAsset.VIEW_DDL,)
|
|
28
31
|
|
|
@@ -38,7 +41,7 @@ SHARDED_ASSETS = (WarehouseAsset.TABLE, WarehouseAsset.COLUMN)
|
|
|
38
41
|
SHARDED_FILE_PATH = "cte/sharded.sql"
|
|
39
42
|
|
|
40
43
|
|
|
41
|
-
def _database_formatted(datasets:
|
|
44
|
+
def _database_formatted(datasets: SetTwoString) -> str:
|
|
42
45
|
databases = {db for _, db in datasets}
|
|
43
46
|
if not databases:
|
|
44
47
|
# when no datasets are provided condition should pass
|
|
@@ -55,10 +58,11 @@ class BigQueryQueryBuilder(AbstractQueryBuilder):
|
|
|
55
58
|
|
|
56
59
|
def __init__(
|
|
57
60
|
self,
|
|
58
|
-
regions:
|
|
59
|
-
datasets:
|
|
61
|
+
regions: SetTwoString,
|
|
62
|
+
datasets: SetTwoString,
|
|
60
63
|
time_filter: Optional[TimeFilter] = None,
|
|
61
64
|
sync_tags: Optional[bool] = False,
|
|
65
|
+
extended_regions: Optional[SetTwoString] = None,
|
|
62
66
|
):
|
|
63
67
|
super().__init__(
|
|
64
68
|
time_filter=time_filter,
|
|
@@ -67,6 +71,7 @@ class BigQueryQueryBuilder(AbstractQueryBuilder):
|
|
|
67
71
|
self._regions = regions
|
|
68
72
|
self._datasets = datasets
|
|
69
73
|
self._sync_tags = sync_tags
|
|
74
|
+
self._extended_regions = extended_regions or regions
|
|
70
75
|
|
|
71
76
|
@staticmethod
|
|
72
77
|
def _format(query: ExtractionQuery, values: dict) -> ExtractionQuery:
|
|
@@ -97,6 +102,13 @@ class BigQueryQueryBuilder(AbstractQueryBuilder):
|
|
|
97
102
|
sharded_statement = self._load_from_file(SHARDED_FILE_PATH)
|
|
98
103
|
return statement.format(sharded_statement=sharded_statement)
|
|
99
104
|
|
|
105
|
+
def _get_regions(self, asset: WarehouseAsset) -> SetTwoString:
|
|
106
|
+
return (
|
|
107
|
+
self._extended_regions
|
|
108
|
+
if asset in EXTENDED_REGION_REQUIRED
|
|
109
|
+
else self._regions
|
|
110
|
+
)
|
|
111
|
+
|
|
100
112
|
def build(self, asset: WarehouseAsset) -> List[ExtractionQuery]:
|
|
101
113
|
"""
|
|
102
114
|
It would be easier to stitch data directly in the query statement (UNION ALL).
|
|
@@ -110,12 +122,14 @@ class BigQueryQueryBuilder(AbstractQueryBuilder):
|
|
|
110
122
|
query = super().build_default(asset)
|
|
111
123
|
|
|
112
124
|
if asset in REGION_REQUIRED:
|
|
125
|
+
regions = self._get_regions(asset)
|
|
126
|
+
|
|
113
127
|
logger.info(
|
|
114
|
-
f"\tWill run queries with following region params: {
|
|
128
|
+
f"\tWill run queries with following region params: {regions}",
|
|
115
129
|
)
|
|
116
130
|
return [
|
|
117
131
|
self._format(query, {"project": project, "region": region})
|
|
118
|
-
for project, region in
|
|
132
|
+
for project, region in regions
|
|
119
133
|
]
|
|
120
134
|
|
|
121
135
|
if asset in DATASET_REQUIRED:
|
|
@@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|
|
4
4
|
from datetime import date
|
|
5
5
|
from enum import Enum
|
|
6
6
|
from functools import partial
|
|
7
|
-
from typing import Any, Dict, List, Optional, Set, Tuple, cast
|
|
7
|
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, cast
|
|
8
8
|
|
|
9
9
|
import requests
|
|
10
10
|
from databricks import sql # type: ignore
|
|
@@ -28,9 +28,13 @@ from .utils import build_path, tag_label
|
|
|
28
28
|
|
|
29
29
|
logger = logging.getLogger(__name__)
|
|
30
30
|
|
|
31
|
-
_DATABRICKS_CLIENT_TIMEOUT =
|
|
31
|
+
_DATABRICKS_CLIENT_TIMEOUT = 90
|
|
32
|
+
_DEFAULT_HOUR_MIN = 0
|
|
33
|
+
_DEFAULT_HOUR_MAX = 23
|
|
32
34
|
_MAX_NUMBER_OF_LINEAGE_ERRORS = 1000
|
|
35
|
+
_MAX_NUMBER_OF_QUERY_ERRORS = 1000
|
|
33
36
|
_MAX_THREADS = 10
|
|
37
|
+
_NUM_HOURS_IN_A_DAY = 24
|
|
34
38
|
_RETRY_ATTEMPTS = 3
|
|
35
39
|
_RETRY_BASE_MS = 1000
|
|
36
40
|
_RETRY_EXCEPTIONS = [
|
|
@@ -40,7 +44,8 @@ _WORKSPACE_ID_HEADER = "X-Databricks-Org-Id"
|
|
|
40
44
|
|
|
41
45
|
_INFORMATION_SCHEMA_SQL = "SELECT * FROM system.information_schema"
|
|
42
46
|
|
|
43
|
-
|
|
47
|
+
safe_lineage_params = SafeMode((BaseException,), _MAX_NUMBER_OF_LINEAGE_ERRORS)
|
|
48
|
+
safe_query_params = SafeMode((BaseException,), _MAX_NUMBER_OF_QUERY_ERRORS)
|
|
44
49
|
|
|
45
50
|
|
|
46
51
|
class TagEntity(Enum):
|
|
@@ -112,7 +117,7 @@ class DatabricksClient(APIClient):
|
|
|
112
117
|
Execute a SQL query on Databricks system tables and return the results.
|
|
113
118
|
https://docs.databricks.com/en/dev-tools/python-sql-connector.html
|
|
114
119
|
|
|
115
|
-
|
|
120
|
+
//!\\ credentials.http_path is required in order to run SQL queries
|
|
116
121
|
"""
|
|
117
122
|
assert self._http_path, "HTTP_PATH is required to run SQL queries"
|
|
118
123
|
with sql.connect(
|
|
@@ -261,7 +266,6 @@ class DatabricksClient(APIClient):
|
|
|
261
266
|
table_tags = self._get_tags_mapping(TagEntity.TABLE)
|
|
262
267
|
column_tags = self._get_tags_mapping(TagEntity.COLUMN)
|
|
263
268
|
for schema in schemas:
|
|
264
|
-
|
|
265
269
|
t_to_add, c_to_add = self._tables_columns_of_schema(
|
|
266
270
|
schema=schema,
|
|
267
271
|
table_tags=table_tags,
|
|
@@ -325,7 +329,7 @@ class DatabricksClient(APIClient):
|
|
|
325
329
|
|
|
326
330
|
return list(filter(None, links))
|
|
327
331
|
|
|
328
|
-
@safe_mode(
|
|
332
|
+
@safe_mode(safe_lineage_params, lambda: [])
|
|
329
333
|
@retry(
|
|
330
334
|
exceptions=_RETRY_EXCEPTIONS,
|
|
331
335
|
max_retries=_RETRY_ATTEMPTS,
|
|
@@ -421,7 +425,7 @@ class DatabricksClient(APIClient):
|
|
|
421
425
|
|
|
422
426
|
return list(filter(None, links))
|
|
423
427
|
|
|
424
|
-
@safe_mode(
|
|
428
|
+
@safe_mode(safe_lineage_params, lambda: [])
|
|
425
429
|
@retry(
|
|
426
430
|
exceptions=_RETRY_EXCEPTIONS,
|
|
427
431
|
max_retries=_RETRY_ATTEMPTS,
|
|
@@ -468,8 +472,20 @@ class DatabricksClient(APIClient):
|
|
|
468
472
|
return self.formatter.format_lineage(deduplicated)
|
|
469
473
|
|
|
470
474
|
@staticmethod
|
|
471
|
-
def
|
|
472
|
-
|
|
475
|
+
def _time_filter_payload(start_time_ms: int, end_time_ms: int) -> dict:
|
|
476
|
+
return {
|
|
477
|
+
"filter_by": {
|
|
478
|
+
"query_start_time_range": {
|
|
479
|
+
"end_time_ms": end_time_ms,
|
|
480
|
+
"start_time_ms": start_time_ms,
|
|
481
|
+
}
|
|
482
|
+
}
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
def _hourly_time_filters(
|
|
486
|
+
self, time_filter: Optional[TimeFilter]
|
|
487
|
+
) -> Iterable[dict]:
|
|
488
|
+
"""time filters to retrieve Databricks' queries: 1h duration each"""
|
|
473
489
|
# define an explicit time window
|
|
474
490
|
if not time_filter:
|
|
475
491
|
time_filter = TimeFilter.default()
|
|
@@ -479,22 +495,13 @@ class DatabricksClient(APIClient):
|
|
|
479
495
|
hour_min = time_filter.hour_min
|
|
480
496
|
hour_max = time_filter.hour_max
|
|
481
497
|
day = time_filter.day
|
|
482
|
-
if hour_min is
|
|
483
|
-
|
|
484
|
-
# note: in practice, hour_min == hour_max (hourly query ingestion)
|
|
485
|
-
end_time_ms = _day_hour_to_epoch_ms(day, hour_max + 1)
|
|
486
|
-
else: # fallback to an extraction of the entire day
|
|
487
|
-
start_time_ms = _day_to_epoch_ms(day)
|
|
488
|
-
end_time_ms = _day_to_epoch_ms(date_after(day, 1))
|
|
498
|
+
if hour_min is None or hour_max is None: # fallback to an entire day
|
|
499
|
+
hour_min, hour_max = _DEFAULT_HOUR_MIN, _DEFAULT_HOUR_MAX
|
|
489
500
|
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
"start_time_ms": start_time_ms,
|
|
495
|
-
}
|
|
496
|
-
}
|
|
497
|
-
}
|
|
501
|
+
for index in range(hour_min, min(hour_max + 1, _NUM_HOURS_IN_A_DAY)):
|
|
502
|
+
start_time_ms = _day_hour_to_epoch_ms(day, index)
|
|
503
|
+
end_time_ms = _day_hour_to_epoch_ms(day, index + 1)
|
|
504
|
+
yield self._time_filter_payload(start_time_ms, end_time_ms)
|
|
498
505
|
|
|
499
506
|
def query_payload(
|
|
500
507
|
self,
|
|
@@ -507,10 +514,11 @@ class DatabricksClient(APIClient):
|
|
|
507
514
|
if page_token:
|
|
508
515
|
payload: Dict[str, Any] = {"page_token": page_token}
|
|
509
516
|
else:
|
|
510
|
-
if time_range_filter:
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
517
|
+
if not time_range_filter:
|
|
518
|
+
# should never happen.
|
|
519
|
+
# `time_range_filter` optional to leverage functiontools.partial
|
|
520
|
+
raise ValueError("Time range not specified")
|
|
521
|
+
payload = {**time_range_filter}
|
|
514
522
|
if max_results:
|
|
515
523
|
payload["max_results"] = max_results
|
|
516
524
|
return payload
|
|
@@ -532,18 +540,29 @@ class DatabricksClient(APIClient):
|
|
|
532
540
|
content = self.get(path=path, payload=payload)
|
|
533
541
|
return content if content else {}
|
|
534
542
|
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
543
|
+
@safe_mode(safe_query_params, lambda: [])
|
|
544
|
+
@retry(
|
|
545
|
+
exceptions=_RETRY_EXCEPTIONS,
|
|
546
|
+
max_retries=_RETRY_ATTEMPTS,
|
|
547
|
+
base_ms=_RETRY_BASE_MS,
|
|
548
|
+
)
|
|
549
|
+
def _queries(self, filter_: dict) -> List[dict]:
|
|
550
|
+
"""helper to retrieve queries using a given time filter"""
|
|
539
551
|
_time_filtered_scroll_queries = partial(
|
|
540
552
|
self._scroll_queries,
|
|
541
|
-
time_range_filter=
|
|
553
|
+
time_range_filter=filter_,
|
|
542
554
|
)
|
|
543
|
-
|
|
544
555
|
# retrieve all queries using pagination
|
|
545
|
-
|
|
556
|
+
return PagerOnToken(_time_filtered_scroll_queries).all()
|
|
557
|
+
|
|
558
|
+
def queries(self, time_filter: Optional[TimeFilter] = None) -> List[dict]:
|
|
559
|
+
"""get all queries, hour per hour"""
|
|
560
|
+
time_range_filters = self._hourly_time_filters(time_filter)
|
|
546
561
|
|
|
562
|
+
raw_queries = []
|
|
563
|
+
for _filter in time_range_filters:
|
|
564
|
+
hourly = self._queries(_filter)
|
|
565
|
+
raw_queries.extend(hourly)
|
|
547
566
|
return self.formatter.format_query(raw_queries)
|
|
548
567
|
|
|
549
568
|
def users(self) -> List[dict]:
|
|
@@ -4,7 +4,12 @@ from unittest.mock import Mock, patch
|
|
|
4
4
|
from freezegun import freeze_time
|
|
5
5
|
|
|
6
6
|
from ..abstract.time_filter import TimeFilter
|
|
7
|
-
from .client import
|
|
7
|
+
from .client import (
|
|
8
|
+
DatabricksClient,
|
|
9
|
+
DatabricksCredentials,
|
|
10
|
+
LineageLinks,
|
|
11
|
+
_day_hour_to_epoch_ms,
|
|
12
|
+
)
|
|
8
13
|
from .test_constants import (
|
|
9
14
|
CLOSER_DATE,
|
|
10
15
|
MOCK_TABLES_FOR_TABLE_LINEAGE,
|
|
@@ -18,45 +23,53 @@ def test__day_hour_to_epoch_ms():
|
|
|
18
23
|
|
|
19
24
|
|
|
20
25
|
@freeze_time("2023-7-4")
|
|
21
|
-
def
|
|
26
|
+
def test_DatabricksClient__hourly_time_filters():
|
|
27
|
+
credentials = DatabricksCredentials(
|
|
28
|
+
host="carthago",
|
|
29
|
+
token="delenda",
|
|
30
|
+
http_host="est",
|
|
31
|
+
)
|
|
32
|
+
client = DatabricksClient(credentials)
|
|
33
|
+
|
|
22
34
|
# default is yesterday
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
assert default_filter == expected_default
|
|
35
|
+
default_filters = [f for f in client._hourly_time_filters(None)]
|
|
36
|
+
|
|
37
|
+
assert len(default_filters) == 24 # number of hours in a day
|
|
38
|
+
|
|
39
|
+
first = default_filters[0]
|
|
40
|
+
start = first["filter_by"]["query_start_time_range"]["start_time_ms"]
|
|
41
|
+
last = default_filters[-1]
|
|
42
|
+
end = last["filter_by"]["query_start_time_range"]["end_time_ms"]
|
|
43
|
+
assert start == 1688342400000 # July 3, 2023 12:00:00 AM GMT
|
|
44
|
+
assert end == 1688428800000 # July 4, 2023 12:00:00 AM GMT
|
|
34
45
|
|
|
35
46
|
# custom time (from execution_date in DAG for example)
|
|
36
47
|
time_filter = TimeFilter(day=date(2020, 10, 15))
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
assert
|
|
48
|
+
custom_filters = [f for f in client._hourly_time_filters(time_filter)]
|
|
49
|
+
|
|
50
|
+
assert len(custom_filters) == 24
|
|
51
|
+
|
|
52
|
+
first = custom_filters[0]
|
|
53
|
+
start = first["filter_by"]["query_start_time_range"]["start_time_ms"]
|
|
54
|
+
last = custom_filters[-1]
|
|
55
|
+
end = last["filter_by"]["query_start_time_range"]["end_time_ms"]
|
|
56
|
+
assert start == 1602720000000 # Oct 15, 2020 12:00:00 AM
|
|
57
|
+
assert end == 1602806400000 # Oct 16, 2020 12:00:00 AM
|
|
47
58
|
|
|
48
59
|
# hourly extraction: note that hour_min == hour_max
|
|
49
60
|
hourly = TimeFilter(day=date(2023, 4, 14), hour_min=4, hour_max=4)
|
|
50
|
-
|
|
51
|
-
expected_hourly =
|
|
52
|
-
|
|
53
|
-
"
|
|
54
|
-
"
|
|
55
|
-
|
|
61
|
+
hourly_filters = [f for f in client._hourly_time_filters(hourly)]
|
|
62
|
+
expected_hourly = [
|
|
63
|
+
{
|
|
64
|
+
"filter_by": {
|
|
65
|
+
"query_start_time_range": {
|
|
66
|
+
"end_time_ms": 1681448400000, # April 14, 2023 5:00:00 AM
|
|
67
|
+
"start_time_ms": 1681444800000, # April 14, 2023 4:00:00 AM
|
|
68
|
+
}
|
|
56
69
|
}
|
|
57
70
|
}
|
|
58
|
-
|
|
59
|
-
assert
|
|
71
|
+
]
|
|
72
|
+
assert hourly_filters == expected_hourly
|
|
60
73
|
|
|
61
74
|
|
|
62
75
|
class MockDatabricksClient(DatabricksClient):
|
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Dict, Iterator, List, Tuple
|
|
2
|
+
from typing import Dict, Iterator, List, Optional, Tuple
|
|
3
3
|
|
|
4
4
|
from tqdm import tqdm # type: ignore
|
|
5
5
|
|
|
6
6
|
from ...utils.salesforce import SalesforceBaseClient, SalesforceCredentials
|
|
7
7
|
from .format import SalesforceFormatter
|
|
8
|
-
from .soql import
|
|
8
|
+
from .soql import (
|
|
9
|
+
DESCRIPTION_QUERY_TPL,
|
|
10
|
+
SOBJECT_FIELDS_QUERY_TPL,
|
|
11
|
+
SOBJECTS_QUERY_TPL,
|
|
12
|
+
)
|
|
9
13
|
|
|
10
14
|
logger = logging.getLogger(__name__)
|
|
11
15
|
|
|
@@ -90,13 +94,34 @@ class SalesforceClient(SalesforceBaseClient):
|
|
|
90
94
|
response = self._call(self.tooling_url, params={"q": query})
|
|
91
95
|
return response["records"]
|
|
92
96
|
|
|
97
|
+
def fetch_description(self, table_name: str) -> Optional[str]:
|
|
98
|
+
"""Retrieve description of a table"""
|
|
99
|
+
query = DESCRIPTION_QUERY_TPL.format(table_name=table_name)
|
|
100
|
+
response = self._call(self.tooling_url, params={"q": query})
|
|
101
|
+
if not response["records"]:
|
|
102
|
+
return None
|
|
103
|
+
return response["records"][0]["Description"]
|
|
104
|
+
|
|
105
|
+
def add_table_descriptions(self, sobjects: List[dict]) -> List[dict]:
|
|
106
|
+
"""
|
|
107
|
+
Add table descriptions.
|
|
108
|
+
We use the tooling API which does not handle well the LIMIT in SOQL
|
|
109
|
+
so we have to retrieve descriptions individually
|
|
110
|
+
"""
|
|
111
|
+
described_sobjects = []
|
|
112
|
+
for sobject in sobjects:
|
|
113
|
+
description = self.fetch_description(sobject["QualifiedApiName"])
|
|
114
|
+
described_sobjects.append({**sobject, "Description": description})
|
|
115
|
+
return described_sobjects
|
|
116
|
+
|
|
93
117
|
def tables(self) -> List[dict]:
|
|
94
118
|
"""
|
|
95
119
|
Get Salesforce sobjects as tables
|
|
96
120
|
"""
|
|
97
121
|
sobjects = self.fetch_sobjects()
|
|
98
122
|
logger.info(f"Extracted {len(sobjects)} sobjects")
|
|
99
|
-
|
|
123
|
+
described_sobjects = self.add_table_descriptions(sobjects)
|
|
124
|
+
return list(self.formatter.tables(described_sobjects))
|
|
100
125
|
|
|
101
126
|
def columns(
|
|
102
127
|
self, sobject_names: List[Tuple[str, str]], show_progress: bool = True
|
|
@@ -19,7 +19,6 @@ def _example_sobjects() -> Tuple[Dict[str, str], ...]:
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
def test__field_description():
|
|
22
|
-
|
|
23
22
|
field = {}
|
|
24
23
|
assert _field_description(field) == ""
|
|
25
24
|
|
|
@@ -59,7 +58,7 @@ def test__merge_label_and_api_name():
|
|
|
59
58
|
"label": "foo",
|
|
60
59
|
"schema_id": SCHEMA_NAME,
|
|
61
60
|
"table_name": expected_name,
|
|
62
|
-
"description":
|
|
61
|
+
"description": None,
|
|
63
62
|
"tags": [],
|
|
64
63
|
"type": "TABLE",
|
|
65
64
|
}
|
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
DESCRIPTION_QUERY_TPL = """
|
|
2
|
+
SELECT Description
|
|
3
|
+
FROM EntityDefinition
|
|
4
|
+
WHERE QualifiedApiName = '{table_name}'
|
|
5
|
+
"""
|
|
6
|
+
|
|
1
7
|
SOBJECTS_QUERY_TPL = """
|
|
2
8
|
SELECT
|
|
3
9
|
DeveloperName,
|
|
@@ -13,7 +19,6 @@ SOBJECTS_QUERY_TPL = """
|
|
|
13
19
|
LIMIT {limit}
|
|
14
20
|
"""
|
|
15
21
|
|
|
16
|
-
|
|
17
22
|
SOBJECT_FIELDS_QUERY_TPL = """
|
|
18
23
|
SELECT
|
|
19
24
|
DataType,
|