castor-extractor 0.19.4__py3-none-any.whl → 0.19.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of castor-extractor might be problematic. Click here for more details.

Files changed (40) hide show
  1. CHANGELOG.md +13 -0
  2. castor_extractor/quality/soda/client/pagination.py +1 -1
  3. castor_extractor/utils/__init__.py +1 -0
  4. castor_extractor/utils/client/__init__.py +1 -1
  5. castor_extractor/utils/client/api/__init__.py +1 -1
  6. castor_extractor/utils/client/api/client.py +33 -7
  7. castor_extractor/utils/client/api/pagination.py +23 -6
  8. castor_extractor/utils/pager/__init__.py +0 -1
  9. castor_extractor/utils/salesforce/client.py +45 -50
  10. castor_extractor/utils/salesforce/client_test.py +2 -2
  11. castor_extractor/utils/salesforce/pagination.py +33 -0
  12. castor_extractor/visualization/metabase/client/api/client.py +30 -11
  13. castor_extractor/visualization/salesforce_reporting/client/rest.py +4 -3
  14. castor_extractor/visualization/sigma/client/client.py +2 -1
  15. castor_extractor/visualization/tableau_revamp/assets.py +8 -0
  16. castor_extractor/visualization/tableau_revamp/client/client.py +6 -1
  17. castor_extractor/warehouse/databricks/api_client.py +239 -0
  18. castor_extractor/warehouse/databricks/api_client_test.py +15 -0
  19. castor_extractor/warehouse/databricks/client.py +37 -489
  20. castor_extractor/warehouse/databricks/client_test.py +1 -99
  21. castor_extractor/warehouse/databricks/endpoints.py +28 -0
  22. castor_extractor/warehouse/databricks/lineage.py +141 -0
  23. castor_extractor/warehouse/databricks/lineage_test.py +34 -0
  24. castor_extractor/warehouse/databricks/pagination.py +22 -0
  25. castor_extractor/warehouse/databricks/sql_client.py +90 -0
  26. castor_extractor/warehouse/databricks/utils.py +44 -1
  27. castor_extractor/warehouse/databricks/utils_test.py +58 -1
  28. castor_extractor/warehouse/mysql/client.py +0 -3
  29. castor_extractor/warehouse/salesforce/client.py +12 -59
  30. castor_extractor/warehouse/salesforce/pagination.py +34 -0
  31. castor_extractor/warehouse/sqlserver/client.py +0 -2
  32. {castor_extractor-0.19.4.dist-info → castor_extractor-0.19.7.dist-info}/METADATA +14 -1
  33. {castor_extractor-0.19.4.dist-info → castor_extractor-0.19.7.dist-info}/RECORD +36 -31
  34. castor_extractor/utils/client/api_deprecated.py +0 -89
  35. castor_extractor/utils/client/api_deprecated_test.py +0 -18
  36. castor_extractor/utils/pager/pager_on_token.py +0 -52
  37. castor_extractor/utils/pager/pager_on_token_test.py +0 -73
  38. {castor_extractor-0.19.4.dist-info → castor_extractor-0.19.7.dist-info}/LICENCE +0 -0
  39. {castor_extractor-0.19.4.dist-info → castor_extractor-0.19.7.dist-info}/WHEEL +0 -0
  40. {castor_extractor-0.19.4.dist-info → castor_extractor-0.19.7.dist-info}/entry_points.txt +0 -0
@@ -1,14 +1,7 @@
1
- from datetime import date
2
1
  from unittest.mock import Mock, patch
3
2
 
4
- from freezegun import freeze_time
5
-
6
- from ..abstract.time_filter import TimeFilter
7
3
  from .client import (
8
4
  DatabricksClient,
9
- DatabricksCredentials,
10
- LineageLinks,
11
- _day_hour_to_epoch_ms,
12
5
  )
