castor-extractor 0.18.5__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of castor-extractor might be problematic. Click here for more details.

Files changed (69) hide show
  1. CHANGELOG.md +48 -1
  2. castor_extractor/commands/extract_looker.py +3 -3
  3. castor_extractor/commands/extract_metabase_api.py +1 -1
  4. castor_extractor/commands/extract_metabase_db.py +1 -1
  5. castor_extractor/commands/extract_notion.py +16 -0
  6. castor_extractor/commands/file_check.py +5 -2
  7. castor_extractor/commands/upload.py +5 -3
  8. castor_extractor/knowledge/__init__.py +0 -0
  9. castor_extractor/knowledge/notion/__init__.py +3 -0
  10. castor_extractor/knowledge/notion/assets.py +9 -0
  11. castor_extractor/knowledge/notion/client/__init__.py +2 -0
  12. castor_extractor/knowledge/notion/client/client.py +145 -0
  13. castor_extractor/knowledge/notion/client/client_test.py +67 -0
  14. castor_extractor/knowledge/notion/client/constants.py +3 -0
  15. castor_extractor/knowledge/notion/client/credentials.py +16 -0
  16. castor_extractor/knowledge/notion/client/endpoints.py +18 -0
  17. castor_extractor/knowledge/notion/client/pagination.py +16 -0
  18. castor_extractor/knowledge/notion/extract.py +59 -0
  19. castor_extractor/quality/__init__.py +0 -0
  20. castor_extractor/quality/soda/__init__.py +2 -0
  21. castor_extractor/quality/soda/assets.py +8 -0
  22. castor_extractor/quality/soda/client/__init__.py +1 -0
  23. castor_extractor/quality/soda/client/client.py +99 -0
  24. castor_extractor/quality/soda/client/credentials.py +28 -0
  25. castor_extractor/quality/soda/client/endpoints.py +13 -0
  26. castor_extractor/types.py +1 -3
  27. castor_extractor/uploader/upload.py +0 -1
  28. castor_extractor/utils/__init__.py +2 -0
  29. castor_extractor/utils/argument_parser_test.py +0 -1
  30. castor_extractor/utils/client/api.py +29 -11
  31. castor_extractor/utils/client/api_test.py +9 -1
  32. castor_extractor/utils/object_test.py +1 -1
  33. castor_extractor/utils/pager/pager.py +1 -1
  34. castor_extractor/utils/pager/pager_on_id.py +11 -6
  35. castor_extractor/utils/safe_request.py +5 -3
  36. castor_extractor/utils/safe_request_test.py +1 -3
  37. castor_extractor/utils/string_test.py +1 -1
  38. castor_extractor/utils/time.py +11 -0
  39. castor_extractor/visualization/domo/client/client.py +2 -3
  40. castor_extractor/visualization/looker/api/client.py +35 -0
  41. castor_extractor/visualization/looker/api/extraction_parameters.py +2 -1
  42. castor_extractor/visualization/looker/extract.py +2 -2
  43. castor_extractor/visualization/metabase/assets.py +3 -1
  44. castor_extractor/visualization/metabase/extract.py +20 -8
  45. castor_extractor/visualization/mode/client/client.py +1 -1
  46. castor_extractor/visualization/powerbi/client/constants.py +1 -1
  47. castor_extractor/visualization/powerbi/client/rest.py +5 -15
  48. castor_extractor/visualization/qlik/client/engine/client.py +36 -5
  49. castor_extractor/visualization/qlik/client/engine/constants.py +1 -0
  50. castor_extractor/visualization/qlik/client/engine/error.py +18 -1
  51. castor_extractor/visualization/salesforce_reporting/client/soql.py +3 -1
  52. castor_extractor/visualization/tableau/extract.py +40 -16
  53. castor_extractor/visualization/tableau_revamp/client/client.py +2 -5
  54. castor_extractor/visualization/tableau_revamp/extract.py +3 -2
  55. castor_extractor/warehouse/bigquery/client.py +41 -6
  56. castor_extractor/warehouse/bigquery/extract.py +1 -0
  57. castor_extractor/warehouse/bigquery/query.py +23 -9
  58. castor_extractor/warehouse/bigquery/types.py +1 -2
  59. castor_extractor/warehouse/databricks/client.py +54 -35
  60. castor_extractor/warehouse/databricks/client_test.py +44 -31
  61. castor_extractor/warehouse/salesforce/client.py +28 -3
  62. castor_extractor/warehouse/salesforce/format.py +1 -1
  63. castor_extractor/warehouse/salesforce/format_test.py +1 -2
  64. castor_extractor/warehouse/salesforce/soql.py +6 -1
  65. {castor_extractor-0.18.5.dist-info → castor_extractor-0.19.0.dist-info}/METADATA +4 -4
  66. {castor_extractor-0.18.5.dist-info → castor_extractor-0.19.0.dist-info}/RECORD +69 -50
  67. {castor_extractor-0.18.5.dist-info → castor_extractor-0.19.0.dist-info}/entry_points.txt +1 -0
  68. {castor_extractor-0.18.5.dist-info → castor_extractor-0.19.0.dist-info}/LICENCE +0 -0
  69. {castor_extractor-0.18.5.dist-info → castor_extractor-0.19.0.dist-info}/WHEEL +0 -0
