dasl-client 1.0.6__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dasl-client might be problematic. Click here for more details.

@@ -0,0 +1,559 @@
1
+ from pyspark.sql import DataFrame, SparkSession
2
+ from pyspark.sql.types import *
3
+ from pyspark.sql.dataframe import DataFrame
4
+ from pyspark.sql.functions import col, lit
5
+ from dasl_client.preset_development.errors import *
6
+
7
+ FieldSpec = Dict[str, Any]
8
+
9
+
10
+ class Node:
11
+ """
12
+ Represents a node in a tree structure. This is used to figure out STRUCT types
13
+ and ensure proper subSTRUCT reconciliation at data output.
14
+ """
15
+
16
+ pass
17
+
18
+
19
+ class Branch(Node):
20
+ """
21
+ Represents a branch in a tree structure.
22
+ """
23
+
24
+ def __init__(self, children: Dict[str, Node] = None):
25
+ self._children: Dict[str, Node] = children if children is not None else {}
26
+
27
+
28
+ class Leaf(Node):
29
+ """
30
+ Represents a leaf node in a tree structure.
31
+ """
32
+
33
+ def __init__(self, field: FieldSpec):
34
+ self._field = field
35
+
36
+
37
+ class Stage:
38
+ """
39
+ Stage represents a single stage's table and contains all logic required to correctly perform
40
+ operations as defined in the table's field specifications, utility functions, filters, and more.
41
+
42
+ To operate, it takes in a DataFrame and returns a new DataFrame with the defined operations
43
+ applied. The resulting DataFrame can be used in subsequent stages if required.
44
+ """
45
+
46
+ __op_list = ["assert", "literal", "from", "alias", "expr", "join"]
47
+
48
+ def __init__(self, spark: SparkSession, stage: str, table: Dict[str, any]):
49
+ """
50
+ Initializes a Stage object that encapsulates all operations required for a single
51
+ table within a stage.
52
+
53
+ Instance Attributes:
54
+ stage (str): The medallion layer stage name.
55
+ name (str): The name of the table.
56
+ filter (str): The filter applied before operations.
57
+ postFilter (str): The filter applied after operations.
58
+ utils (Dict[str, Any]): Utility operations to perform on the DataFrame.
59
+ input (str): The name of the prior stage's table to use as input.
60
+ fields (List[Dict[str, str]]): Field specification operations to apply.
61
+ assertions (List[Dict[str, List[str]]]): Assertions to apply after operations.
62
+ """
63
+ self._spark = spark
64
+ self._stage = stage
65
+ self._name = table.get("name", "")
66
+ self._filter = table.get("filter", "")
67
+ self._postFilter = table.get("postFilter", "")
68
+ self._utils = table.get("utils", {})
69
+ self._input = table.get("input", None)
70
+
71
+ fields = (
72
+ [{"name": "dasl_id", "from": "dasl_id"}] + table.get("fields", [])
73
+ if self._stage != "temp_fields"
74
+ else table.get("fields", [])
75
+ )
76
+ self._fields = [
77
+ f for f in fields if (f.get("name", None) and not f.get("assert", None))
78
+ ]
79
+ self._assertions = [f for f in fields if f.get("assert", None)]
80
+
81
+ names = []
82
+ for field in self._fields:
83
+ if not (name := field.get("name", None)):
84
+ if not field.get("assert", None): # Can't walrus em all :/
85
+ raise MissingFieldNameError(self._stage, self._name)
86
+ if name in names:
87
+ raise DuplicateFieldNameError(self._stage, self._name, name)
88
+ names += [name]
89
+
90
+ missing_op_count = [
91
+ spec for spec in [field.get(op, None) for op in self.__op_list]
92
+ ].count(None)
93
+ if (missing_op_count == len(self.__op_list)) or (
94
+ len(self.__op_list) - missing_op_count > 1
95
+ ):
96
+ raise MalformedFieldError(
97
+ self._stage, self._name, field.get("name", None)
98
+ )
99
+
100
+ def _referenced_columns(self) -> List[str]:
101
+ """
102
+ Get a list of columns referenced in the table's field specifications.
103
+
104
+ Returns:
105
+ A list of referenced columns.
106
+ """
107
+ return [field.get("from") for field in self._fields if field.get("from", None)]
108
+
109
+ def _column_names(self) -> List[str]:
110
+ """
111
+ Returns a list of column names referenced in the table's field specifications.
112
+
113
+ Returns:
114
+ A list of columns names.
115
+ """
116
+ return [field.get("name") for field in self._fields]
117
+
118
+ def _omitted_columns(self) -> List[str]:
119
+ """
120
+ Get omitted columns from the preserve utility function object.
121
+
122
+ Returns:
123
+ A list of omitted columns.
124
+ """
125
+ preserve = self._utils.get("unreferencedColumns", None)
126
+ return [] if not preserve else preserve.get("omitColumns", [])
127
+
128
+ def _duplicate_prefix(self) -> str:
129
+ """
130
+ Get the prefix to use for duplicate fields. If not provided, returns a default.
131
+
132
+ Returns:
133
+ The duplicate prefix string.
134
+ """
135
+ preserve = self._utils.get("unreferencedColumns", None)
136
+ return "d_" if not preserve else preserve.get("duplicatePrefix", "d_")
137
+
138
+ def json_extract_boilerplate(
139
+ self, df: DataFrame, source: str, omit_fields: List[str], duplicate_prefix: str
140
+ ) -> Tuple[StructType, List[str]]:
141
+ """
142
+ Processes the common schema and column information needed from a target JSON
143
+ containing column.
144
+
145
+ Returns:
146
+ The schema and columns extracted from the target JSON.
147
+ """
148
+ target_col = source
149
+ existing_columns = df.columns
150
+ if target_col not in existing_columns:
151
+ raise ReferencedColumnMissingError("jsonExtract", target_col)
152
+ schema = self._spark.sql(
153
+ f"SELECT schema_of_json_agg({target_col}) AS sc FROM {{df}}", df=df
154
+ ).collect()[0][0]
155
+ extract_df = self._spark.createDataFrame(data=[], schema=schema)
156
+ columns = extract_df.columns
157
+ columns = [
158
+ f"extract.{col} AS {col}"
159
+ for col in columns
160
+ if col not in omit_fields and col not in existing_columns
161
+ ]
162
+ columns += [
163
+ f"extract.{col} AS {duplicate_prefix}{col}"
164
+ for col in columns
165
+ if col not in omit_fields and col in existing_columns
166
+ ]
167
+ return (schema, columns)
168
+
169
+ def json_extract(
170
+ self, df: DataFrame, schema: str, columns: List[str], target_col: str
171
+ ) -> DataFrame:
172
+ """
173
+ Prefroms JSON extraction to new fields in the DataFrame.
174
+
175
+ Returns:
176
+ A DataFrame with the resultant operation's records.
177
+ """
178
+ return (
179
+ df.selectExpr("*", f"from_json({target_col}, '{schema}') AS extract")
180
+ .selectExpr("*", *columns)
181
+ .drop("extract")
182
+ )
183
+
184
+ def json_extract_embed_column(
185
+ self,
186
+ df: DataFrame,
187
+ schema: str,
188
+ omit_fields: List[str],
189
+ target_col: str,
190
+ name: str,
191
+ ) -> DataFrame:
192
+ """
193
+ Performs JSON extraction embedding new JSON fields to a single new field in
194
+ the DataFrame. (Serialized as JSON.)
195
+
196
+ Returns:
197
+ A DataFrame with the resultant operation's records.
198
+ """
199
+ extract_df = self._spark.createDataFrame(data=[], schema=schema)
200
+ schema = extract_df.drop(omit_fields).schema.simpleString()
201
+ return df.selectExpr("*", f"from_json({target_col}, '{schema}') AS {name}")
202
+
203
+ def preserved_columns(
204
+ self, df: DataFrame
205
+ ) -> Tuple[List[str], List[str], List[str]]:
206
+ """
207
+ Performs unreferenced field preservation to new fields in the DataFrame.
208
+
209
+ Returns:
210
+ A DataFrame with the resultant operation's records.
211
+ """
212
+ # We do not want to preserve temporary fields.
213
+ temp_columns = []
214
+ if temp_fields := self._utils.get("temporaryFields", None):
215
+ temp_columns = [f.get("name") for f in temp_fields]
216
+
217
+ referenced_columns = self._referenced_columns()
218
+ omitted_columns = self._omitted_columns()
219
+ preserved_columns = [
220
+ col
221
+ for col in df.columns
222
+ if col not in referenced_columns and col not in omitted_columns
223
+ ]
224
+ duplicate_prefix = self._duplicate_prefix()
225
+ column_names = self._column_names()
226
+ duplicate_renames = [
227
+ f"{col} AS {duplicate_prefix}{col}"
228
+ for col in preserved_columns
229
+ if col in column_names
230
+ ]
231
+ preserved_columns = [
232
+ col
233
+ for col in preserved_columns
234
+ if (col not in column_names and col not in temp_columns)
235
+ ]
236
+ column_names = [col for col in column_names if col not in preserved_columns]
237
+
238
+ return (preserved_columns, duplicate_renames, column_names)
239
+
240
+ def preserved_columns_embed_column(self, df) -> List[str]:
241
+ """
242
+ Performs unreferenced field preservation to new a single field in the DataFrame.
243
+ (Serialized as JSON.)
244
+
245
+ Returns:
246
+ A DataFrame with the resultant operation's records.
247
+ """
248
+ referenced_columns = self._referenced_columns()
249
+ omitted_columns = self._omitted_columns()
250
+ preserved_columns = [
251
+ col
252
+ for col in df.columns
253
+ if col not in referenced_columns and col not in omitted_columns
254
+ ]
255
+ return preserved_columns
256
+
257
+ def insert_path(
258
+ self, tree: Dict[str, Node], path: List[str], field: FieldSpec
259
+ ) -> None:
260
+ """
261
+ Inserts a field specification into a tree of nodes.
262
+ """
263
+ if not path:
264
+ return
265
+
266
+ head = path[0]
267
+ if len(path) == 1:
268
+ tree[head] = Leaf(field)
269
+ else:
270
+ if head in tree and isinstance(tree[head], Branch):
271
+ sub_tree = tree[head]._children
272
+ else:
273
+ sub_tree = {}
274
+ self.insert_path(sub_tree, path[1:], field)
275
+ tree[head] = Branch(sub_tree)
276
+
277
+ def parse_to_tree(self, fields: List[FieldSpec]) -> Dict[str, Node]:
278
+ """
279
+ Parses a list of field specifications into a tree of nodes.
280
+ """
281
+ tree: Dict[str, Node] = {}
282
+ for field in fields:
283
+ name = field.get("name")
284
+ if name is None:
285
+ continue
286
+ path = name.split(".")
287
+ self.insert_path(tree, path, field)
288
+ return tree
289
+
290
+ def cast_to_expr(self, field: FieldSpec, name: str) -> str:
291
+ """
292
+ Casts a field specification into a SELECT expression.
293
+
294
+ Returns:
295
+ The SQL expression for the field.
296
+ """
297
+ if field.get("from", None):
298
+ # check that the from column exists in the df?
299
+ return f"{field['from']} AS {name}"
300
+ elif field.get("literal", None):
301
+ return f"'{field['literal']}' AS {name}"
302
+ elif field.get("expr", None):
303
+ return f"{field['expr']} AS {name}"
304
+ else:
305
+ return ""
306
+
307
+ def process_node(self, name: str, node: Node) -> str:
308
+ """
309
+ Processes a single node in a tree of nodes.
310
+
311
+ Returns:
312
+ The STRUCT or SELECT SQL expression for the node.
313
+ """
314
+ if isinstance(node, Leaf):
315
+ return self.cast_to_expr(node._field, name)
316
+ elif isinstance(node, Branch):
317
+ fields_list = []
318
+ for child_name, child_node in node._children.items():
319
+ child_expr = self.process_node(child_name, child_node)
320
+ fields_list.append(f"{child_expr}")
321
+ joined_fields = ",\n".join(fields_list)
322
+ return f"struct(\n{joined_fields}\n) AS {name}"
323
+ else:
324
+ return ""
325
+
326
+ def parse_to_string(self, nested_tree: Dict[str, Node]) -> str:
327
+ """
328
+ Processes the nested tree representation to a valid SELECT expression.
329
+
330
+ Returns:
331
+ The SQL expression.
332
+ """
333
+ lines = []
334
+ for name, node in nested_tree.items():
335
+ processed = self.process_node(name, node)
336
+ wrapped = f"\n{processed}\n"
337
+ lines.append(wrapped)
338
+ return lines
339
+
340
+ def render_fields(self, fields: List[FieldSpec]) -> str:
341
+ """
342
+ Renders a list of field specifications containing both simple and
343
+ STRUCT references into valid, STRUCT cognicient, SELECT expressions.
344
+
345
+ Returns:
346
+ The SQL expression.
347
+ """
348
+ simple_fields = [f for f in fields if "." not in f["name"]]
349
+ nested_fields = [f for f in fields if "." in f["name"]]
350
+
351
+ result_parts = []
352
+ for field in simple_fields:
353
+ expr_str = self.cast_to_expr(field, field["name"])
354
+ result_parts.append(f"{expr_str}")
355
+
356
+ if nested_fields:
357
+ tree = self.parse_to_tree(nested_fields)
358
+ nested_str = self.parse_to_string(tree)
359
+ result_parts.append(nested_str)
360
+
361
+ return [p for p in result_parts if p]
362
+
363
+ def select_expr(self, df: DataFrame) -> str:
364
+ """
365
+ Renders all field specification operations that result in a SELECT expression
366
+ after filtering, but before post-filtering and aliasing.
367
+
368
+ Returns:
369
+ The SQL expression.
370
+ """
371
+ select_fields = self.render_fields(self._fields)
372
+
373
+ if preserve := self._utils.get("unreferencedColumns", None):
374
+ should_preserve = preserve.get("preserve", None)
375
+ if type(should_preserve) != bool:
376
+ raise MissingUtilityConfigurationFieldError(
377
+ "unreferencedColumns", "preserve"
378
+ )
379
+ if should_preserve:
380
+ if embed_col := preserve.get("embedColumn", None):
381
+ preserved_columns = self.preserved_columns_embed_column(df)
382
+ select_fields += [
383
+ f"struct({', '.join(preserved_columns)}) AS {embed_col}"
384
+ ]
385
+ else:
386
+ (
387
+ preserved_columns,
388
+ duplicate_renames,
389
+ column_names,
390
+ ) = self.preserved_columns(df)
391
+ select_fields += preserved_columns
392
+ select_fields += duplicate_renames
393
+
394
+ return ["*"] + select_fields if self._stage == "temp_fields" else select_fields
395
+
396
+ def run_filter(self, df: DataFrame) -> DataFrame:
397
+ """
398
+ Runs filter operations on the provided DataFrame.
399
+
400
+ Returns:
401
+ A DataFrame with the resultant operation's records.
402
+ """
403
+ if self._filter:
404
+ df = df.filter(self._filter)
405
+ return df
406
+
407
+ def run_json_extract(self, df: DataFrame) -> DataFrame:
408
+ """
409
+ Runs JSON extract utility operations on the provided DataFrame.
410
+
411
+ Returns:
412
+ A DataFrame with the resultant operation's records.
413
+ """
414
+ if json_extracts := self._utils.get("jsonExtract", None):
415
+ for json_extract in json_extracts:
416
+ source = json_extract.get("source")
417
+ if not source:
418
+ raise MissingUtilityConfigurationFieldError("jsonExtract", "source")
419
+ omit_fields = json_extract.get("omitFields", [])
420
+ duplicate_prefix = json_extract.get("duplicatePrefix", "d_")
421
+ schema, columns = self.json_extract_boilerplate(
422
+ df, source, omit_fields, duplicate_prefix
423
+ )
424
+ if name := json_extract.get("embedColumn", None):
425
+ df = self.json_extract_embed_column(
426
+ df, schema, omit_fields, source, name
427
+ )
428
+ else:
429
+ df = self.json_extract(df, schema, columns, source)
430
+ return df
431
+
432
+ def run_select_expr(self, df: DataFrame, select_fields: List[str]) -> DataFrame:
433
+ """
434
+ Runs select operations (preserver, from, literal, expr) on the provided DataFrame.
435
+
436
+ Returns:
437
+ A DataFrame with the resultant operation's records.
438
+ """
439
+ # Join columns, processed before this, need to be included in the
440
+ # dataframe too. So we append their output names to the fields
441
+ # selected.
442
+ joins_cols = []
443
+ for field in self._fields:
444
+ if field.get("join", None):
445
+ joins_cols += [field["name"]]
446
+
447
+ return df.selectExpr(select_fields + joins_cols)
448
+
449
+ def run_joins(self, df: DataFrame) -> DataFrame:
450
+ """
451
+ Runs joins operations on the provided DataFrame.
452
+
453
+ Returns:
454
+ A DataFrame with the resultant operation's records.
455
+ """
456
+ joins = []
457
+ for field in self._fields:
458
+ if field.get("join", None):
459
+ joins += [field]
460
+
461
+ for field in joins:
462
+ join = field.get("join")
463
+ lhs = join.get("lhs")
464
+ if not lhs:
465
+ raise MissingJoinFieldError("lhs")
466
+ rhs = join.get("rhs")
467
+ if not rhs:
468
+ raise MissingJoinFieldError("rhs")
469
+ select = join.get("select")
470
+ if not select:
471
+ raise MissingJoinFieldError("select")
472
+
473
+ if table := join.get("withTable", None):
474
+ df_joined = self._spark.table(table)
475
+ df = (
476
+ df.alias("tmp")
477
+ .join(df_joined, on=[df[lhs] == df_joined[rhs]], how="left")
478
+ .selectExpr("tmp.*", f"{select} AS {field.get('name')}")
479
+ )
480
+ elif csv := join.get("withCSV", None):
481
+ if path := csv.get("path", None):
482
+ df_joined = self._spark.read.csv(
483
+ path, header=True, inferSchema=True
484
+ )
485
+ df = (
486
+ df.alias("tmp")
487
+ .join(df_joined, on=[df[lhs] == df_joined[rhs]], how="left")
488
+ .selectExpr("tmp.*", f"{select} AS {field.get('name')}")
489
+ )
490
+ else:
491
+ raise MissingJoinFieldError("withTable or withCSV (please supply 1)")
492
+ return df
493
+
494
+ def run_aliases(self, df: DataFrame) -> DataFrame:
495
+ """
496
+ Runs alias operations on the provided DataFrame.
497
+
498
+ Returns:
499
+ A DataFrame with the resultant operation's records.
500
+ """
501
+ for field in self._fields:
502
+ if field.get("alias", None):
503
+ df = df.selectExpr("*", f"{field.get('alias')} AS {field.get('name')}")
504
+ return df
505
+
506
+ def run_assertions(self, df: DataFrame) -> DataFrame:
507
+ """
508
+ Runs assert operations on the provided DataFrame.
509
+
510
+ Returns:
511
+ A DataFrame with the resultant operation's records.
512
+ """
513
+ for assertions in self._assertions:
514
+ for assertion in assertions.get("assert"):
515
+ failing_rows = df.filter(f"NOT ({assertion['expr']})")
516
+ if not failing_rows.isEmpty():
517
+ raise AssertionFailedError(
518
+ assertion["expr"], assertion.get("message", ""), failing_rows
519
+ )
520
+ return df
521
+
522
+ def run_post_filter(self, df: DataFrame) -> DataFrame:
523
+ """
524
+ Runs postFilter operations on the provided DataFrame.
525
+
526
+ Returns:
527
+ A DataFrame with the resultant operation's records.
528
+ """
529
+ if self._postFilter:
530
+ df = df.filter(self._postFilter)
531
+ return df
532
+
533
+ def run_temp_fields(self, df: DataFrame) -> DataFrame:
534
+ """
535
+ Runs temporary field expressions on the provided DataFrame.
536
+
537
+ Returns:
538
+ A DataFrame with the resultant operation's records.
539
+ """
540
+ if temp_fields := self._utils.get("temporaryFields", None):
541
+ df = Stage(self._spark, "temp_fields", {"fields": temp_fields}).run(df)
542
+ return df
543
+
544
+ def run(self, df: DataFrame) -> DataFrame:
545
+ """
546
+ Runs all provided preset operations in the provided stage's table.
547
+
548
+ Returns:
549
+ A DataFrame with the resultant operation's records.
550
+ """
551
+ df = self.run_filter(df)
552
+ df = self.run_temp_fields(df)
553
+ df = self.run_json_extract(df)
554
+ df = self.run_joins(df)
555
+ df = self.run_select_expr(df, self.select_expr(df))
556
+ df = self.run_aliases(df)
557
+ df = self.run_assertions(df)
558
+ df = self.run_post_filter(df)
559
+ return df
@@ -16,13 +16,16 @@ class AdminConfig(BaseModel):
16
16
  Databricks.
17
17
 
18
18
  Attributes:
19
- workspace_url (str): The Databricks URL for the Databricks workspace.
20
- app_client_id (str): The client ID used by this workspace to use in
21
- three-legged OAuth.
22
- service_principal_id (str): The Databricks client ID for an OAuth
23
- secret associated with the service principal.
24
- service_principal_secret (str): The Databricks client secret for an
25
- OAuth secret associated with the service principal.
19
+ workspace_url (str):
20
+ The Databricks URL for the Databricks workspace.
21
+ app_client_id (str):
22
+ The client ID used by this workspace to use in three-legged OAuth.
23
+ service_principal_id (str):
24
+ The Databricks client ID for an OAuth secret associated with the
25
+ service principal.
26
+ service_principal_secret (str):
27
+ The Databricks client secret for an OAuth secret associated with
28
+ the service principal.
26
29
  """
27
30
 
28
31
  workspace_url: str