dasl-client 1.0.14__py3-none-any.whl → 1.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dasl-client might be problematic. Click here for more details.

@@ -0,0 +1,125 @@
1
+ import inspect
2
+ import json
3
+
4
+ import pytest
5
+
6
+ from datetime import datetime
7
+ from hashlib import md5
8
+ from typing import Optional, Type, Union
9
+
10
+ from dasl_api.models import *
11
+ from pydantic import BaseModel
12
+ from pydantic.fields import FieldInfo
13
+
14
+
15
+ checked_dasl_types = {
16
+ # Resources
17
+ WorkspaceV1AdminConfig: "admin_config.json",
18
+ CoreV1DataSource: "data_source.json",
19
+ CoreV1Rule: "rule.json",
20
+ WorkspaceV1WorkspaceConfig: "workspace_config.json",
21
+ ContentV1DatasourcePreset: "datasource_preset.json",
22
+ # Data
23
+ DbuiV1ObservableEventsList: "observable_events_list.json",
24
+ }
25
+
26
+
27
+ simple_types = [
28
+ bool,
29
+ int,
30
+ float,
31
+ str,
32
+ datetime,
33
+ ]
34
+
35
+
36
+ def is_simple_type(tpe: Type) -> bool:
37
+ return tpe in simple_types
38
+
39
+
40
+ def is_dasl_api_type(tpe: Type) -> bool:
41
+ if tpe.__name__ in globals():
42
+ return "dasl_api" in globals()[tpe.__name__].__module__
43
+ return False
44
+
45
+
46
+ def dasl_model_to_dict(tpe: Type[BaseModel]) -> dict:
47
+ decorators = getattr(
48
+ getattr(tpe, "__pydantic_decorators__", None), "field_validators", {}
49
+ )
50
+ return {
51
+ "name": tpe.__name__,
52
+ "fields": [
53
+ field_to_dict(name, field, decorators)
54
+ for name, field in tpe.model_fields.items()
55
+ ],
56
+ }
57
+
58
+
59
+ def field_to_dict(name: str, field: FieldInfo, validators: dict) -> dict:
60
+ d = {
61
+ "name": name,
62
+ "alias": field.alias,
63
+ "is_required": field.is_required(),
64
+ "is_nullable": is_nullable(field.annotation),
65
+ "is_sequence": is_sequence(field.annotation),
66
+ "validation_hash": field_validation_hash(name, validators),
67
+ }
68
+ field_type: Union[*simple_types, BaseModel] = inner_type(field.annotation)
69
+ if is_simple_type(field_type):
70
+ d["type"] = field_type.__name__
71
+ elif is_dasl_api_type(field_type):
72
+ d["type"] = dasl_model_to_dict(field_type)
73
+ else:
74
+ raise Exception(
75
+ f"unsupported field type {field_type} encountered while converting field - {name}: {field}"
76
+ )
77
+ return d
78
+
79
+
80
+ def is_sequence(tpe: Type) -> bool:
81
+ seq_types = [list, set, frozenset, tuple]
82
+ if tpe in seq_types:
83
+ return True
84
+ if hasattr(tpe, "__origin__"):
85
+ if tpe.__origin__ in seq_types:
86
+ return True
87
+ if hasattr(tpe, "__args__"):
88
+ return is_sequence(tpe.__args__[0])
89
+ return False
90
+
91
+
92
+ def is_nullable(tpe: Type) -> bool:
93
+ return hasattr(tpe, "__args__") and type(None) in tpe.__args__
94
+
95
+
96
+ def field_validation_hash(field_name: str, validators: dict) -> Optional[str]:
97
+ for validator in validators.values():
98
+ if hasattr(validator, "info") and hasattr(validator.info, "fields"):
99
+ if field_name in validator.info.fields:
100
+ return md5(
101
+ inspect.getsource(validator.func).encode("utf-8")
102
+ ).hexdigest()
103
+ return None
104
+
105
+
106
+ def inner_type(tpe: Type) -> Type:
107
+ if hasattr(tpe, "__args__"):
108
+ return inner_type(tpe.__args__[0])
109
+ return tpe
110
+
111
+
112
+ def dasl_model_to_string(tpe: Type[BaseModel]) -> str:
113
+ d = dasl_model_to_dict(tpe)
114
+ return json.dumps(d, indent=2, sort_keys=True)
115
+
116
+
117
+ @pytest.mark.parametrize(
118
+ "tpe",
119
+ checked_dasl_types.keys(),
120
+ ids=[f"{tpe.__name__} model is unchanged" for tpe in checked_dasl_types.keys()],
121
+ )
122
+ def test_api_model_for_changes(tpe):
123
+ with open(f"test/expected_api_models/{checked_dasl_types[tpe]}", "r") as f:
124
+ expected_val = f.read()
125
+ assert dasl_model_to_string(tpe) == expected_val
@@ -0,0 +1,300 @@
1
+ from dasl_client import *
2
+
3
+ from .constants import *
4
+
5
+
6
+ def test_admin_config(api_client):
7
+ base_admin_config = AdminConfig(
8
+ workspace_url=databricks_host,
9
+ app_client_id=app_client_id,
10
+ service_principal_id=databricks_client_id,
11
+ service_principal_secret="********",
12
+ )
13
+
14
+ ac = api_client.get_admin_config()
15
+ assert ac == base_admin_config
16
+
17
+ other = AdminConfig(
18
+ workspace_url=databricks_host,
19
+ app_client_id=alternate_app_client_id,
20
+ service_principal_id=databricks_client_id,
21
+ service_principal_secret=databricks_client_secret,
22
+ )
23
+ api_client.put_admin_config(other)
24
+
25
+ assert api_client.get_admin_config() == AdminConfig(
26
+ workspace_url=databricks_host,
27
+ app_client_id=alternate_app_client_id,
28
+ service_principal_id=databricks_client_id,
29
+ service_principal_secret="********",
30
+ )
31
+
32
+ ac.service_principal_secret = databricks_client_secret
33
+ api_client.put_admin_config(ac)
34
+ assert api_client.get_admin_config() == base_admin_config
35
+
36
+
37
+ def test_workspace_config(api_client):
38
+ base_workspace_config = WorkspaceConfig(
39
+ metadata=Metadata(
40
+ name="config",
41
+ workspace=workspace,
42
+ client_of_origin=get_client_identifier(),
43
+ ),
44
+ dasl_storage_path="/Volumes/automated_test_cases/default/test",
45
+ system_tables_config=SystemTablesConfig(
46
+ catalog_name="automated_test_cases",
47
+ var_schema="default",
48
+ ),
49
+ default_config=DefaultConfig(
50
+ var_global=DefaultConfig.Config(
51
+ bronze_schema="bronze",
52
+ silver_schema="silver",
53
+ gold_schema="gold",
54
+ catalog_name="automated_test_cases",
55
+ ),
56
+ ),
57
+ )
58
+
59
+ api_client.put_config(base_workspace_config)
60
+ got = api_client.get_config()
61
+
62
+ # the server is going to populate created_timestamp, modified_timestamp,
63
+ # version, and resource_status, so copy those over before comparing.
64
+ base_workspace_config.metadata.created_timestamp = got.metadata.created_timestamp
65
+ base_workspace_config.metadata.modified_timestamp = got.metadata.modified_timestamp
66
+ base_workspace_config.metadata.version = got.metadata.version
67
+ base_workspace_config.metadata.resource_status = got.metadata.resource_status
68
+
69
+ assert api_client.get_config() == base_workspace_config
70
+
71
+ base_workspace_config.default_config.var_global.bronze_schema = "bronze_new"
72
+ api_client.put_config(base_workspace_config)
73
+ got = api_client.get_config()
74
+ base_workspace_config.metadata.modified_timestamp = got.metadata.modified_timestamp
75
+ base_workspace_config.metadata.version = got.metadata.version
76
+ base_workspace_config.metadata.resource_status = got.metadata.resource_status
77
+
78
+ assert api_client.get_config() == base_workspace_config
79
+
80
+
81
+ def test_minimal_data_source(api_client):
82
+ base_data_source = DataSource(
83
+ source="test",
84
+ schedule=Schedule(
85
+ at_least_every="2h",
86
+ enabled=True,
87
+ ),
88
+ bronze=BronzeSpec(
89
+ bronze_table="test_bronze_table",
90
+ skip_bronze_loading=False,
91
+ ),
92
+ silver=SilverSpec(),
93
+ gold=GoldSpec(),
94
+ )
95
+
96
+ base_ds_1 = api_client.create_datasource("test_1", base_data_source)
97
+ assert base_ds_1.source == base_data_source.source
98
+ assert base_ds_1.schedule == base_data_source.schedule
99
+ assert base_ds_1.bronze == base_data_source.bronze
100
+ assert base_ds_1.silver == base_data_source.silver
101
+ assert base_ds_1.gold == base_data_source.gold
102
+
103
+ got = api_client.get_datasource("test_1")
104
+ listed = []
105
+ for ds in api_client.list_datasources():
106
+ listed.append(ds)
107
+ assert len(listed) == 1
108
+ assert listed[0] == got
109
+
110
+ # the server is going to populate created_timestamp, modified_timestamp,
111
+ # version, and resource_status, so copy those over before comparing.
112
+ base_ds_1.metadata.created_timestamp = got.metadata.created_timestamp
113
+ base_ds_1.metadata.created_by = got.metadata.created_by
114
+ base_ds_1.metadata.modified_timestamp = got.metadata.modified_timestamp
115
+ base_ds_1.metadata.version = got.metadata.version
116
+ base_ds_1.metadata.resource_status = got.metadata.resource_status
117
+ assert api_client.get_datasource("test_1") == base_ds_1
118
+
119
+ base_ds_2 = api_client.create_datasource("test_2", base_data_source)
120
+ assert base_ds_2.source == base_data_source.source
121
+ assert base_ds_2.schedule == base_data_source.schedule
122
+ assert base_ds_2.bronze == base_data_source.bronze
123
+ assert base_ds_2.silver == base_data_source.silver
124
+ assert base_ds_2.gold == base_data_source.gold
125
+
126
+ got_2 = api_client.get_datasource("test_2")
127
+ listed = []
128
+ for ds in api_client.list_datasources():
129
+ listed.append(ds)
130
+ assert len(listed) == 2
131
+ assert listed[0] == got
132
+ assert listed[1] == got_2
133
+
134
+ base_ds_2.metadata.created_timestamp = got_2.metadata.created_timestamp
135
+ base_ds_2.metadata.created_by = got_2.metadata.created_by
136
+ base_ds_2.metadata.modified_timestamp = got_2.metadata.modified_timestamp
137
+ base_ds_2.metadata.version = got_2.metadata.version
138
+ base_ds_2.metadata.resource_status = got_2.metadata.resource_status
139
+ assert api_client.get_datasource("test_2") == base_ds_2
140
+
141
+ base_ds_2.bronze.bronze_table = "test_2"
142
+ api_client.replace_datasource("test_2", base_ds_2)
143
+
144
+ got_2 = api_client.get_datasource("test_2")
145
+ base_ds_2.metadata.modified_timestamp = got_2.metadata.modified_timestamp
146
+ base_ds_2.metadata.version = got_2.metadata.version
147
+ base_ds_2.metadata.resource_status = got_2.metadata.resource_status
148
+
149
+ assert api_client.get_datasource("test_2") == base_ds_2
150
+
151
+ api_client.delete_datasource("test_1")
152
+ listed = [
153
+ item
154
+ for item in api_client.list_datasources()
155
+ if item.metadata.resource_status != "deletionPending"
156
+ ]
157
+ assert len(listed) == 1
158
+ assert listed[0] == base_ds_2
159
+
160
+
161
+ def test_minimal_rule(api_client):
162
+ base_rule = Rule(
163
+ schedule=Schedule(
164
+ at_least_every="2h",
165
+ enabled=True,
166
+ ),
167
+ input=Rule.Input(
168
+ stream=Rule.Input.Stream(
169
+ tables=[
170
+ Rule.Input.Stream.Table(
171
+ name="test",
172
+ ),
173
+ ],
174
+ ),
175
+ ),
176
+ output=Rule.Output(
177
+ summary="test",
178
+ ),
179
+ )
180
+
181
+ base_rule_1 = api_client.create_rule("test_0", base_rule)
182
+ assert base_rule_1.schedule == base_rule.schedule
183
+ assert base_rule_1.input == base_rule.input
184
+ assert base_rule_1.output == base_rule.output
185
+
186
+ got = api_client.get_rule("test_0")
187
+ listed = []
188
+ for rule in api_client.list_rules():
189
+ listed.append(rule)
190
+ assert len(listed) == 1
191
+ assert listed[0] == got
192
+
193
+ # the server is going to populate created_timestamp, modified_timestamp,
194
+ # version, and resource_status, so copy those over before comparing.
195
+ base_rule_1.metadata.created_timestamp = got.metadata.created_timestamp
196
+ base_rule_1.metadata.created_by = got.metadata.created_by
197
+ base_rule_1.metadata.modified_timestamp = got.metadata.modified_timestamp
198
+ base_rule_1.metadata.version = got.metadata.version
199
+ base_rule_1.metadata.resource_status = got.metadata.resource_status
200
+ assert api_client.get_rule("test_0") == base_rule_1
201
+
202
+ base_rule_2 = api_client.create_rule("test_1", base_rule)
203
+ assert base_rule_2.schedule == base_rule.schedule
204
+ assert base_rule_2.input == base_rule.input
205
+ assert base_rule_2.output == base_rule.output
206
+
207
+ got_2 = api_client.get_rule("test_1")
208
+ listed = []
209
+ for rule in api_client.list_rules():
210
+ listed.append(rule)
211
+ assert len(listed) == 2
212
+ assert listed[0] == got
213
+ assert listed[1] == got_2
214
+
215
+ base_rule_2.metadata.created_timestamp = got_2.metadata.created_timestamp
216
+ base_rule_2.metadata.created_by = got_2.metadata.created_by
217
+ base_rule_2.metadata.modified_timestamp = got_2.metadata.modified_timestamp
218
+ base_rule_2.metadata.version = got_2.metadata.version
219
+ base_rule_2.metadata.resource_status = got_2.metadata.resource_status
220
+ assert api_client.get_rule("test_1") == base_rule_2
221
+
222
+ base_rule_2.input.stream.tables[0].name = "test_1"
223
+ api_client.replace_rule("test_1", base_rule_2)
224
+
225
+ got_2 = api_client.get_rule("test_1")
226
+ base_rule_2.metadata.modified_timestamp = got_2.metadata.modified_timestamp
227
+ base_rule_2.metadata.version = got_2.metadata.version
228
+ base_rule_2.metadata.resource_status = got_2.metadata.resource_status
229
+
230
+ assert api_client.get_rule("test_1") == base_rule_2
231
+
232
+ api_client.delete_rule("test_0")
233
+ listed = [
234
+ item
235
+ for item in api_client.list_rules()
236
+ if item.metadata.resource_status != "deletionPending"
237
+ ]
238
+ assert len(listed) == 1
239
+ assert listed[0] == base_rule_2
240
+
241
+
242
+ def test_list_pagination(api_client):
243
+ base_rule = Rule(
244
+ schedule=Schedule(
245
+ at_least_every="2h",
246
+ enabled=True,
247
+ ),
248
+ input=Rule.Input(
249
+ stream=Rule.Input.Stream(
250
+ tables=[
251
+ Rule.Input.Stream.Table(
252
+ name="test",
253
+ ),
254
+ ],
255
+ ),
256
+ ),
257
+ output=Rule.Output(
258
+ summary="test",
259
+ ),
260
+ )
261
+
262
+ # create (remainder of) 10 rules for the test
263
+ for i in range(8):
264
+ api_client.create_rule(f"test_{i+2}", base_rule)
265
+
266
+ # ensure all rules are returned for a list call with no params
267
+ listed = []
268
+ for rule in api_client.list_rules():
269
+ listed.append(rule)
270
+ assert len(listed) == 10
271
+
272
+ for i in range(10):
273
+ assert listed[i] == api_client.get_rule(f"test_{i}")
274
+
275
+ # ensure the first 5 rules are returned when limit=5
276
+ listed = []
277
+ for rule in api_client.list_rules(limit=5):
278
+ listed.append(rule)
279
+ assert len(listed) == 5
280
+
281
+ for i in range(5):
282
+ assert listed[i] == api_client.get_rule(f"test_{i}")
283
+
284
+ # ensure the last 5 rules are returned when limit=5, cursor=pagination_test_4
285
+ listed = []
286
+ for rule in api_client.list_rules(cursor="test_4", limit=5):
287
+ listed.append(rule)
288
+ assert len(listed) == 5
289
+
290
+ for i in range(5):
291
+ assert listed[i] == api_client.get_rule(f"test_{i+5}")
292
+
293
+ # ensure the last 9 rules are returned when cursor=test_0
294
+ listed = []
295
+ for rule in api_client.list_rules(cursor="test_0"):
296
+ listed.append(rule)
297
+ assert len(listed) == 9
298
+
299
+ for i in range(9):
300
+ assert listed[i] == api_client.get_rule(f"test_{i+1}")
@@ -0,0 +1,116 @@
1
+ import base64
2
+ import datetime
3
+ import os
4
+ import time
5
+
6
+ from databricks.sdk import WorkspaceClient
7
+ from databricks.sdk.service import jobs, workspace as dbworkspace
8
+
9
+ from .constants import *
10
+
11
+ pylib_volume_path = os.environ["PYLIB_VOLUME_PATH"]
12
+ pylib_wheel_path = os.environ["PYLIB_WHEEL_PATH"]
13
+
14
+
15
+ def test_secret_auth(api_client):
16
+ # making sure it's even possible to get a config
17
+ api_client.get_config()
18
+
19
+ # need to do an API operation using databricks secret auth.
20
+ notebook_data = f"""
21
+ %pip install {pylib_wheel_path}
22
+ dbutils.library.restartPython()
23
+ # COMMAND ----------
24
+ from dasl_client.client import Client
25
+
26
+ Client.for_workspace(
27
+ workspace_url="{databricks_host}",
28
+ dasl_host="{dasl_host}",
29
+ ).get_config()
30
+ # COMMAND ----------
31
+ dbutils.notebook.exit("SUCCESS")
32
+ """
33
+ print(f"notebook_data={notebook_data}")
34
+
35
+ wsc = WorkspaceClient()
36
+ wsc.workspace.mkdirs(path=pylib_volume_path)
37
+
38
+ notebook_path = f"{pylib_volume_path}/test_secret_auth_notebook"
39
+ wsc.workspace.import_(
40
+ path=notebook_path,
41
+ format=dbworkspace.ImportFormat.SOURCE,
42
+ language=dbworkspace.Language.PYTHON,
43
+ content=base64.b64encode(notebook_data.encode("utf-8")).decode("utf-8"),
44
+ overwrite=True,
45
+ )
46
+
47
+ job_id = None
48
+ try:
49
+ job_id = wsc.jobs.create(
50
+ name="run test_secret_auth notebook",
51
+ tasks=[
52
+ jobs.Task(
53
+ task_key="run_notebook",
54
+ notebook_task=jobs.NotebookTask(notebook_path=notebook_path),
55
+ ),
56
+ ],
57
+ ).job_id
58
+
59
+ wsc.jobs.run_now(job_id=job_id)
60
+
61
+ logs = []
62
+ start = datetime.datetime.now()
63
+ complete = False
64
+ while not complete:
65
+ elapsed = datetime.datetime.now() - start
66
+ if elapsed > datetime.timedelta(seconds=300):
67
+ raise Exception(f"timed out waiting for job")
68
+
69
+ time.sleep(5)
70
+
71
+ status, logs = fetch_latest_run_status_and_logs(wsc, job_id)
72
+ print(f"logs={logs}")
73
+
74
+ if status == jobs.TerminationCodeCode.RUN_EXECUTION_ERROR:
75
+ raise Exception(f"job terminated with error")
76
+
77
+ complete = status == jobs.TerminationCodeCode.SUCCESS
78
+
79
+ print(logs)
80
+ assert len(logs) == 1
81
+ assert logs[0] == "SUCCESS"
82
+ finally:
83
+ wsc.workspace.delete(pylib_volume_path, recursive=True)
84
+ if job_id is not None:
85
+ wsc.jobs.delete(job_id=job_id)
86
+
87
+
88
+ def fetch_latest_run_status_and_logs(
89
+ wsc: WorkspaceClient,
90
+ job_id: str,
91
+ ):
92
+ runs = list(wsc.jobs.list_runs(job_id=job_id, expand_tasks=True))
93
+ if not runs:
94
+ return "No runs found", None
95
+
96
+ # Find the latest run based on the start time
97
+ latest_run = max(runs, key=lambda r: r.start_time)
98
+ if latest_run.status.termination_details is None:
99
+ return "No runs found", None
100
+ status = latest_run.status.termination_details.code
101
+ logs = []
102
+ for task in latest_run.tasks:
103
+ output = wsc.jobs.get_run_output(task.run_id)
104
+ if output.error is not None:
105
+ logs.append(output.error)
106
+ elif output.logs is not None:
107
+ logs.append(output.logs)
108
+ elif output.notebook_output is not None:
109
+ logs.append(output.notebook_output.result)
110
+ elif output.run_job_output is not None:
111
+ raise Exception("Nested jobs are not supported")
112
+ elif output.sql_output is not None:
113
+ raise Exception("SQL jobs are unsupported")
114
+ else:
115
+ logs.append("")
116
+ return status, logs