13
6
  from .test_constants import (
14
7
  CLOSER_DATE,
@@ -18,74 +11,12 @@ from .test_constants import (
18
11
  )
19
12
 
20
13
 
21
- def test__day_hour_to_epoch_ms():
22
- _day_hour_to_epoch_ms(date(2023, 2, 14), 14) == 1644847200000
23
-
24
-
25
- @freeze_time("2023-7-4")
26
- def test_DatabricksClient__hourly_time_filters():
27
- credentials = DatabricksCredentials(
28
- host="carthago",
29
- token="delenda",
30
- http_host="est",
31
- )
32
- client = DatabricksClient(credentials)
33
-
34
- # default is yesterday
35
- default_filters = [f for f in client._hourly_time_filters(None)]
36
-
37
- assert len(default_filters) == 24 # number of hours in a day
38
-
39
- first = default_filters[0]
40
- start = first["filter_by"]["query_start_time_range"]["start_time_ms"]
41
- last = default_filters[-1]
42
- end = last["filter_by"]["query_start_time_range"]["end_time_ms"]
43
- assert start == 1688342400000 # July 3, 2023 12:00:00 AM GMT
44
- assert end == 1688428800000 # July 4, 2023 12:00:00 AM GMT
45
-
46
- # custom time (from execution_date in DAG for example)
47
- time_filter = TimeFilter(day=date(2020, 10, 15))
48
- custom_filters = [f for f in client._hourly_time_filters(time_filter)]
49
-
50
- assert len(custom_filters) == 24
51
-
52
- first = custom_filters[0]
53
- start = first["filter_by"]["query_start_time_range"]["start_time_ms"]
54
- last = custom_filters[-1]
55
- end = last["filter_by"]["query_start_time_range"]["end_time_ms"]
56
- assert start == 1602720000000 # Oct 15, 2020 12:00:00 AM
57
- assert end == 1602806400000 # Oct 16, 2020 12:00:00 AM
58
-
59
- # hourly extraction: note that hour_min == hour_max
60
- hourly = TimeFilter(day=date(2023, 4, 14), hour_min=4, hour_max=4)
61
- hourly_filters = [f for f in client._hourly_time_filters(hourly)]
62
- expected_hourly = [
63
- {
64
- "filter_by": {
65
- "query_start_time_range": {
66
- "end_time_ms": 1681448400000, # April 14, 2023 5:00:00 AM
67
- "start_time_ms": 1681444800000, # April 14, 2023 4:00:00 AM
68
- }
69
- }
70
- }
71
- ]
72
- assert hourly_filters == expected_hourly
73
-
74
-
75
14
  class MockDatabricksClient(DatabricksClient):
76
15
  def __init__(self):
77
16
  self._db_allowed = ["prd", "staging"]
78
17
  self._db_blocked = ["dev"]
79
18
 
80
19
 
81
- def test_DatabricksClient__keep_catalog():
82
- client = MockDatabricksClient()
83
- assert client._keep_catalog("prd")
84
- assert client._keep_catalog("staging")
85
- assert not client._keep_catalog("dev")
86
- assert not client._keep_catalog("something_unknown")
87
-
88
-
89
20
  def test_DatabricksClient__get_user_mapping():
90
21
  client = MockDatabricksClient()
91
22
  users = [
@@ -120,7 +51,7 @@ def test_DatabricksClient__match_table_with_user():
120
51
 
121
52
 
122
53
  @patch(
123
- "source.packages.extractor.castor_extractor.warehouse.databricks.client.DatabricksClient.get",
54
+ "source.packages.extractor.castor_extractor.warehouse.databricks.client.DatabricksAPIClient._get",
124
55
  side_effect=TABLE_LINEAGE_SIDE_EFFECT,
125
56
  )
126
57
  def test_DatabricksClient_table_lineage(mock_get):
@@ -141,32 +72,3 @@ def test_DatabricksClient_table_lineage(mock_get):
141
72
  }
142
73
  assert expected_link_1 in lineage
143
74
  assert expected_link_2 in lineage
144
-
145
-
146
- def test_LineageLinks_add():
147
- links = LineageLinks()
148
- timestamped_link = ("parent", "child", None)
149
- expected_key = ("parent", "child")
150
-
151
- links.add(timestamped_link)
152
-
153
- assert expected_key in links.lineage
154
- assert links.lineage[expected_key] is None
155
-
156
- # we replace None by an actual timestamp
157
- timestamped_link = ("parent", "child", OLDER_DATE)
158
- links.add(timestamped_link)
159
- assert expected_key in links.lineage
160
- assert links.lineage[expected_key] == OLDER_DATE
161
-
162
- # we update with the more recent timestamp
163
- timestamped_link = ("parent", "child", CLOSER_DATE)
164
- links.add(timestamped_link)
165
- assert expected_key in links.lineage
166
- assert links.lineage[expected_key] == CLOSER_DATE
167
-
168
- # we keep the more recent timestamp
169
- timestamped_link = ("parent", "child", OLDER_DATE)
170
- links.add(timestamped_link)
171
- assert expected_key in links.lineage
172
- assert links.lineage[expected_key] == CLOSER_DATE
@@ -0,0 +1,28 @@
1
+ class DatabricksEndpointFactory:
2
+ @classmethod
3
+ def tables(cls):
4
+ return "api/2.1/unity-catalog/tables"
5
+
6
+ @classmethod
7
+ def schemas(cls):
8
+ return "api/2.1/unity-catalog/schemas"
9
+
10
+ @classmethod
11
+ def databases(cls):
12
+ return "api/2.1/unity-catalog/catalogs"
13
+
14
+ @classmethod
15
+ def table_lineage(cls):
16
+ return "api/2.0/lineage-tracking/table-lineage"
17
+
18
+ @classmethod
19
+ def column_lineage(cls):
20
+ return "api/2.0/lineage-tracking/column-lineage"
21
+
22
+ @classmethod
23
+ def queries(cls):
24
+ return "api/2.0/sql/history/queries"
25
+
26
+ @classmethod
27
+ def users(cls):
28
+ return "api/2.0/preview/scim/v2/Users"
@@ -0,0 +1,141 @@
1
+ from typing import Dict, List, Set, Tuple, cast
2
+
3
+ from .types import Link, Ostr, OTimestampedLink, TimestampedLink
4
+
5
+
6
+ class LineageLinks:
7
+ """
8
+ helper class that handles lineage deduplication and filtering
9
+ """
10
+
11
+ def __init__(self):
12
+ self.lineage: Dict[Link, Ostr] = dict()
13
+
14
+ def add(self, timestamped_link: TimestampedLink) -> None:
15
+ """
16
+ keep the most recent lineage link, adding to `self.lineage`
17
+ """
18
+ parent, child, timestamp = timestamped_link
19
+ link = (parent, child)
20
+ if not self.lineage.get(link):
21
+ self.lineage[link] = timestamp
22
+ return
23
+
24
+ if not timestamp:
25
+ return
26
+ # keep most recent link; cast for mypy
27
+ recent = max(cast(str, self.lineage[link]), cast(str, timestamp))
28
+ self.lineage[link] = recent
29
+
30
+
31
+ def _to_table_path(table: dict) -> Ostr:
32
+ if table.get("name"):
33
+ return f"{table['catalog_name']}.{table['schema_name']}.{table['name']}"
34
+ return None
35
+
36
+
37
+ def _to_column_path(column: dict) -> Ostr:
38
+ if column.get("name"):
39
+ return f"{column['catalog_name']}.{column['schema_name']}.{column['table_name']}.{column['name']}"
40
+ return None
41
+
42
+
43
+ def _link(path_from: Ostr, path_to: Ostr, timestamp: Ostr) -> OTimestampedLink:
44
+ """exclude missing path and self-lineage"""
45
+ if (not path_from) or (not path_to):
46
+ return None
47
+ is_self_lineage = path_from.lower() == path_to.lower()
48
+ if is_self_lineage:
49
+ return None
50
+ return path_from, path_to, timestamp
51
+
52
+
53
+ def single_table_lineage_links(
54
+ table_path: str, single_table_lineage: dict
55
+ ) -> List[TimestampedLink]:
56
+ """
57
+ process databricks lineage API response for a given table
58
+ returns a list of (parent, child, timestamp)
59
+
60
+ Note: in `upstreams` or `downstreams` we only care about `tableInfo`,
61
+ we could also have `notebookInfos` or `fileInfo`
62
+ """
63
+ links: List[OTimestampedLink] = []
64
+ # add parent:
65
+ for link in single_table_lineage.get("upstreams", []):
66
+ parent = link.get("tableInfo", {})
67
+ parent_path = _to_table_path(parent)
68
+ timestamp: Ostr = parent.get("lineage_timestamp")
69
+ links.append(_link(parent_path, table_path, timestamp))
70
+
71
+ # add children:
72
+ for link in single_table_lineage.get("downstreams", []):
73
+ child = link.get("tableInfo", {})
74
+ child_path = _to_table_path(child)
75
+ timestamp = child.get("lineage_timestamp")
76
+ links.append(_link(table_path, child_path, timestamp))
77
+
78
+ return list(filter(None, links))
79
+
80
+
81
+ def single_column_lineage_links(
82
+ column_path: str, single_column_lineage: dict
83
+ ) -> List[TimestampedLink]:
84
+ """
85
+ process databricks lineage API response for a given table
86
+ returns a list of (parent, child, timestamp)
87
+
88
+ Note: in `upstreams` or `downstreams` we only care about `tableInfo`,
89
+ we could also have `notebookInfos` or `fileInfo`
90
+ """
91
+ links: List[OTimestampedLink] = []
92
+ # add parent:
93
+ for link in single_column_lineage.get("upstream_cols", []):
94
+ parent_path = _to_column_path(link)
95
+ timestamp: Ostr = link.get("lineage_timestamp")
96
+ links.append(_link(parent_path, column_path, timestamp))
97
+
98
+ # add children:
99
+ for link in single_column_lineage.get("downstream_cols", []):
100
+ child_path = _to_column_path(link)
101
+ timestamp = link.get("lineage_timestamp")
102
+ links.append(_link(column_path, child_path, timestamp))
103
+
104
+ return list(filter(None, links))
105
+
106
+
107
+ def paths_for_column_lineage(
108
+ tables: List[dict], columns: List[dict], table_lineage: List[dict]
109
+ ) -> List[Tuple[str, str]]:
110
+ """
111
+ helper providing a list of candidate columns to look lineage for:
112
+ we only look for column lineage where there is table lineage
113
+ """
114
+ # mapping between table id and its path db.schema.table
115
+ # table["schema_id"] follows the pattern `db.schema`
116
+ mapping = {
117
+ table["id"]: ".".join([table["schema_id"], table["table_name"]])
118
+ for table in tables
119
+ }
120
+
121
+ tables_with_lineage: Set[str] = set()
122
+ for t in table_lineage:
123
+ tables_with_lineage.add(t["parent_path"])
124
+ tables_with_lineage.add(t["child_path"])
125
+
126
+ paths_to_return: List[Tuple[str, str]] = []
127
+ for column in columns:
128
+ table_path = mapping[column["table_id"]]
129
+ if table_path not in tables_with_lineage:
130
+ continue
131
+ column_ = (table_path, column["column_name"])
132
+ paths_to_return.append(column_)
133
+
134
+ return paths_to_return
135
+
136
+
137
+ def deduplicate_lineage(lineages: List[TimestampedLink]) -> dict:
138
+ deduplicated_lineage = LineageLinks()
139
+ for timestamped_link in lineages:
140
+ deduplicated_lineage.add(timestamped_link)
141
+ return deduplicated_lineage.lineage
@@ -0,0 +1,34 @@
1
+ from .lineage import LineageLinks
2
+ from .test_constants import (
3
+ CLOSER_DATE,
4
+ OLDER_DATE,
5
+ )
6
+
7
+
8
+ def test_LineageLinks_add():
9
+ links = LineageLinks()
10
+ timestamped_link = ("parent", "child", None)
11
+ expected_key = ("parent", "child")
12
+
13
+ links.add(timestamped_link)
14
+
15
+ assert expected_key in links.lineage
16
+ assert links.lineage[expected_key] is None
17
+
18
+ # we replace None by an actual timestamp
19
+ timestamped_link = ("parent", "child", OLDER_DATE)
20
+ links.add(timestamped_link)
21
+ assert expected_key in links.lineage
22
+ assert links.lineage[expected_key] == OLDER_DATE
23
+
24
+ # we update with the more recent timestamp
25
+ timestamped_link = ("parent", "child", CLOSER_DATE)
26
+ links.add(timestamped_link)
27
+ assert expected_key in links.lineage
28
+ assert links.lineage[expected_key] == CLOSER_DATE
29
+
30
+ # we keep the more recent timestamp
31
+ timestamped_link = ("parent", "child", OLDER_DATE)
32
+ links.add(timestamped_link)
33
+ assert expected_key in links.lineage
34
+ assert links.lineage[expected_key] == CLOSER_DATE
@@ -0,0 +1,22 @@
1
+ from typing import List, Optional
2
+
3
+ from pydantic import Field
4
+
5
+ from ...utils import PaginationModel
6
+
7
+ DATABRICKS_PAGE_SIZE = 100
8
+
9
+
10
+ class DatabricksPagination(PaginationModel):
11
+ next_page_token: Optional[str] = None
12
+ has_next_page: bool = False
13
+ res: List[dict] = Field(default_factory=list)
14
+
15
+ def is_last(self) -> bool:
16
+ return not (self.has_next_page and self.next_page_token)
17
+
18
+ def next_page_payload(self) -> dict:
19
+ return {"page_token": self.next_page_token}
20
+
21
+ def page_results(self) -> list:
22
+ return self.res
@@ -0,0 +1,90 @@
1
+ import logging
2
+ from collections import defaultdict
3
+ from enum import Enum
4
+ from typing import Optional
5
+
6
+ from databricks import sql # type: ignore
7
+
8
+ from .credentials import DatabricksCredentials
9
+ from .format import TagMapping
10
+ from .utils import build_path, tag_label
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ _INFORMATION_SCHEMA_SQL = "SELECT * FROM system.information_schema"
15
+
16
+
17
+ class TagEntity(Enum):
18
+ """Entities that can be tagged in Databricks"""
19
+
20
+ COLUMN = "COLUMN"
21
+ TABLE = "TABLE"
22
+
23
+
24
+ class DatabricksSQLClient:
25
+ def __init__(
26
+ self,
27
+ credentials: DatabricksCredentials,
28
+ has_table_tags: bool = False,
29
+ has_column_tags: bool = False,
30
+ ):
31
+ self._http_path = credentials.http_path
32
+ self._has_table_tags = has_table_tags
33
+ self._has_column_tags = has_column_tags
34
+ self._host = credentials.host
35
+ self._token = credentials.token
36
+
37
+ def execute_sql(
38
+ self,
39
+ query: str,
40
+ params: Optional[dict] = None,
41
+ ):
42
+ """
43
+ Execute a SQL query on Databricks system tables and return the results.
44
+ https://docs.databricks.com/en/dev-tools/python-sql-connector.html
45
+
46
+ //!\\ credentials.http_path is required in order to run SQL queries
47
+ """
48
+ assert self._http_path, "HTTP_PATH is required to run SQL queries"
49
+ with sql.connect(
50
+ server_hostname=self._host,
51
+ http_path=self._http_path,
52
+ access_token=self._token,
53
+ ) as connection:
54
+ with connection.cursor() as cursor:
55
+ cursor.execute(query, params)
56
+ return cursor.fetchall()
57
+
58
+ def _needs_extraction(self, entity: TagEntity) -> bool:
59
+ if entity == TagEntity.TABLE:
60
+ return self._has_table_tags
61
+ if entity == TagEntity.COLUMN:
62
+ return self._has_column_tags
63
+ raise AssertionError(f"Entity not supported: {entity}")
64
+
65
+ def get_tags_mapping(self, entity: TagEntity) -> TagMapping:
66
+ """
67
+ Fetch tags of the given entity and build a mapping:
68
+ { path: list[tags] }
69
+
70
+ https://docs.databricks.com/en/sql/language-manual/information-schema/table_tags.html
71
+ https://docs.databricks.com/en/sql/language-manual/information-schema/column_tags.html
72
+ """
73
+ if not self._needs_extraction(entity):
74
+ # extracting tags require additional credentials (http_path)
75
+ return dict()
76
+
77
+ table = f"{entity.value.lower()}_tags"
78
+ query = f"{_INFORMATION_SCHEMA_SQL}.{table}"
79
+ result = self.execute_sql(query)
80
+ mapping = defaultdict(list)
81
+ for row in result:
82
+ dict_row = row.asDict()
83
+ keys = ["catalog_name", "schema_name", "table_name"]
84
+ if entity == TagEntity.COLUMN:
85
+ keys.append("column_name")
86
+ path = build_path(dict_row, keys)
87
+ label = tag_label(dict_row)
88
+ mapping[path].append(label)
89
+
90
+ return mapping
@@ -1,4 +1,16 @@
1
- from typing import Dict, List
1
+ from datetime import date
2
+ from typing import Dict, Iterable, List, Optional
3
+
4
+ from ...utils import at_midnight
5
+ from ..abstract import TimeFilter
6
+
7
+ _DEFAULT_HOUR_MIN = 0
8
+ _DEFAULT_HOUR_MAX = 23
9
+ _NUM_HOURS_IN_A_DAY = 24
10
+
11
+
12
+ def _day_hour_to_epoch_ms(day: date, hour: int) -> int:
13
+ return int(at_midnight(day).timestamp() * 1000) + (hour * 3600 * 1000)
2
14
 
3
15
 
4
16
  def build_path(
@@ -25,3 +37,34 @@ def tag_label(row: Dict) -> str:
25
37
  if not tag_value:
26
38
  return tag_name
27
39
  return f"{tag_name}:{tag_value}"
40
+
41
+
42
+ def _time_filter_payload(start_time_ms: int, end_time_ms: int) -> dict:
43
+ return {
44
+ "filter_by": {
45
+ "query_start_time_range": {
46
+ "end_time_ms": end_time_ms,
47
+ "start_time_ms": start_time_ms,
48
+ }
49
+ }
50
+ }
51
+
52
+
53
+ def hourly_time_filters(time_filter: Optional[TimeFilter]) -> Iterable[dict]:
54
+ """time filters to retrieve Databricks' queries: 1h duration each"""
55
+ # define an explicit time window
56
+ if not time_filter:
57
+ time_filter = TimeFilter.default()
58
+
59
+ assert time_filter # for mypy
60
+
61
+ hour_min = time_filter.hour_min
62
+ hour_max = time_filter.hour_max
63
+ day = time_filter.day
64
+ if hour_min is None or hour_max is None: # fallback to an entire day
65
+ hour_min, hour_max = _DEFAULT_HOUR_MIN, _DEFAULT_HOUR_MAX
66
+
67
+ for index in range(hour_min, min(hour_max + 1, _NUM_HOURS_IN_A_DAY)):
68
+ start_time_ms = _day_hour_to_epoch_ms(day, index)
69
+ end_time_ms = _day_hour_to_epoch_ms(day, index + 1)
70
+ yield _time_filter_payload(start_time_ms, end_time_ms)
@@ -1,4 +1,14 @@
1
- from .utils import build_path, tag_label
1
+ from datetime import date
2
+
3
+ from freezegun import freeze_time
4
+
5
+ from ..abstract import TimeFilter
6
+ from .utils import (
7
+ _day_hour_to_epoch_ms,
8
+ build_path,
9
+ hourly_time_filters,
10
+ tag_label,
11
+ )
2
12
 
3
13
 
4
14
  def test_build_path():
@@ -23,3 +33,50 @@ def test_tag_label():
23
33
  "tag_value": "fou",
24
34
  }
25
35
  assert tag_label(row) == "fi:fou"
36
+
37
+
38
+ def test__day_hour_to_epoch_ms():
39
+ assert _day_hour_to_epoch_ms(date(2023, 2, 14), 14) == 1676383200000
40
+
41
+
42
+ @freeze_time("2023-7-4")
43
+ def test_hourly_time_filters():
44
+ # default is yesterday
45
+ default_filters = [f for f in hourly_time_filters(None)]
46
+
47
+ assert len(default_filters) == 24 # number of hours in a day
48
+
49
+ first = default_filters[0]
50
+ start = first["filter_by"]["query_start_time_range"]["start_time_ms"]
51
+ last = default_filters[-1]
52
+ end = last["filter_by"]["query_start_time_range"]["end_time_ms"]
53
+ assert start == 1688342400000 # July 3, 2023 12:00:00 AM GMT
54
+ assert end == 1688428800000 # July 4, 2023 12:00:00 AM GMT
55
+
56
+ # custom time (from execution_date in DAG for example)
57
+ time_filter = TimeFilter(day=date(2020, 10, 15))
58
+ custom_filters = [f for f in hourly_time_filters(time_filter)]
59
+
60
+ assert len(custom_filters) == 24
61
+
62
+ first = custom_filters[0]
63
+ start = first["filter_by"]["query_start_time_range"]["start_time_ms"]
64
+ last = custom_filters[-1]
65
+ end = last["filter_by"]["query_start_time_range"]["end_time_ms"]
66
+ assert start == 1602720000000 # Oct 15, 2020 12:00:00 AM
67
+ assert end == 1602806400000 # Oct 16, 2020 12:00:00 AM
68
+
69
+ # hourly extraction: note that hour_min == hour_max
70
+ hourly = TimeFilter(day=date(2023, 4, 14), hour_min=4, hour_max=4)
71
+ hourly_filters = [f for f in hourly_time_filters(hourly)]
72
+ expected_hourly = [
73
+ {
74
+ "filter_by": {
75
+ "query_start_time_range": {
76
+ "end_time_ms": 1681448400000, # April 14, 2023 5:00:00 AM
77
+ "start_time_ms": 1681444800000, # April 14, 2023 4:00:00 AM
78
+ }
79
+ }
80
+ }
81
+ ]
82
+ assert hourly_filters == expected_hourly
@@ -1,6 +1,3 @@
1
- # this import is necessary so deptry does not mark the package as unused
2
- import pymysql # type: ignore # noqa: F401
3
-
4
1
  from ...utils import SqlalchemyClient, uri_encode
5
2
 
6
3
  SERVER_URI = "{user}:{password}@{host}:{port}"
@@ -1,14 +1,16 @@
1
1
  import logging
2
- from typing import Dict, Iterator, List, Optional, Tuple
2
+ from functools import partial
3
+ from typing import Dict, List, Optional, Tuple
3
4
 
4
5
  from tqdm import tqdm # type: ignore
5
6
 
7
+ from ...utils import fetch_all_pages
6
8
  from ...utils.salesforce import SalesforceBaseClient, SalesforceCredentials
7
9
  from .format import SalesforceFormatter
10
+ from .pagination import SalesforceSQLPagination, format_sobject_query
8
11
  from .soql import (
9
12
  DESCRIPTION_QUERY_TPL,
10
13
  SOBJECT_FIELDS_QUERY_TPL,
11
- SOBJECTS_QUERY_TPL,
12
14
  )
13
15
 
14
16
  logger = logging.getLogger(__name__)
@@ -19,9 +21,6 @@ class SalesforceClient(SalesforceBaseClient):
19
21
  Salesforce API client to extract sobjects
20
22
  """
21
23
 
22
- # Implicit (hard-coded in Salesforce) limitation when using SOQL of 2,000 rows
23
- LIMIT_RECORDS_PER_PAGE = 2000
24
-
25
24
  def __init__(self, credentials: SalesforceCredentials):
26
25
  super().__init__(credentials)
27
26
  self.formatter = SalesforceFormatter()
@@ -30,74 +29,28 @@ class SalesforceClient(SalesforceBaseClient):
30
29
  def name() -> str:
31
30
  return "Salesforce"
32
31
 
33
- def _format_query(self, query_template: str, start_durable_id: str) -> str:
34
- return query_template.format(
35
- start_durable_id=start_durable_id,
36
- limit=self.LIMIT_RECORDS_PER_PAGE,
37
- )
38
-
39
- def _next_records(
40
- self, url: str, query_template: str, start_durable_id: str = "0000"
41
- ) -> List[dict]:
42
- query = self._format_query(
43
- query_template, start_durable_id=start_durable_id
44
- )
45
- records, _ = self._call(
46
- url, params={"q": query}, processor=self._query_processor
47
- )
48
- return records
49
-
50
- def _is_last_page(self, records: List[dict]) -> bool:
51
- return len(records) < self.LIMIT_RECORDS_PER_PAGE
52
-
53
- def _should_query_next_page(
54
- self, records: List[dict], page_number: int
55
- ) -> bool:
56
- return not (
57
- self._is_last_page(records)
58
- or self._has_reached_pagination_limit(page_number)
59
- )
60
-
61
- def _query_all(self, query_template: str) -> Iterator[dict]:
62
- """
63
- Run a SOQL query over salesforce API
64
-
65
- Note, pagination is performed via a LIMIT in the SOQL query and requires
66
- that ids are sorted. The SOQL query must support `limit` and
67
- `start_durable_id` as parameters.
68
- """
69
- url = self.query_url
70
- logger.info("querying page 0")
71
- records = self._next_records(url, query_template)
72
- yield from records
73
-
74
- page_count = 1
75
- while self._should_query_next_page(records, page_count):
76
- logger.info(f"querying page {page_count}")
77
- last_durable_id = records[-1]["DurableId"]
78
- records = self._next_records(
79
- url, query_template, start_durable_id=last_durable_id
80
- )
81
- yield from records
82
- page_count += 1
83
-
84
32
  def fetch_sobjects(self) -> List[dict]:
85
33
  """Fetch all sobjects"""
86
34
  logger.info("Extracting sobjects")
87
- return list(self._query_all(SOBJECTS_QUERY_TPL))
35
+ query = format_sobject_query()
36
+ request_ = partial(
37
+ self._get, endpoint=self.query_endpoint, params={"q": query}
38
+ )
39
+ results = fetch_all_pages(request_, SalesforceSQLPagination)
40
+ return list(results)
88
41
 
89
42
  def fetch_fields(self, sobject_name: str) -> List[dict]:
90
43
  """Fetches fields of a given sobject"""
91
44
  query = SOBJECT_FIELDS_QUERY_TPL.format(
92
45
  entity_definition_id=sobject_name
93
46
  )
94
- response = self._call(self.tooling_url, params={"q": query})
47
+ response = self._get(self.tooling_endpoint, params={"q": query})
95
48
  return response["records"]
96
49
 
97
50
  def fetch_description(self, table_name: str) -> Optional[str]:
98
51
  """Retrieve description of a table"""
99
52
  query = DESCRIPTION_QUERY_TPL.format(table_name=table_name)
100
- response = self._call(self.tooling_url, params={"q": query})
53
+ response = self._get(self.tooling_endpoint, params={"q": query})
101
54
  if not response["records"]:
102
55
  return None
103
56
  return response["records"][0]["Description"]