datacontract-cli 0.10.33__py3-none-any.whl → 0.10.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datacontract-cli might be problematic. Click here for more details.
- datacontract/api.py +9 -2
- datacontract/cli.py +4 -2
- datacontract/engines/data_contract_checks.py +102 -59
- datacontract/engines/data_contract_test.py +37 -0
- datacontract/engines/fastjsonschema/check_jsonschema.py +37 -19
- datacontract/engines/soda/check_soda_execute.py +6 -0
- datacontract/engines/soda/connections/athena.py +79 -0
- datacontract/engines/soda/connections/duckdb_connection.py +3 -0
- datacontract/export/avro_converter.py +12 -2
- datacontract/export/dqx_converter.py +121 -0
- datacontract/export/exporter.py +1 -0
- datacontract/export/exporter_factory.py +6 -0
- datacontract/export/markdown_converter.py +115 -5
- datacontract/export/mermaid_exporter.py +24 -11
- datacontract/export/spark_converter.py +28 -3
- datacontract/export/sql_type_converter.py +4 -0
- datacontract/imports/avro_importer.py +33 -7
- datacontract/imports/odcs_v3_importer.py +30 -1
- datacontract/imports/spark_importer.py +12 -1
- {datacontract_cli-0.10.33.dist-info → datacontract_cli-0.10.35.dist-info}/METADATA +126 -42
- {datacontract_cli-0.10.33.dist-info → datacontract_cli-0.10.35.dist-info}/RECORD +25 -23
- {datacontract_cli-0.10.33.dist-info → datacontract_cli-0.10.35.dist-info}/licenses/LICENSE +1 -1
- {datacontract_cli-0.10.33.dist-info → datacontract_cli-0.10.35.dist-info}/WHEEL +0 -0
- {datacontract_cli-0.10.33.dist-info → datacontract_cli-0.10.35.dist-info}/entry_points.txt +0 -0
- {datacontract_cli-0.10.33.dist-info → datacontract_cli-0.10.35.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import yaml
|
|
4
|
+
|
|
5
|
+
from datacontract.model.exceptions import DataContractException
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def to_athena_soda_configuration(server):
|
|
9
|
+
s3_region = os.getenv("DATACONTRACT_S3_REGION")
|
|
10
|
+
s3_access_key_id = os.getenv("DATACONTRACT_S3_ACCESS_KEY_ID")
|
|
11
|
+
s3_secret_access_key = os.getenv("DATACONTRACT_S3_SECRET_ACCESS_KEY")
|
|
12
|
+
s3_session_token = os.getenv("DATACONTRACT_S3_SESSION_TOKEN")
|
|
13
|
+
|
|
14
|
+
# Validate required parameters
|
|
15
|
+
if not s3_access_key_id:
|
|
16
|
+
raise DataContractException(
|
|
17
|
+
type="athena-connection",
|
|
18
|
+
name="missing_access_key_id",
|
|
19
|
+
reason="AWS access key ID is required. Set the DATACONTRACT_S3_ACCESS_KEY_ID environment variable.",
|
|
20
|
+
engine="datacontract",
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
if not s3_secret_access_key:
|
|
24
|
+
raise DataContractException(
|
|
25
|
+
type="athena-connection",
|
|
26
|
+
name="missing_secret_access_key",
|
|
27
|
+
reason="AWS secret access key is required. Set the DATACONTRACT_S3_SECRET_ACCESS_KEY environment variable.",
|
|
28
|
+
engine="datacontract",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
if not hasattr(server, "schema_") or not server.schema_:
|
|
32
|
+
raise DataContractException(
|
|
33
|
+
type="athena-connection",
|
|
34
|
+
name="missing_schema",
|
|
35
|
+
reason="Schema is required for Athena connection. Specify the schema where your tables exist in the server configuration.",
|
|
36
|
+
engine="datacontract",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if not hasattr(server, "stagingDir") or not server.stagingDir:
|
|
40
|
+
raise DataContractException(
|
|
41
|
+
type="athena-connection",
|
|
42
|
+
name="missing_s3_staging_dir",
|
|
43
|
+
reason="S3 staging directory is required for Athena connection. This should be the Amazon S3 Query Result Location (e.g., 's3://my-bucket/athena-results/').",
|
|
44
|
+
engine="datacontract",
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Validate S3 staging directory format
|
|
48
|
+
if not server.stagingDir.startswith("s3://"):
|
|
49
|
+
raise DataContractException(
|
|
50
|
+
type="athena-connection",
|
|
51
|
+
name="invalid_s3_staging_dir",
|
|
52
|
+
reason=f"S3 staging directory must start with 's3://'. Got: {server.s3_staging_dir}. Example: 's3://my-bucket/athena-results/'",
|
|
53
|
+
engine="datacontract",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
data_source = {
|
|
57
|
+
"type": "athena",
|
|
58
|
+
"access_key_id": s3_access_key_id,
|
|
59
|
+
"secret_access_key": s3_secret_access_key,
|
|
60
|
+
"schema": server.schema_,
|
|
61
|
+
"staging_dir": server.stagingDir,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
if s3_region:
|
|
65
|
+
data_source["region_name"] = s3_region
|
|
66
|
+
elif server.region_name:
|
|
67
|
+
data_source["region_name"] = server.region_name
|
|
68
|
+
|
|
69
|
+
if server.catalog:
|
|
70
|
+
# Optional, Identify the name of the Data Source, also referred to as a Catalog. The default value is `awsdatacatalog`.
|
|
71
|
+
data_source["catalog"] = server.catalog
|
|
72
|
+
|
|
73
|
+
if s3_session_token:
|
|
74
|
+
data_source["aws_session_token"] = s3_session_token
|
|
75
|
+
|
|
76
|
+
soda_configuration = {f"data_source {server.type}": data_source}
|
|
77
|
+
|
|
78
|
+
soda_configuration_str = yaml.dump(soda_configuration)
|
|
79
|
+
return soda_configuration_str
|
|
@@ -71,6 +71,9 @@ def get_duckdb_connection(
|
|
|
71
71
|
elif server.format == "delta":
|
|
72
72
|
con.sql("update extensions;") # Make sure we have the latest delta extension
|
|
73
73
|
con.sql(f"""CREATE VIEW "{model_name}" AS SELECT * FROM delta_scan('{model_path}');""")
|
|
74
|
+
table_info = con.sql(f"PRAGMA table_info('{model_name}');").fetchdf()
|
|
75
|
+
if table_info is not None and not table_info.empty:
|
|
76
|
+
run.log_info(f"DuckDB Table Info: {table_info.to_string(index=False)}")
|
|
74
77
|
return con
|
|
75
78
|
|
|
76
79
|
|
|
@@ -44,12 +44,18 @@ def to_avro_field(field, field_name):
|
|
|
44
44
|
avro_type = to_avro_type(field, field_name)
|
|
45
45
|
avro_field["type"] = avro_type if is_required_avro else ["null", avro_type]
|
|
46
46
|
|
|
47
|
-
|
|
48
|
-
|
|
47
|
+
# Handle enum types - both required and optional
|
|
48
|
+
if avro_type == "enum" or (isinstance(avro_field["type"], list) and "enum" in avro_field["type"]):
|
|
49
|
+
enum_def = {
|
|
49
50
|
"type": "enum",
|
|
50
51
|
"name": field.title,
|
|
51
52
|
"symbols": field.enum,
|
|
52
53
|
}
|
|
54
|
+
if is_required_avro:
|
|
55
|
+
avro_field["type"] = enum_def
|
|
56
|
+
else:
|
|
57
|
+
# Replace "enum" with the full enum definition in the union
|
|
58
|
+
avro_field["type"] = ["null", enum_def]
|
|
53
59
|
|
|
54
60
|
if field.config:
|
|
55
61
|
if "avroDefault" in field.config:
|
|
@@ -77,6 +83,10 @@ def to_avro_type(field: Field, field_name: str) -> str | dict:
|
|
|
77
83
|
if "avroType" in field.config:
|
|
78
84
|
return field.config["avroType"]
|
|
79
85
|
|
|
86
|
+
# Check for enum fields based on presence of enum list and avroType config
|
|
87
|
+
if field.enum and field.config and field.config.get("avroType") == "enum":
|
|
88
|
+
return "enum"
|
|
89
|
+
|
|
80
90
|
if field.type is None:
|
|
81
91
|
return "null"
|
|
82
92
|
if field.type in ["string", "varchar", "text"]:
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Union
|
|
2
|
+
|
|
3
|
+
import yaml
|
|
4
|
+
|
|
5
|
+
from datacontract.export.exporter import Exporter, _check_models_for_export
|
|
6
|
+
from datacontract.model.data_contract_specification import DataContractSpecification, Field, Model, Quality
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DqxKeys:
|
|
10
|
+
CHECK = "check"
|
|
11
|
+
ARGUMENTS = "arguments"
|
|
12
|
+
SPECIFICATION = "specification"
|
|
13
|
+
COL_NAME = "column"
|
|
14
|
+
COL_NAMES = "for_each_column"
|
|
15
|
+
COLUMNS = "columns"
|
|
16
|
+
FUNCTION = "function"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DqxExporter(Exporter):
|
|
20
|
+
"""Exporter implementation for converting data contracts to DQX YAML file."""
|
|
21
|
+
|
|
22
|
+
def export(
|
|
23
|
+
self,
|
|
24
|
+
data_contract: DataContractSpecification,
|
|
25
|
+
model: Model,
|
|
26
|
+
server: str,
|
|
27
|
+
sql_server_type: str,
|
|
28
|
+
export_args: Dict[str, Any],
|
|
29
|
+
) -> str:
|
|
30
|
+
"""Exports a data contract to DQX format."""
|
|
31
|
+
model_name, model_value = _check_models_for_export(data_contract, model, self.export_format)
|
|
32
|
+
return to_dqx_yaml(model_value)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def to_dqx_yaml(model_value: Model) -> str:
|
|
36
|
+
"""
|
|
37
|
+
Converts the data contract's quality checks to DQX YAML format.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
model_value (Model): The data contract to convert.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
str: YAML representation of the data contract's quality checks.
|
|
44
|
+
"""
|
|
45
|
+
extracted_rules = extract_quality_rules(model_value)
|
|
46
|
+
return yaml.dump(extracted_rules, sort_keys=False, allow_unicode=True, default_flow_style=False)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def process_quality_rule(rule: Quality, column_name: str) -> Dict[str, Any]:
|
|
50
|
+
"""
|
|
51
|
+
Processes a single quality rule by injecting the column path into its arguments if absent.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
rule (Quality): The quality rule to process.
|
|
55
|
+
column_name (str): The full path to the current column.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
dict: The processed quality rule specification.
|
|
59
|
+
"""
|
|
60
|
+
rule_data = rule.model_extra
|
|
61
|
+
specification = rule_data[DqxKeys.SPECIFICATION]
|
|
62
|
+
check = specification[DqxKeys.CHECK]
|
|
63
|
+
|
|
64
|
+
arguments = check.setdefault(DqxKeys.ARGUMENTS, {})
|
|
65
|
+
|
|
66
|
+
if DqxKeys.COL_NAME not in arguments and DqxKeys.COL_NAMES not in arguments and DqxKeys.COLUMNS not in arguments:
|
|
67
|
+
if check[DqxKeys.FUNCTION] not in ("is_unique", "foreign_key"):
|
|
68
|
+
arguments[DqxKeys.COL_NAME] = column_name
|
|
69
|
+
else:
|
|
70
|
+
arguments[DqxKeys.COLUMNS] = [column_name]
|
|
71
|
+
|
|
72
|
+
return specification
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def extract_quality_rules(data: Union[Model, Field, Quality], column_path: str = "") -> List[Dict[str, Any]]:
|
|
76
|
+
"""
|
|
77
|
+
Recursively extracts all quality rules from a data contract structure.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
data (Union[Model, Field, Quality]): The data contract model, field, or quality rule.
|
|
81
|
+
column_path (str, optional): The current path in the schema hierarchy. Defaults to "".
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
List[Dict[str, Any]]: A list of quality rule specifications.
|
|
85
|
+
"""
|
|
86
|
+
quality_rules = []
|
|
87
|
+
|
|
88
|
+
if isinstance(data, Quality):
|
|
89
|
+
return [process_quality_rule(data, column_path)]
|
|
90
|
+
|
|
91
|
+
if isinstance(data, (Model, Field)):
|
|
92
|
+
for key, field in data.fields.items():
|
|
93
|
+
current_path = build_column_path(column_path, key)
|
|
94
|
+
|
|
95
|
+
if field.fields:
|
|
96
|
+
# Field is a struct-like object, recurse deeper
|
|
97
|
+
quality_rules.extend(extract_quality_rules(field, current_path))
|
|
98
|
+
else:
|
|
99
|
+
# Process quality rules at leaf fields
|
|
100
|
+
for rule in field.quality:
|
|
101
|
+
quality_rules.append(process_quality_rule(rule, current_path))
|
|
102
|
+
|
|
103
|
+
# Process any quality rules attached directly to this level
|
|
104
|
+
for rule in data.quality:
|
|
105
|
+
quality_rules.append(process_quality_rule(rule, column_path))
|
|
106
|
+
|
|
107
|
+
return quality_rules
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def build_column_path(current_path: str, key: str) -> str:
|
|
111
|
+
"""
|
|
112
|
+
Builds the full column path by concatenating parent path with current key.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
current_path (str): The current path prefix.
|
|
116
|
+
key (str): The current field's key.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
str: The full path.
|
|
120
|
+
"""
|
|
121
|
+
return f"{current_path}.{key}" if current_path else key
|
datacontract/export/exporter.py
CHANGED
|
@@ -197,6 +197,12 @@ exporter_factory.register_lazy_exporter(
|
|
|
197
197
|
class_name="MarkdownExporter",
|
|
198
198
|
)
|
|
199
199
|
|
|
200
|
+
exporter_factory.register_lazy_exporter(
|
|
201
|
+
name=ExportFormat.dqx,
|
|
202
|
+
module_path="datacontract.export.dqx_converter",
|
|
203
|
+
class_name="DqxExporter",
|
|
204
|
+
)
|
|
205
|
+
|
|
200
206
|
exporter_factory.register_lazy_exporter(
|
|
201
207
|
name=ExportFormat.iceberg, module_path="datacontract.export.iceberg_converter", class_name="IcebergExporter"
|
|
202
208
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Dict
|
|
1
|
+
from typing import Dict, List
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel
|
|
4
4
|
|
|
@@ -12,6 +12,9 @@ from datacontract.model.data_contract_specification import (
|
|
|
12
12
|
ServiceLevel,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
|
+
TAB = " "
|
|
16
|
+
ARROW = "↳"
|
|
17
|
+
|
|
15
18
|
|
|
16
19
|
class MarkdownExporter(Exporter):
|
|
17
20
|
"""Exporter implementation for converting data contracts to Markdown."""
|
|
@@ -70,7 +73,8 @@ def obj_attributes_to_markdown(obj: BaseModel, excluded_fields: set = set(), is_
|
|
|
70
73
|
else:
|
|
71
74
|
bullet_char = "-"
|
|
72
75
|
newline_char = "\n"
|
|
73
|
-
|
|
76
|
+
model_attributes_to_include = set(obj.__class__.model_fields.keys())
|
|
77
|
+
obj_model = obj.model_dump(exclude_unset=True, include=model_attributes_to_include, exclude=excluded_fields)
|
|
74
78
|
description_value = obj_model.pop("description", None)
|
|
75
79
|
attributes = [
|
|
76
80
|
(f"{bullet_char} `{attr}`" if value is True else f"{bullet_char} **{attr}:** {value}")
|
|
@@ -78,7 +82,8 @@ def obj_attributes_to_markdown(obj: BaseModel, excluded_fields: set = set(), is_
|
|
|
78
82
|
if value
|
|
79
83
|
]
|
|
80
84
|
description = f"*{description_to_markdown(description_value)}*"
|
|
81
|
-
|
|
85
|
+
extra = [extra_to_markdown(obj)] if obj.model_extra else []
|
|
86
|
+
return newline_char.join([description] + attributes + extra)
|
|
82
87
|
|
|
83
88
|
|
|
84
89
|
def servers_to_markdown(servers: Dict[str, Server]) -> str:
|
|
@@ -153,8 +158,8 @@ def field_to_markdown(field_name: str, field: Field, level: int = 0) -> str:
|
|
|
153
158
|
Returns:
|
|
154
159
|
str: A Markdown table rows for the field.
|
|
155
160
|
"""
|
|
156
|
-
tabs =
|
|
157
|
-
arrow =
|
|
161
|
+
tabs = TAB * level
|
|
162
|
+
arrow = ARROW if level > 0 else ""
|
|
158
163
|
column_name = f"{tabs}{arrow} {field_name}"
|
|
159
164
|
|
|
160
165
|
attributes = obj_attributes_to_markdown(field, {"type", "fields", "items", "keys", "values"}, True)
|
|
@@ -206,3 +211,108 @@ def service_level_to_markdown(service_level: ServiceLevel | None) -> str:
|
|
|
206
211
|
|
|
207
212
|
def description_to_markdown(description: str | None) -> str:
|
|
208
213
|
return (description or "No description.").replace("\n", "<br>")
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def array_of_dict_to_markdown(array: List[Dict[str, str]]) -> str:
|
|
217
|
+
"""
|
|
218
|
+
Convert a list of dictionaries to a Markdown table.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
array (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a row in the table.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
str: A Markdown formatted table.
|
|
225
|
+
"""
|
|
226
|
+
if not array:
|
|
227
|
+
return ""
|
|
228
|
+
|
|
229
|
+
headers = []
|
|
230
|
+
|
|
231
|
+
for item in array:
|
|
232
|
+
headers += item.keys()
|
|
233
|
+
headers = list(dict.fromkeys(headers)) # Preserve order and remove duplicates
|
|
234
|
+
|
|
235
|
+
markdown_parts = [
|
|
236
|
+
"| " + " | ".join(headers) + " |",
|
|
237
|
+
"| " + " | ".join(["---"] * len(headers)) + " |",
|
|
238
|
+
]
|
|
239
|
+
|
|
240
|
+
for row in array:
|
|
241
|
+
element = row
|
|
242
|
+
markdown_parts.append(
|
|
243
|
+
"| "
|
|
244
|
+
+ " | ".join(
|
|
245
|
+
f"{str(element.get(header, ''))}".replace("\n", "<br>").replace("\t", TAB) for header in headers
|
|
246
|
+
)
|
|
247
|
+
+ " |"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
return "\n".join(markdown_parts) + "\n"
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def array_to_markdown(array: List[str]) -> str:
|
|
254
|
+
"""
|
|
255
|
+
Convert a list of strings to a Markdown formatted list.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
array (List[str]): A list of strings to convert.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
str: A Markdown formatted list.
|
|
262
|
+
"""
|
|
263
|
+
if not array:
|
|
264
|
+
return ""
|
|
265
|
+
return "\n".join(f"- {item}" for item in array) + "\n"
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def dict_to_markdown(dictionary: Dict[str, str]) -> str:
|
|
269
|
+
"""
|
|
270
|
+
Convert a dictionary to a Markdown formatted list.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
dictionary (Dict[str, str]): A dictionary where keys are item names and values are item descriptions.
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
str: A Markdown formatted list of items.
|
|
277
|
+
"""
|
|
278
|
+
if not dictionary:
|
|
279
|
+
return ""
|
|
280
|
+
|
|
281
|
+
markdown_parts = []
|
|
282
|
+
for key, value in dictionary.items():
|
|
283
|
+
if isinstance(value, dict):
|
|
284
|
+
markdown_parts.append(f"- {key}")
|
|
285
|
+
nested_markdown = dict_to_markdown(value)
|
|
286
|
+
if nested_markdown:
|
|
287
|
+
nested_lines = nested_markdown.split("\n")
|
|
288
|
+
for line in nested_lines:
|
|
289
|
+
if line.strip():
|
|
290
|
+
markdown_parts.append(f" {line}")
|
|
291
|
+
else:
|
|
292
|
+
markdown_parts.append(f"- {key}: {value}")
|
|
293
|
+
return "\n".join(markdown_parts) + "\n"
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def extra_to_markdown(obj: BaseModel) -> str:
|
|
297
|
+
"""
|
|
298
|
+
Convert the extra attributes of a data contract to Markdown format.
|
|
299
|
+
Args:
|
|
300
|
+
obj (BaseModel): The data contract object containing extra attributes.
|
|
301
|
+
Returns:
|
|
302
|
+
str: A Markdown formatted string representing the extra attributes of the data contract.
|
|
303
|
+
"""
|
|
304
|
+
markdown_part = ""
|
|
305
|
+
extra = obj.model_extra
|
|
306
|
+
if extra:
|
|
307
|
+
for key_extra, value_extra in extra.items():
|
|
308
|
+
markdown_part += f"\n### {key_extra.capitalize()}\n"
|
|
309
|
+
if isinstance(value_extra, list) and len(value_extra):
|
|
310
|
+
if isinstance(value_extra[0], dict):
|
|
311
|
+
markdown_part += array_of_dict_to_markdown(value_extra)
|
|
312
|
+
elif isinstance(value_extra[0], str):
|
|
313
|
+
markdown_part += array_to_markdown(value_extra)
|
|
314
|
+
elif isinstance(value_extra, dict):
|
|
315
|
+
markdown_part += dict_to_markdown(value_extra)
|
|
316
|
+
else:
|
|
317
|
+
markdown_part += f"{str(value_extra)}\n"
|
|
318
|
+
return markdown_part
|
|
@@ -27,31 +27,33 @@ def dcs_to_mermaid(data_contract_spec: DataContractSpecification) -> str | None:
|
|
|
27
27
|
mmd_references = []
|
|
28
28
|
|
|
29
29
|
for model_name, model in data_contract_spec.models.items():
|
|
30
|
+
clean_model = _sanitize_name(model_name)
|
|
30
31
|
entity_block = ""
|
|
31
32
|
|
|
32
33
|
for field_name, field in model.fields.items():
|
|
33
34
|
clean_name = _sanitize_name(field_name)
|
|
34
|
-
|
|
35
|
+
field_type = field.type or "unknown"
|
|
35
36
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
if field.references:
|
|
39
|
-
indicators += "⌘"
|
|
37
|
+
is_pk = bool(field.primaryKey or (field.unique and field.required))
|
|
38
|
+
is_fk = bool(field.references)
|
|
40
39
|
|
|
41
|
-
field_type = field.
|
|
42
|
-
entity_block += f"\t{clean_name}{indicators} {field_type}\n"
|
|
40
|
+
entity_block += _field_line(clean_name, field_type, pk=is_pk, uk=bool(field.unique), fk=is_fk)
|
|
43
41
|
|
|
44
42
|
if field.references:
|
|
45
|
-
|
|
43
|
+
references = field.references.replace(".", "·")
|
|
44
|
+
parts = references.split("·")
|
|
45
|
+
referenced_model = _sanitize_name(parts[0]) if len(parts) > 0 else ""
|
|
46
|
+
referenced_field = _sanitize_name(parts[1]) if len(parts) > 1 else ""
|
|
46
47
|
if referenced_model:
|
|
47
|
-
|
|
48
|
+
label = referenced_field or clean_name
|
|
49
|
+
mmd_references.append(f'"**{referenced_model}**" ||--o{{ "**{clean_model}**" : {label}')
|
|
48
50
|
|
|
49
|
-
mmd_entity += f'\t"**{
|
|
51
|
+
mmd_entity += f'\t"**{clean_model}**" {{\n{entity_block}}}\n'
|
|
50
52
|
|
|
51
53
|
if mmd_references:
|
|
52
54
|
mmd_entity += "\n" + "\n".join(mmd_references)
|
|
53
55
|
|
|
54
|
-
return
|
|
56
|
+
return mmd_entity + "\n"
|
|
55
57
|
|
|
56
58
|
except Exception as e:
|
|
57
59
|
print(f"Error generating DCS mermaid diagram: {e}")
|
|
@@ -95,3 +97,14 @@ def odcs_to_mermaid(data_contract_spec: OpenDataContractStandard) -> str | None:
|
|
|
95
97
|
|
|
96
98
|
def _sanitize_name(name: str) -> str:
|
|
97
99
|
return name.replace("#", "Nb").replace(" ", "_").replace("/", "by")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _field_line(name: str, field_type: str, pk: bool = False, uk: bool = False, fk: bool = False) -> str:
|
|
103
|
+
indicators = ""
|
|
104
|
+
if pk:
|
|
105
|
+
indicators += "🔑"
|
|
106
|
+
if uk:
|
|
107
|
+
indicators += "🔒"
|
|
108
|
+
if fk:
|
|
109
|
+
indicators += "⌘"
|
|
110
|
+
return f"\t{name}{indicators} {field_type}\n"
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
1
3
|
from pyspark.sql import types
|
|
2
4
|
|
|
3
5
|
from datacontract.export.exporter import Exporter
|
|
@@ -104,7 +106,8 @@ def to_struct_field(field: Field, field_name: str) -> types.StructField:
|
|
|
104
106
|
types.StructField: The corresponding Spark StructField.
|
|
105
107
|
"""
|
|
106
108
|
data_type = to_spark_data_type(field)
|
|
107
|
-
|
|
109
|
+
metadata = to_spark_metadata(field)
|
|
110
|
+
return types.StructField(name=field_name, dataType=data_type, nullable=not field.required, metadata=metadata)
|
|
108
111
|
|
|
109
112
|
|
|
110
113
|
def to_spark_data_type(field: Field) -> types.DataType:
|
|
@@ -152,7 +155,25 @@ def to_spark_data_type(field: Field) -> types.DataType:
|
|
|
152
155
|
return types.DateType()
|
|
153
156
|
if field_type == "bytes":
|
|
154
157
|
return types.BinaryType()
|
|
155
|
-
return types.StringType()
|
|
158
|
+
return types.StringType() # default if no condition is met
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def to_spark_metadata(field: Field) -> dict[str, str]:
|
|
162
|
+
"""
|
|
163
|
+
Convert a field to a Spark metadata dictonary.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
field (Field): The field to convert.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
dict: dictionary that can be supplied to Spark as metadata for a StructField
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
metadata = {}
|
|
173
|
+
if field.description:
|
|
174
|
+
metadata["comment"] = field.description
|
|
175
|
+
|
|
176
|
+
return metadata
|
|
156
177
|
|
|
157
178
|
|
|
158
179
|
def print_schema(dtype: types.DataType) -> str:
|
|
@@ -192,7 +213,11 @@ def print_schema(dtype: types.DataType) -> str:
|
|
|
192
213
|
name = f'"{column.name}"'
|
|
193
214
|
data_type = indent(print_schema(column.dataType), 1)
|
|
194
215
|
nullable = indent(f"{column.nullable}", 1)
|
|
195
|
-
|
|
216
|
+
if column.metadata:
|
|
217
|
+
metadata = indent(f"{json.dumps(column.metadata)}", 1)
|
|
218
|
+
return f"StructField({name},\n{data_type},\n{nullable},\n{metadata}\n)"
|
|
219
|
+
else:
|
|
220
|
+
return f"StructField({name},\n{data_type},\n{nullable}\n)"
|
|
196
221
|
|
|
197
222
|
def format_struct_type(struct_type: types.StructType) -> str:
|
|
198
223
|
"""
|
|
@@ -3,6 +3,9 @@ from datacontract.model.data_contract_specification import Field
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
def convert_to_sql_type(field: Field, server_type: str) -> str:
|
|
6
|
+
if field.config and "physicalType" in field.config:
|
|
7
|
+
return field.config["physicalType"]
|
|
8
|
+
|
|
6
9
|
if server_type == "snowflake":
|
|
7
10
|
return convert_to_snowflake(field)
|
|
8
11
|
elif server_type == "postgres":
|
|
@@ -19,6 +22,7 @@ def convert_to_sql_type(field: Field, server_type: str) -> str:
|
|
|
19
22
|
return convert_type_to_bigquery(field)
|
|
20
23
|
elif server_type == "trino":
|
|
21
24
|
return convert_type_to_trino(field)
|
|
25
|
+
|
|
22
26
|
return field.type
|
|
23
27
|
|
|
24
28
|
|
|
@@ -130,13 +130,23 @@ def import_record_fields(record_fields: List[avro.schema.Field]) -> Dict[str, Fi
|
|
|
130
130
|
imported_field.fields = import_record_fields(field.type.fields)
|
|
131
131
|
elif field.type.type == "union":
|
|
132
132
|
imported_field.required = False
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
if
|
|
136
|
-
imported_field.
|
|
137
|
-
|
|
138
|
-
imported_field.
|
|
139
|
-
imported_field.
|
|
133
|
+
# Check for enum in union first, since it needs special handling
|
|
134
|
+
enum_schema = get_enum_from_union_field(field)
|
|
135
|
+
if enum_schema:
|
|
136
|
+
imported_field.type = "string"
|
|
137
|
+
imported_field.enum = enum_schema.symbols
|
|
138
|
+
imported_field.title = enum_schema.name
|
|
139
|
+
if not imported_field.config:
|
|
140
|
+
imported_field.config = {}
|
|
141
|
+
imported_field.config["avroType"] = "enum"
|
|
142
|
+
else:
|
|
143
|
+
type = import_type_of_optional_field(field)
|
|
144
|
+
imported_field.type = type
|
|
145
|
+
if type == "record":
|
|
146
|
+
imported_field.fields = import_record_fields(get_record_from_union_field(field).fields)
|
|
147
|
+
elif type == "array":
|
|
148
|
+
imported_field.type = "array"
|
|
149
|
+
imported_field.items = import_avro_array_items(get_array_from_union_field(field))
|
|
140
150
|
elif field.type.type == "array":
|
|
141
151
|
imported_field.type = "array"
|
|
142
152
|
imported_field.items = import_avro_array_items(field.type)
|
|
@@ -277,6 +287,22 @@ def get_array_from_union_field(field: avro.schema.Field) -> avro.schema.ArraySch
|
|
|
277
287
|
return None
|
|
278
288
|
|
|
279
289
|
|
|
290
|
+
def get_enum_from_union_field(field: avro.schema.Field) -> avro.schema.EnumSchema | None:
|
|
291
|
+
"""
|
|
292
|
+
Get the enum schema from a union field.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
field: The Avro field with a union type.
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
The enum schema if found, None otherwise.
|
|
299
|
+
"""
|
|
300
|
+
for field_type in field.type.schemas:
|
|
301
|
+
if field_type.type == "enum":
|
|
302
|
+
return field_type
|
|
303
|
+
return None
|
|
304
|
+
|
|
305
|
+
|
|
280
306
|
def map_type_from_avro(avro_type_str: str) -> str:
|
|
281
307
|
"""
|
|
282
308
|
Map Avro type strings to data contract type strings.
|
|
@@ -131,14 +131,18 @@ def import_servers(odcs: OpenDataContractStandard) -> Dict[str, Server] | None:
|
|
|
131
131
|
server.host = odcs_server.host
|
|
132
132
|
server.port = odcs_server.port
|
|
133
133
|
server.catalog = odcs_server.catalog
|
|
134
|
+
server.stagingDir = odcs_server.stagingDir
|
|
134
135
|
server.topic = getattr(odcs_server, "topic", None)
|
|
135
136
|
server.http_path = getattr(odcs_server, "http_path", None)
|
|
136
137
|
server.token = getattr(odcs_server, "token", None)
|
|
137
138
|
server.driver = getattr(odcs_server, "driver", None)
|
|
138
139
|
server.roles = import_server_roles(odcs_server.roles)
|
|
139
140
|
server.storageAccount = (
|
|
140
|
-
|
|
141
|
+
to_azure_storage_account(odcs_server.location)
|
|
142
|
+
if server.type == "azure" and "://" in server.location
|
|
143
|
+
else None
|
|
141
144
|
)
|
|
145
|
+
|
|
142
146
|
servers[server_name] = server
|
|
143
147
|
return servers
|
|
144
148
|
|
|
@@ -413,3 +417,28 @@ def import_tags(odcs: OpenDataContractStandard) -> List[str] | None:
|
|
|
413
417
|
if odcs.tags is None:
|
|
414
418
|
return None
|
|
415
419
|
return odcs.tags
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def to_azure_storage_account(location: str) -> str | None:
|
|
423
|
+
"""
|
|
424
|
+
Converts a storage location string to extract the storage account name.
|
|
425
|
+
ODCS v3.0 has no explicit field for the storage account. It uses the location field, which is a URI.
|
|
426
|
+
|
|
427
|
+
This function parses a storage location string to identify and return the
|
|
428
|
+
storage account name. It handles two primary patterns:
|
|
429
|
+
1. Protocol://containerName@storageAccountName
|
|
430
|
+
2. Protocol://storageAccountName
|
|
431
|
+
|
|
432
|
+
:param location: The storage location string to parse, typically following
|
|
433
|
+
the format protocol://containerName@storageAccountName. or
|
|
434
|
+
protocol://storageAccountName.
|
|
435
|
+
:return: The extracted storage account name if found, otherwise None
|
|
436
|
+
"""
|
|
437
|
+
# to catch protocol://containerName@storageAccountName. pattern from location
|
|
438
|
+
match = re.search(r"(?<=@)([^.]*)", location, re.IGNORECASE)
|
|
439
|
+
if match:
|
|
440
|
+
return match.group()
|
|
441
|
+
else:
|
|
442
|
+
# to catch protocol://storageAccountName. pattern from location
|
|
443
|
+
match = re.search(r"(?<=//)(?!@)([^.]*)", location, re.IGNORECASE)
|
|
444
|
+
return match.group() if match else None
|