@@ -26,46 +26,70 @@ def iterate_all_data(
26
26
  yield TableauAsset.USER, deep_serialize(client.fetch(TableauAsset.USER))
27
27
 
28
28
  logger.info("Extracting WORKBOOK from Tableau API")
29
- yield TableauAsset.WORKBOOK, deep_serialize(
30
- client.fetch(TableauAsset.WORKBOOK),
29
+ yield (
30
+ TableauAsset.WORKBOOK,
31
+ deep_serialize(
32
+ client.fetch(TableauAsset.WORKBOOK),
33
+ ),
31
34
  )
32
35
 
33
36
  logger.info("Extracting DASHBOARD from Tableau API")
34
- yield TableauAsset.DASHBOARD, deep_serialize(
35
- client.fetch(TableauAsset.DASHBOARD),
37
+ yield (
38
+ TableauAsset.DASHBOARD,
39
+ deep_serialize(
40
+ client.fetch(TableauAsset.DASHBOARD),
41
+ ),
36
42
  )
37
43
 
38
44
  logger.info("Extracting PUBLISHED DATASOURCE from Tableau API")
39
- yield TableauAsset.PUBLISHED_DATASOURCE, deep_serialize(
40
- client.fetch(TableauAsset.PUBLISHED_DATASOURCE),
45
+ yield (
46
+ TableauAsset.PUBLISHED_DATASOURCE,
47
+ deep_serialize(
48
+ client.fetch(TableauAsset.PUBLISHED_DATASOURCE),
49
+ ),
41
50
  )
42
51
 
43
52
  logger.info("Extracting PROJECT from Tableau API")
44
- yield TableauAsset.PROJECT, deep_serialize(
45
- client.fetch(TableauAsset.PROJECT),
53
+ yield (
54
+ TableauAsset.PROJECT,
55
+ deep_serialize(
56
+ client.fetch(TableauAsset.PROJECT),
57
+ ),
46
58
  )
47
59
 
48
60
  logger.info("Extracting USAGE from Tableau API")
49
61
  yield TableauAsset.USAGE, deep_serialize(client.fetch(TableauAsset.USAGE))
50
62
 
51
63
  logger.info("Extracting WORKBOOK_TO_DATASOURCE from Tableau API")
52
- yield TableauAsset.WORKBOOK_TO_DATASOURCE, deep_serialize(
53
- client.fetch(TableauAsset.WORKBOOK_TO_DATASOURCE),
64
+ yield (
65
+ TableauAsset.WORKBOOK_TO_DATASOURCE,
66
+ deep_serialize(
67
+ client.fetch(TableauAsset.WORKBOOK_TO_DATASOURCE),
68
+ ),
54
69
  )
55
70
 
56
71
  logger.info("Extracting DATASOURCE from Tableau API")
57
- yield TableauAsset.DATASOURCE, deep_serialize(
58
- client.fetch(TableauAsset.DATASOURCE),
72
+ yield (
73
+ TableauAsset.DATASOURCE,
74
+ deep_serialize(
75
+ client.fetch(TableauAsset.DATASOURCE),
76
+ ),
59
77
  )
60
78
 
61
79
  logger.info("Extracting CUSTOM_SQL_TABLE from Tableau API")
62
- yield TableauAsset.CUSTOM_SQL_TABLE, deep_serialize(
63
- client.fetch(TableauAsset.CUSTOM_SQL_TABLE),
80
+ yield (
81
+ TableauAsset.CUSTOM_SQL_TABLE,
82
+ deep_serialize(
83
+ client.fetch(TableauAsset.CUSTOM_SQL_TABLE),
84
+ ),
64
85
  )
65
86
 
66
87
  logger.info("Extracting CUSTOM_SQL_QUERY from Tableau API")
67
- yield TableauAsset.CUSTOM_SQL_QUERY, deep_serialize(
68
- client.fetch(TableauAsset.CUSTOM_SQL_QUERY),
88
+ yield (
89
+ TableauAsset.CUSTOM_SQL_QUERY,
90
+ deep_serialize(
91
+ client.fetch(TableauAsset.CUSTOM_SQL_QUERY),
92
+ ),
69
93
  )
70
94
 
71
95
  logger.info("Extracting FIELD from Tableau API")
@@ -210,7 +210,6 @@ class TableauRevampClient:
210
210
  self,
211
211
  asset: TableauRevampAsset,
212
212
  ) -> SerializedAsset:
213
-
214
213
  if asset == TableauRevampAsset.DATASOURCE:
215
214
  data = TSC.Pager(self._server.datasources)
216
215
 
@@ -280,13 +279,11 @@ class TableauRevampClient:
280
279
 
281
280
  return _enrich_workbooks_with_tsc(workbooks, workbook_projects)
282
281
 
283
- def fetch(
284
- self,
285
- asset: TableauRevampAsset,
286
- ) -> SerializedAsset:
282
+ def fetch(self, asset: TableauRevampAsset) -> SerializedAsset:
287
283
  """
288
284
  Extract the given Tableau Asset
289
285
  """
286
+
290
287
  if asset == TableauRevampAsset.DATASOURCE:
291
288
  # both APIs are required to extract datasources
292
289
  return self._fetch_datasources()
@@ -23,8 +23,9 @@ def iterate_all_data(
23
23
  """Iterate over the extracted Data from Tableau"""
24
24
 
25
25
  logger.info("Extracting USER from Tableau API")
26
- yield TableauRevampAsset.USER, deep_serialize(
27
- client.fetch(TableauRevampAsset.USER)
26
+ yield (
27
+ TableauRevampAsset.USER,
28
+ deep_serialize(client.fetch(TableauRevampAsset.USER)),
28
29
  )
29
30
 
30
31
 
@@ -1,13 +1,14 @@
1
+ import itertools
1
2
  import logging
2
- from typing import List, Optional, Set, Tuple
3
+ from typing import List, Optional, Set
3
4
 
4
- from google.api_core.exceptions import Forbidden
5
- from google.api_core.page_iterator import Iterator as PageIterator
5
+ from google.api_core.exceptions import Forbidden # type: ignore
6
6
  from google.cloud.bigquery import Client as GoogleCloudClient # type: ignore
7
7
  from google.cloud.bigquery.dataset import Dataset # type: ignore
8
8
  from google.oauth2.service_account import Credentials # type: ignore
9
9
 
10
10
  from ...utils import SqlalchemyClient, retry
11
+ from .types import SetTwoString
11
12
 
12
13
  logger = logging.getLogger(__name__)
13
14
 
@@ -117,16 +118,50 @@ class BigQueryClient(SqlalchemyClient):
117
118
  ]
118
119
  return self._projects
119
120
 
120
- def get_regions(self) -> Set[Tuple[str, str]]:
121
+ def get_regions(self) -> SetTwoString:
121
122
  """
122
- Returns distinct (project_id, region) available for the given GCP client
123
+ Returns (project_id, region) available for the given GCP client
124
+ - Loops trough projects -> datasets -> region
125
+ - Returns distinct values
126
+ Example:
127
+ project_A
128
+ -> dataset_1:US
129
+ project_B
130
+ -> empty
131
+ project_C
132
+ -> dataset_2:EU
133
+ -> dataset_3:EU
134
+ Will return:
135
+ { (p_A, US), (p_C, EU) }
123
136
  """
124
137
  return {
125
138
  (ds.project, ds._properties["location"])
126
139
  for ds in self._list_datasets()
127
140
  }
128
141
 
129
- def get_datasets(self) -> Set[Tuple[str, str]]:
142
+ def get_extended_regions(self) -> SetTwoString:
143
+ """
144
+ Returns all combinations of (project_id, region) for the given client
145
+ - Fetch all projects
146
+ - Fetch all regions (cross projects)
147
+ - Returns a combination of the two lists
148
+ Example:
149
+ project_A
150
+ -> dataset_1:US
151
+ project_B
152
+ -> empty
153
+ project_C
154
+ -> dataset_2:EU
155
+ -> dataset_3:EU
156
+ Will return:
157
+ { (p_A, EU), (p_A, US), (p_B, EU), (p_B, US), (p_C, EU), (p_C, US) }
158
+ """
159
+ projects = self.get_projects()
160
+ regions = {ds._properties["location"] for ds in self._list_datasets()}
161
+ combinations = itertools.product(projects, regions)
162
+ return set(combinations)
163
+
164
+ def get_datasets(self) -> SetTwoString:
130
165
  """
131
166
  Returns distinct (project_id, dataset_id) available for the given GCP client
132
167
  """
@@ -68,6 +68,7 @@ def extract_all(**kwargs) -> None:
68
68
  query_builder = BigQueryQueryBuilder(
69
69
  regions=client.get_regions(),
70
70
  datasets=client.get_datasets(),
71
+ extended_regions=client.get_extended_regions(),
71
72
  )
72
73
 
73
74
  storage = LocalStorage(directory=output_directory)
@@ -2,18 +2,16 @@ import logging
2
2
  from typing import List, Optional
3
3
 
4
4
  from ..abstract import (
5
- QUERIES_DIR,
6
5
  AbstractQueryBuilder,
7
6
  ExtractionQuery,
8
7
  TimeFilter,
9
8
  WarehouseAsset,
10
9
  )
11
-
12
- # Those queries must be formatted with {region}
13
- from .types import IterTwoString
10
+ from .types import SetTwoString
14
11
 
15
12
  logger = logging.getLogger(__name__)
16
13
 
14
+ # Those queries must be formatted with {region}
17
15
  REGION_REQUIRED = (
18
16
  WarehouseAsset.COLUMN,
19
17
  WarehouseAsset.DATABASE,
@@ -23,6 +21,11 @@ REGION_REQUIRED = (
23
21
  WarehouseAsset.USER,
24
22
  )
25
23
 
24
+ # Some clients use empty projects (no datasets) to run their SQL queries
25
+ # The extended regions is a combination of all regions with all projects
26
+ # It allows to extract those queries which were left apart before
27
+ EXTENDED_REGION_REQUIRED = (WarehouseAsset.QUERY,)
28
+
26
29
  # Those queries must be formatted with {dataset}
27
30
  DATASET_REQUIRED = (WarehouseAsset.VIEW_DDL,)
28
31
 
@@ -38,7 +41,7 @@ SHARDED_ASSETS = (WarehouseAsset.TABLE, WarehouseAsset.COLUMN)
38
41
  SHARDED_FILE_PATH = "cte/sharded.sql"
39
42
 
40
43
 
41
- def _database_formatted(datasets: IterTwoString) -> str:
44
+ def _database_formatted(datasets: SetTwoString) -> str:
42
45
  databases = {db for _, db in datasets}
43
46
  if not databases:
44
47
  # when no datasets are provided condition should pass
@@ -55,10 +58,11 @@ class BigQueryQueryBuilder(AbstractQueryBuilder):
55
58
 
56
59
  def __init__(
57
60
  self,
58
- regions: IterTwoString,
59
- datasets: IterTwoString,
61
+ regions: SetTwoString,
62
+ datasets: SetTwoString,
60
63
  time_filter: Optional[TimeFilter] = None,
61
64
  sync_tags: Optional[bool] = False,
65
+ extended_regions: Optional[SetTwoString] = None,
62
66
  ):
63
67
  super().__init__(
64
68
  time_filter=time_filter,
@@ -67,6 +71,7 @@ class BigQueryQueryBuilder(AbstractQueryBuilder):
67
71
  self._regions = regions
68
72
  self._datasets = datasets
69
73
  self._sync_tags = sync_tags
74
+ self._extended_regions = extended_regions or regions
70
75
 
71
76
  @staticmethod
72
77
  def _format(query: ExtractionQuery, values: dict) -> ExtractionQuery:
@@ -97,6 +102,13 @@ class BigQueryQueryBuilder(AbstractQueryBuilder):
97
102
  sharded_statement = self._load_from_file(SHARDED_FILE_PATH)
98
103
  return statement.format(sharded_statement=sharded_statement)
99
104
 
105
+ def _get_regions(self, asset: WarehouseAsset) -> SetTwoString:
106
+ return (
107
+ self._extended_regions
108
+ if asset in EXTENDED_REGION_REQUIRED
109
+ else self._regions
110
+ )
111
+
100
112
  def build(self, asset: WarehouseAsset) -> List[ExtractionQuery]:
101
113
  """
102
114
  It would be easier to stitch data directly in the query statement (UNION ALL).
@@ -110,12 +122,14 @@ class BigQueryQueryBuilder(AbstractQueryBuilder):
110
122
  query = super().build_default(asset)
111
123
 
112
124
  if asset in REGION_REQUIRED:
125
+ regions = self._get_regions(asset)
126
+
113
127
  logger.info(
114
- f"\tWill run queries with following region params: {self._regions}",
128
+ f"\tWill run queries with following region params: {regions}",
115
129
  )
116
130
  return [
117
131
  self._format(query, {"project": project, "region": region})
118
- for project, region in self._regions
132
+ for project, region in regions
119
133
  ]
120
134
 
121
135
  if asset in DATASET_REQUIRED:
@@ -1,5 +1,4 @@
1
- from typing import Iterable, Set, Tuple
1
+ from typing import Set, Tuple
2
2
 
3
3
  SetString = Set[str]
4
4
  SetTwoString = Set[Tuple[str, str]]
5
- IterTwoString = Iterable[Tuple[str, str]]
@@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
4
4
  from datetime import date
5
5
  from enum import Enum
6
6
  from functools import partial
7
- from typing import Any, Dict, List, Optional, Set, Tuple, cast
7
+ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, cast
8
8
 
9
9
  import requests
10
10
  from databricks import sql # type: ignore
@@ -28,9 +28,13 @@ from .utils import build_path, tag_label
28
28
 
29
29
  logger = logging.getLogger(__name__)
30
30
 
31
- _DATABRICKS_CLIENT_TIMEOUT = 60
31
+ _DATABRICKS_CLIENT_TIMEOUT = 90
32
+ _DEFAULT_HOUR_MIN = 0
33
+ _DEFAULT_HOUR_MAX = 23
32
34
  _MAX_NUMBER_OF_LINEAGE_ERRORS = 1000
35
+ _MAX_NUMBER_OF_QUERY_ERRORS = 1000
33
36
  _MAX_THREADS = 10
37
+ _NUM_HOURS_IN_A_DAY = 24
34
38
  _RETRY_ATTEMPTS = 3
35
39
  _RETRY_BASE_MS = 1000
36
40
  _RETRY_EXCEPTIONS = [
@@ -40,7 +44,8 @@ _WORKSPACE_ID_HEADER = "X-Databricks-Org-Id"
40
44
 
41
45
  _INFORMATION_SCHEMA_SQL = "SELECT * FROM system.information_schema"
42
46
 
43
- safe_params = SafeMode((BaseException,), _MAX_NUMBER_OF_LINEAGE_ERRORS)
47
+ safe_lineage_params = SafeMode((BaseException,), _MAX_NUMBER_OF_LINEAGE_ERRORS)
48
+ safe_query_params = SafeMode((BaseException,), _MAX_NUMBER_OF_QUERY_ERRORS)
44
49
 
45
50
 
46
51
  class TagEntity(Enum):
@@ -112,7 +117,7 @@ class DatabricksClient(APIClient):
112
117
  Execute a SQL query on Databricks system tables and return the results.
113
118
  https://docs.databricks.com/en/dev-tools/python-sql-connector.html
114
119
 
115
- /!\ credentials.http_path is required in order to run SQL queries
120
+ //!\\ credentials.http_path is required in order to run SQL queries
116
121
  """
117
122
  assert self._http_path, "HTTP_PATH is required to run SQL queries"
118
123
  with sql.connect(
@@ -261,7 +266,6 @@ class DatabricksClient(APIClient):
261
266
  table_tags = self._get_tags_mapping(TagEntity.TABLE)
262
267
  column_tags = self._get_tags_mapping(TagEntity.COLUMN)
263
268
  for schema in schemas:
264
-
265
269
  t_to_add, c_to_add = self._tables_columns_of_schema(
266
270
  schema=schema,
267
271
  table_tags=table_tags,
@@ -325,7 +329,7 @@ class DatabricksClient(APIClient):
325
329
 
326
330
  return list(filter(None, links))
327
331
 
328
- @safe_mode(safe_params, lambda: [])
332
+ @safe_mode(safe_lineage_params, lambda: [])
329
333
  @retry(
330
334
  exceptions=_RETRY_EXCEPTIONS,
331
335
  max_retries=_RETRY_ATTEMPTS,
@@ -421,7 +425,7 @@ class DatabricksClient(APIClient):
421
425
 
422
426
  return list(filter(None, links))
423
427
 
424
- @safe_mode(safe_params, lambda: [])
428
+ @safe_mode(safe_lineage_params, lambda: [])
425
429
  @retry(
426
430
  exceptions=_RETRY_EXCEPTIONS,
427
431
  max_retries=_RETRY_ATTEMPTS,
@@ -468,8 +472,20 @@ class DatabricksClient(APIClient):
468
472
  return self.formatter.format_lineage(deduplicated)
469
473
 
470
474
  @staticmethod
471
- def _time_filter(time_filter: Optional[TimeFilter]) -> dict:
472
- """time filter to retrieve Databricks' queries"""
475
+ def _time_filter_payload(start_time_ms: int, end_time_ms: int) -> dict:
476
+ return {
477
+ "filter_by": {
478
+ "query_start_time_range": {
479
+ "end_time_ms": end_time_ms,
480
+ "start_time_ms": start_time_ms,
481
+ }
482
+ }
483
+ }
484
+
485
+ def _hourly_time_filters(
486
+ self, time_filter: Optional[TimeFilter]
487
+ ) -> Iterable[dict]:
488
+ """time filters to retrieve Databricks' queries: 1h duration each"""
473
489
  # define an explicit time window
474
490
  if not time_filter:
475
491
  time_filter = TimeFilter.default()
@@ -479,22 +495,13 @@ class DatabricksClient(APIClient):
479
495
  hour_min = time_filter.hour_min
480
496
  hour_max = time_filter.hour_max
481
497
  day = time_filter.day
482
- if hour_min is not None and hour_max is not None: # specific window
483
- start_time_ms = _day_hour_to_epoch_ms(day, hour_min)
484
- # note: in practice, hour_min == hour_max (hourly query ingestion)
485
- end_time_ms = _day_hour_to_epoch_ms(day, hour_max + 1)
486
- else: # fallback to an extraction of the entire day
487
- start_time_ms = _day_to_epoch_ms(day)
488
- end_time_ms = _day_to_epoch_ms(date_after(day, 1))
498
+ if hour_min is None or hour_max is None: # fallback to an entire day
499
+ hour_min, hour_max = _DEFAULT_HOUR_MIN, _DEFAULT_HOUR_MAX
489
500
 
490
- return {
491
- "filter_by": {
492
- "query_start_time_range": {
493
- "end_time_ms": end_time_ms,
494
- "start_time_ms": start_time_ms,
495
- }
496
- }
497
- }
501
+ for index in range(hour_min, min(hour_max + 1, _NUM_HOURS_IN_A_DAY)):
502
+ start_time_ms = _day_hour_to_epoch_ms(day, index)
503
+ end_time_ms = _day_hour_to_epoch_ms(day, index + 1)
504
+ yield self._time_filter_payload(start_time_ms, end_time_ms)
498
505
 
499
506
  def query_payload(
500
507
  self,
@@ -507,10 +514,11 @@ class DatabricksClient(APIClient):
507
514
  if page_token:
508
515
  payload: Dict[str, Any] = {"page_token": page_token}
509
516
  else:
510
- if time_range_filter:
511
- payload = {**time_range_filter}
512
- else:
513
- payload = self._time_filter(None) # default to yesterday
517
+ if not time_range_filter:
518
+ # should never happen.
519
+ # `time_range_filter` optional to leverage functiontools.partial
520
+ raise ValueError("Time range not specified")
521
+ payload = {**time_range_filter}
514
522
  if max_results:
515
523
  payload["max_results"] = max_results
516
524
  return payload
@@ -532,18 +540,29 @@ class DatabricksClient(APIClient):
532
540
  content = self.get(path=path, payload=payload)
533
541
  return content if content else {}
534
542
 
535
- def queries(self, time_filter: Optional[TimeFilter] = None) -> List[dict]:
536
- """get all queries"""
537
- # add a time filter (by default: yesterday)
538
- time_range_filter = self._time_filter(time_filter)
543
+ @safe_mode(safe_query_params, lambda: [])
544
+ @retry(
545
+ exceptions=_RETRY_EXCEPTIONS,
546
+ max_retries=_RETRY_ATTEMPTS,
547
+ base_ms=_RETRY_BASE_MS,
548
+ )
549
+ def _queries(self, filter_: dict) -> List[dict]:
550
+ """helper to retrieve queries using a given time filter"""
539
551
  _time_filtered_scroll_queries = partial(
540
552
  self._scroll_queries,
541
- time_range_filter=time_range_filter,
553
+ time_range_filter=filter_,
542
554
  )
543
-
544
555
  # retrieve all queries using pagination
545
- raw_queries = PagerOnToken(_time_filtered_scroll_queries).all()
556
+ return PagerOnToken(_time_filtered_scroll_queries).all()
557
+
558
+ def queries(self, time_filter: Optional[TimeFilter] = None) -> List[dict]:
559
+ """get all queries, hour per hour"""
560
+ time_range_filters = self._hourly_time_filters(time_filter)
546
561
 
562
+ raw_queries = []
563
+ for _filter in time_range_filters:
564
+ hourly = self._queries(_filter)
565
+ raw_queries.extend(hourly)
547
566
  return self.formatter.format_query(raw_queries)
548
567
 
549
568
  def users(self) -> List[dict]:
@@ -4,7 +4,12 @@ from unittest.mock import Mock, patch
4
4
  from freezegun import freeze_time
5
5
 
6
6
  from ..abstract.time_filter import TimeFilter
7
- from .client import DatabricksClient, LineageLinks, _day_hour_to_epoch_ms
7
+ from .client import (
8
+ DatabricksClient,
9
+ DatabricksCredentials,
10
+ LineageLinks,
11
+ _day_hour_to_epoch_ms,
12
+ )
8
13
  from .test_constants import (
9
14
  CLOSER_DATE,
10
15
  MOCK_TABLES_FOR_TABLE_LINEAGE,
@@ -18,45 +23,53 @@ def test__day_hour_to_epoch_ms():
18
23
 
19
24
 
20
25
  @freeze_time("2023-7-4")
21
- def test_DatabricksClient__time_filter():
26
+ def test_DatabricksClient__hourly_time_filters():
27
+ credentials = DatabricksCredentials(
28
+ host="carthago",
29
+ token="delenda",
30
+ http_host="est",
31
+ )
32
+ client = DatabricksClient(credentials)
33
+
22
34
  # default is yesterday
23
- default_time_filter = None
24
- default_filter = DatabricksClient._time_filter(default_time_filter)
25
- expected_default = {
26
- "filter_by": {
27
- "query_start_time_range": {
28
- "end_time_ms": 1688428800000, # July 4, 2023 12:00:00 AM GMT
29
- "start_time_ms": 1688342400000, # July 3, 2023 12:00:00 AM GMT
30
- }
31
- }
32
- }
33
- assert default_filter == expected_default
35
+ default_filters = [f for f in client._hourly_time_filters(None)]
36
+
37
+ assert len(default_filters) == 24 # number of hours in a day
38
+
39
+ first = default_filters[0]
40
+ start = first["filter_by"]["query_start_time_range"]["start_time_ms"]
41
+ last = default_filters[-1]
42
+ end = last["filter_by"]["query_start_time_range"]["end_time_ms"]
43
+ assert start == 1688342400000 # July 3, 2023 12:00:00 AM GMT
44
+ assert end == 1688428800000 # July 4, 2023 12:00:00 AM GMT
34
45
 
35
46
  # custom time (from execution_date in DAG for example)
36
47
  time_filter = TimeFilter(day=date(2020, 10, 15))
37
- custom_filter = DatabricksClient._time_filter(time_filter)
38
- expected_custom = {
39
- "filter_by": {
40
- "query_start_time_range": {
41
- "end_time_ms": 1602806400000, # October 16, 2020 12:00:00 AM
42
- "start_time_ms": 1602720000000, # October 15, 2020 12:00:00 AM
43
- }
44
- }
45
- }
46
- assert custom_filter == expected_custom
48
+ custom_filters = [f for f in client._hourly_time_filters(time_filter)]
49
+
50
+ assert len(custom_filters) == 24
51
+
52
+ first = custom_filters[0]
53
+ start = first["filter_by"]["query_start_time_range"]["start_time_ms"]
54
+ last = custom_filters[-1]
55
+ end = last["filter_by"]["query_start_time_range"]["end_time_ms"]
56
+ assert start == 1602720000000 # Oct 15, 2020 12:00:00 AM
57
+ assert end == 1602806400000 # Oct 16, 2020 12:00:00 AM
47
58
 
48
59
  # hourly extraction: note that hour_min == hour_max
49
60
  hourly = TimeFilter(day=date(2023, 4, 14), hour_min=4, hour_max=4)
50
- hourly_filter = DatabricksClient._time_filter(hourly)
51
- expected_hourly = {
52
- "filter_by": {
53
- "query_start_time_range": {
54
- "end_time_ms": 1681448400000, # April 14, 2023 5:00:00 AM
55
- "start_time_ms": 1681444800000, # April 14, 2023 4:00:00 AM
61
+ hourly_filters = [f for f in client._hourly_time_filters(hourly)]
62
+ expected_hourly = [
63
+ {
64
+ "filter_by": {
65
+ "query_start_time_range": {
66
+ "end_time_ms": 1681448400000, # April 14, 2023 5:00:00 AM
67
+ "start_time_ms": 1681444800000, # April 14, 2023 4:00:00 AM
68
+ }
56
69
  }
57
70
  }
58
- }
59
- assert hourly_filter == expected_hourly
71
+ ]
72
+ assert hourly_filters == expected_hourly
60
73
 
61
74
 
62
75
  class MockDatabricksClient(DatabricksClient):
@@ -1,11 +1,15 @@
1
1
  import logging
2
- from typing import Dict, Iterator, List, Tuple
2
+ from typing import Dict, Iterator, List, Optional, Tuple
3
3
 
4
4
  from tqdm import tqdm # type: ignore
5
5
 
6
6
  from ...utils.salesforce import SalesforceBaseClient, SalesforceCredentials
7
7
  from .format import SalesforceFormatter
8
- from .soql import SOBJECT_FIELDS_QUERY_TPL, SOBJECTS_QUERY_TPL
8
+ from .soql import (
9
+ DESCRIPTION_QUERY_TPL,
10
+ SOBJECT_FIELDS_QUERY_TPL,
11
+ SOBJECTS_QUERY_TPL,
12
+ )
9
13
 
10
14
  logger = logging.getLogger(__name__)
11
15
 
@@ -90,13 +94,34 @@ class SalesforceClient(SalesforceBaseClient):
90
94
  response = self._call(self.tooling_url, params={"q": query})
91
95
  return response["records"]
92
96
 
97
+ def fetch_description(self, table_name: str) -> Optional[str]:
98
+ """Retrieve description of a table"""
99
+ query = DESCRIPTION_QUERY_TPL.format(table_name=table_name)
100
+ response = self._call(self.tooling_url, params={"q": query})
101
+ if not response["records"]:
102
+ return None
103
+ return response["records"][0]["Description"]
104
+
105
+ def add_table_descriptions(self, sobjects: List[dict]) -> List[dict]:
106
+ """
107
+ Add table descriptions.
108
+ We use the tooling API which does not handle well the LIMIT in SOQL
109
+ so we have to retrieve descriptions individually
110
+ """
111
+ described_sobjects = []
112
+ for sobject in sobjects:
113
+ description = self.fetch_description(sobject["QualifiedApiName"])
114
+ described_sobjects.append({**sobject, "Description": description})
115
+ return described_sobjects
116
+
93
117
  def tables(self) -> List[dict]:
94
118
  """
95
119
  Get Salesforce sobjects as tables
96
120
  """
97
121
  sobjects = self.fetch_sobjects()
98
122
  logger.info(f"Extracted {len(sobjects)} sobjects")
99
- return list(self.formatter.tables(sobjects))
123
+ described_sobjects = self.add_table_descriptions(sobjects)
124
+ return list(self.formatter.tables(described_sobjects))
100
125
 
101
126
  def columns(
102
127
  self, sobject_names: List[Tuple[str, str]], show_progress: bool = True
@@ -44,7 +44,7 @@ def _to_table_payload(sobject: dict, table_name: str) -> dict:
44
44
  "label": sobject["Label"],
45
45
  "schema_id": SCHEMA_NAME,
46
46
  "table_name": table_name,
47
- "description": "",
47
+ "description": sobject.get("Description"),
48
48
  "tags": [],
49
49
  "type": "TABLE",
50
50
  }
@@ -19,7 +19,6 @@ def _example_sobjects() -> Tuple[Dict[str, str], ...]:
19
19
 
20
20
 
21
21
  def test__field_description():
22
-
23
22
  field = {}
24
23
  assert _field_description(field) == ""
25
24
 
@@ -59,7 +58,7 @@ def test__merge_label_and_api_name():
59
58
  "label": "foo",
60
59
  "schema_id": SCHEMA_NAME,
61
60
  "table_name": expected_name,
62
- "description": "",
61
+ "description": None,
63
62
  "tags": [],
64
63
  "type": "TABLE",
65
64
  }
@@ -1,3 +1,9 @@
1
+ DESCRIPTION_QUERY_TPL = """
2
+ SELECT Description
3
+ FROM EntityDefinition
4
+ WHERE QualifiedApiName = '{table_name}'
5
+ """
6
+
1
7
  SOBJECTS_QUERY_TPL = """
2
8
  SELECT
3
9
  DeveloperName,
@@ -13,7 +19,6 @@ SOBJECTS_QUERY_TPL = """
13
19
  LIMIT {limit}
14
20
  """
15
21
 
16
-
17
22
  SOBJECT_FIELDS_QUERY_TPL = """
18
23
  SELECT
19
24
  DataType,