datacontract-cli 0.10.23__py3-none-any.whl → 0.10.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datacontract/__init__.py +13 -0
- datacontract/api.py +12 -5
- datacontract/catalog/catalog.py +5 -3
- datacontract/cli.py +116 -10
- datacontract/data_contract.py +143 -65
- datacontract/engines/data_contract_checks.py +366 -60
- datacontract/engines/data_contract_test.py +50 -4
- datacontract/engines/fastjsonschema/check_jsonschema.py +37 -19
- datacontract/engines/fastjsonschema/s3/s3_read_files.py +3 -2
- datacontract/engines/soda/check_soda_execute.py +22 -3
- datacontract/engines/soda/connections/athena.py +79 -0
- datacontract/engines/soda/connections/duckdb_connection.py +65 -6
- datacontract/engines/soda/connections/kafka.py +4 -2
- datacontract/export/avro_converter.py +20 -3
- datacontract/export/bigquery_converter.py +1 -1
- datacontract/export/dbt_converter.py +36 -7
- datacontract/export/dqx_converter.py +126 -0
- datacontract/export/duckdb_type_converter.py +57 -0
- datacontract/export/excel_exporter.py +923 -0
- datacontract/export/exporter.py +3 -0
- datacontract/export/exporter_factory.py +17 -1
- datacontract/export/great_expectations_converter.py +55 -5
- datacontract/export/{html_export.py → html_exporter.py} +31 -20
- datacontract/export/markdown_converter.py +134 -5
- datacontract/export/mermaid_exporter.py +110 -0
- datacontract/export/odcs_v3_exporter.py +187 -145
- datacontract/export/protobuf_converter.py +163 -69
- datacontract/export/rdf_converter.py +2 -2
- datacontract/export/sodacl_converter.py +9 -1
- datacontract/export/spark_converter.py +31 -4
- datacontract/export/sql_converter.py +6 -2
- datacontract/export/sql_type_converter.py +20 -8
- datacontract/imports/avro_importer.py +63 -12
- datacontract/imports/csv_importer.py +111 -57
- datacontract/imports/excel_importer.py +1111 -0
- datacontract/imports/importer.py +16 -3
- datacontract/imports/importer_factory.py +17 -0
- datacontract/imports/json_importer.py +325 -0
- datacontract/imports/odcs_importer.py +2 -2
- datacontract/imports/odcs_v3_importer.py +351 -151
- datacontract/imports/protobuf_importer.py +264 -0
- datacontract/imports/spark_importer.py +117 -13
- datacontract/imports/sql_importer.py +32 -16
- datacontract/imports/unity_importer.py +84 -38
- datacontract/init/init_template.py +1 -1
- datacontract/integration/datamesh_manager.py +16 -2
- datacontract/lint/resolve.py +112 -23
- datacontract/lint/schema.py +24 -15
- datacontract/model/data_contract_specification/__init__.py +1 -0
- datacontract/model/odcs.py +13 -0
- datacontract/model/run.py +3 -0
- datacontract/output/junit_test_results.py +3 -3
- datacontract/schemas/datacontract-1.1.0.init.yaml +1 -1
- datacontract/schemas/datacontract-1.2.0.init.yaml +91 -0
- datacontract/schemas/datacontract-1.2.0.schema.json +2029 -0
- datacontract/schemas/datacontract-1.2.1.init.yaml +91 -0
- datacontract/schemas/datacontract-1.2.1.schema.json +2058 -0
- datacontract/schemas/odcs-3.0.2.schema.json +2382 -0
- datacontract/templates/datacontract.html +54 -3
- datacontract/templates/datacontract_odcs.html +685 -0
- datacontract/templates/index.html +5 -2
- datacontract/templates/partials/server.html +2 -0
- datacontract/templates/style/output.css +319 -145
- {datacontract_cli-0.10.23.dist-info → datacontract_cli-0.10.37.dist-info}/METADATA +656 -431
- datacontract_cli-0.10.37.dist-info/RECORD +119 -0
- {datacontract_cli-0.10.23.dist-info → datacontract_cli-0.10.37.dist-info}/WHEEL +1 -1
- {datacontract_cli-0.10.23.dist-info → datacontract_cli-0.10.37.dist-info/licenses}/LICENSE +1 -1
- datacontract/export/csv_type_converter.py +0 -36
- datacontract/lint/lint.py +0 -142
- datacontract/lint/linters/description_linter.py +0 -35
- datacontract/lint/linters/field_pattern_linter.py +0 -34
- datacontract/lint/linters/field_reference_linter.py +0 -48
- datacontract/lint/linters/notice_period_linter.py +0 -55
- datacontract/lint/linters/quality_schema_linter.py +0 -52
- datacontract/lint/linters/valid_constraints_linter.py +0 -100
- datacontract/model/data_contract_specification.py +0 -327
- datacontract_cli-0.10.23.dist-info/RECORD +0 -113
- /datacontract/{lint/linters → output}/__init__.py +0 -0
- {datacontract_cli-0.10.23.dist-info → datacontract_cli-0.10.37.dist-info}/entry_points.txt +0 -0
- {datacontract_cli-0.10.23.dist-info → datacontract_cli-0.10.37.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import tempfile
|
|
4
|
+
|
|
5
|
+
from google.protobuf import descriptor_pb2
|
|
6
|
+
from grpc_tools import protoc
|
|
7
|
+
|
|
8
|
+
from datacontract.imports.importer import Importer
|
|
9
|
+
from datacontract.model.data_contract_specification import DataContractSpecification
|
|
10
|
+
from datacontract.model.exceptions import DataContractException
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def map_type_from_protobuf(field_type: int):
|
|
14
|
+
protobuf_type_mapping = {
|
|
15
|
+
1: "double",
|
|
16
|
+
2: "float",
|
|
17
|
+
3: "long",
|
|
18
|
+
4: "long", # uint64 mapped to long
|
|
19
|
+
5: "integer", # int32 mapped to integer
|
|
20
|
+
6: "string", # fixed64 mapped to string
|
|
21
|
+
7: "string", # fixed32 mapped to string
|
|
22
|
+
8: "boolean",
|
|
23
|
+
9: "string",
|
|
24
|
+
12: "bytes",
|
|
25
|
+
13: "integer", # uint32 mapped to integer
|
|
26
|
+
15: "integer", # sfixed32 mapped to integer
|
|
27
|
+
16: "long", # sfixed64 mapped to long
|
|
28
|
+
17: "integer", # sint32 mapped to integer
|
|
29
|
+
18: "long", # sint64 mapped to long
|
|
30
|
+
}
|
|
31
|
+
return protobuf_type_mapping.get(field_type, "string")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def parse_imports(proto_file: str) -> list:
|
|
35
|
+
"""
|
|
36
|
+
Parse import statements from a .proto file and return a list of imported file paths.
|
|
37
|
+
"""
|
|
38
|
+
try:
|
|
39
|
+
with open(proto_file, "r") as f:
|
|
40
|
+
content = f.read()
|
|
41
|
+
except Exception as e:
|
|
42
|
+
raise DataContractException(
|
|
43
|
+
type="file",
|
|
44
|
+
name="Parse proto imports",
|
|
45
|
+
reason=f"Failed to read proto file: {proto_file}",
|
|
46
|
+
engine="datacontract",
|
|
47
|
+
original_exception=e,
|
|
48
|
+
)
|
|
49
|
+
imported_files = re.findall(r'import\s+"(.+?)";', content)
|
|
50
|
+
proto_dir = os.path.dirname(proto_file)
|
|
51
|
+
return [os.path.join(proto_dir, imp) for imp in imported_files]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def compile_proto_to_binary(proto_files: list, output_file: str):
|
|
55
|
+
"""
|
|
56
|
+
Compile the provided proto files into a single descriptor set using grpc_tools.protoc.
|
|
57
|
+
"""
|
|
58
|
+
proto_dirs = set(os.path.dirname(proto) for proto in proto_files)
|
|
59
|
+
proto_paths = [f"--proto_path={d}" for d in proto_dirs]
|
|
60
|
+
|
|
61
|
+
args = [""] + proto_paths + [f"--descriptor_set_out={output_file}"] + proto_files
|
|
62
|
+
ret = protoc.main(args)
|
|
63
|
+
if ret != 0:
|
|
64
|
+
raise DataContractException(
|
|
65
|
+
type="schema",
|
|
66
|
+
name="Compile proto files",
|
|
67
|
+
reason=f"grpc_tools.protoc failed with exit code {ret}",
|
|
68
|
+
engine="datacontract",
|
|
69
|
+
original_exception=None,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def extract_enum_values_from_fds(fds: descriptor_pb2.FileDescriptorSet, enum_name: str) -> dict:
|
|
74
|
+
"""
|
|
75
|
+
Search the FileDescriptorSet for an enum definition with the given name
|
|
76
|
+
and return a dictionary of its values (name to number).
|
|
77
|
+
"""
|
|
78
|
+
for file_descriptor in fds.file:
|
|
79
|
+
# Check top-level enums.
|
|
80
|
+
for enum in file_descriptor.enum_type:
|
|
81
|
+
if enum.name == enum_name:
|
|
82
|
+
return {value.name: value.number for value in enum.value}
|
|
83
|
+
# Check enums defined inside messages.
|
|
84
|
+
for message in file_descriptor.message_type:
|
|
85
|
+
for enum in message.enum_type:
|
|
86
|
+
if enum.name == enum_name:
|
|
87
|
+
return {value.name: value.number for value in enum.value}
|
|
88
|
+
return {}
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def extract_message_fields_from_fds(fds: descriptor_pb2.FileDescriptorSet, message_name: str) -> dict:
|
|
92
|
+
"""
|
|
93
|
+
Given a FileDescriptorSet and a message name, return a dict with its field definitions.
|
|
94
|
+
This function recurses for nested messages and handles enums.
|
|
95
|
+
"""
|
|
96
|
+
for file_descriptor in fds.file:
|
|
97
|
+
for msg in file_descriptor.message_type:
|
|
98
|
+
if msg.name == message_name:
|
|
99
|
+
fields = {}
|
|
100
|
+
for field in msg.field:
|
|
101
|
+
if field.type == 11: # TYPE_MESSAGE
|
|
102
|
+
nested_msg_name = field.type_name.split(".")[-1]
|
|
103
|
+
nested_fields = extract_message_fields_from_fds(fds, nested_msg_name)
|
|
104
|
+
if field.label == 3: # repeated field
|
|
105
|
+
field_info = {
|
|
106
|
+
"description": f"List of {nested_msg_name}",
|
|
107
|
+
"type": "array",
|
|
108
|
+
"items": {"type": "object", "fields": nested_fields},
|
|
109
|
+
}
|
|
110
|
+
else:
|
|
111
|
+
field_info = {
|
|
112
|
+
"description": f"Nested object of {nested_msg_name}",
|
|
113
|
+
"type": "object",
|
|
114
|
+
"fields": nested_fields,
|
|
115
|
+
}
|
|
116
|
+
elif field.type == 14: # TYPE_ENUM
|
|
117
|
+
enum_name = field.type_name.split(".")[-1]
|
|
118
|
+
enum_values = extract_enum_values_from_fds(fds, enum_name)
|
|
119
|
+
field_info = {
|
|
120
|
+
"description": f"Enum field {field.name}",
|
|
121
|
+
"type": "string",
|
|
122
|
+
"values": enum_values,
|
|
123
|
+
"required": (field.label == 2),
|
|
124
|
+
}
|
|
125
|
+
else:
|
|
126
|
+
field_info = {
|
|
127
|
+
"description": f"Field {field.name}",
|
|
128
|
+
"type": map_type_from_protobuf(field.type),
|
|
129
|
+
"required": (field.label == 2),
|
|
130
|
+
}
|
|
131
|
+
fields[field.name] = field_info
|
|
132
|
+
return fields
|
|
133
|
+
return {}
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def import_protobuf(
|
|
137
|
+
data_contract_specification: DataContractSpecification, sources: list, import_args: dict = None
|
|
138
|
+
) -> DataContractSpecification:
|
|
139
|
+
"""
|
|
140
|
+
Gather all proto files (including those imported), compile them into one descriptor,
|
|
141
|
+
then generate models with nested fields and enums resolved.
|
|
142
|
+
|
|
143
|
+
The generated data contract uses generic defaults instead of specific hardcoded ones.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
# --- Step 1: Gather all proto files (main and imported)
|
|
147
|
+
proto_files_set = set()
|
|
148
|
+
queue = list(sources)
|
|
149
|
+
while queue:
|
|
150
|
+
proto = queue.pop(0)
|
|
151
|
+
if proto not in proto_files_set:
|
|
152
|
+
proto_files_set.add(proto)
|
|
153
|
+
for imp in parse_imports(proto):
|
|
154
|
+
if os.path.exists(imp) and imp not in proto_files_set:
|
|
155
|
+
queue.append(imp)
|
|
156
|
+
all_proto_files = list(proto_files_set)
|
|
157
|
+
|
|
158
|
+
# --- Step 2: Compile all proto files into a single descriptor set.
|
|
159
|
+
temp_descriptor = tempfile.NamedTemporaryFile(suffix=".pb", delete=False)
|
|
160
|
+
descriptor_file = temp_descriptor.name
|
|
161
|
+
temp_descriptor.close() # Allow protoc to write to the file
|
|
162
|
+
try:
|
|
163
|
+
compile_proto_to_binary(all_proto_files, descriptor_file)
|
|
164
|
+
|
|
165
|
+
with open(descriptor_file, "rb") as f:
|
|
166
|
+
proto_data = f.read()
|
|
167
|
+
fds = descriptor_pb2.FileDescriptorSet()
|
|
168
|
+
try:
|
|
169
|
+
fds.ParseFromString(proto_data)
|
|
170
|
+
except Exception as e:
|
|
171
|
+
raise DataContractException(
|
|
172
|
+
type="schema",
|
|
173
|
+
name="Parse descriptor set",
|
|
174
|
+
reason="Failed to parse descriptor set from compiled proto files",
|
|
175
|
+
engine="datacontract",
|
|
176
|
+
original_exception=e,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# --- Step 3: Build models from the descriptor set.
|
|
180
|
+
all_models = {}
|
|
181
|
+
# Create a set of the main proto file basenames.
|
|
182
|
+
source_proto_basenames = {os.path.basename(proto) for proto in sources}
|
|
183
|
+
|
|
184
|
+
for file_descriptor in fds.file:
|
|
185
|
+
# Only process file descriptors that correspond to your main proto files.
|
|
186
|
+
if os.path.basename(file_descriptor.name) not in source_proto_basenames:
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
for message in file_descriptor.message_type:
|
|
190
|
+
fields = {}
|
|
191
|
+
for field in message.field:
|
|
192
|
+
if field.type == 11: # TYPE_MESSAGE
|
|
193
|
+
nested_msg_name = field.type_name.split(".")[-1]
|
|
194
|
+
nested_fields = extract_message_fields_from_fds(fds, nested_msg_name)
|
|
195
|
+
if field.label == 3:
|
|
196
|
+
field_info = {
|
|
197
|
+
"description": f"List of {nested_msg_name}",
|
|
198
|
+
"type": "array",
|
|
199
|
+
"items": {"type": "object", "fields": nested_fields},
|
|
200
|
+
}
|
|
201
|
+
else:
|
|
202
|
+
field_info = {
|
|
203
|
+
"description": f"Nested object of {nested_msg_name}",
|
|
204
|
+
"type": "object",
|
|
205
|
+
"fields": nested_fields,
|
|
206
|
+
}
|
|
207
|
+
fields[field.name] = field_info
|
|
208
|
+
elif field.type == 14: # TYPE_ENUM
|
|
209
|
+
enum_name = field.type_name.split(".")[-1]
|
|
210
|
+
enum_values = extract_enum_values_from_fds(fds, enum_name)
|
|
211
|
+
field_info = {
|
|
212
|
+
"description": f"Enum field {field.name}",
|
|
213
|
+
"type": "string",
|
|
214
|
+
"values": enum_values,
|
|
215
|
+
"required": (field.label == 2),
|
|
216
|
+
}
|
|
217
|
+
fields[field.name] = field_info
|
|
218
|
+
else:
|
|
219
|
+
field_info = {
|
|
220
|
+
"description": f"Field {field.name}",
|
|
221
|
+
"type": map_type_from_protobuf(field.type),
|
|
222
|
+
"required": (field.label == 2),
|
|
223
|
+
}
|
|
224
|
+
fields[field.name] = field_info
|
|
225
|
+
|
|
226
|
+
all_models[message.name] = {
|
|
227
|
+
"description": f"Details of {message.name}.",
|
|
228
|
+
"type": "table",
|
|
229
|
+
"fields": fields,
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
data_contract_specification.models = all_models
|
|
233
|
+
|
|
234
|
+
return data_contract_specification
|
|
235
|
+
finally:
|
|
236
|
+
# Clean up the temporary descriptor file.
|
|
237
|
+
if os.path.exists(descriptor_file):
|
|
238
|
+
os.remove(descriptor_file)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class ProtoBufImporter(Importer):
|
|
242
|
+
def __init__(self, name):
|
|
243
|
+
# 'name' is passed by the importer factory.
|
|
244
|
+
self.name = name
|
|
245
|
+
|
|
246
|
+
def import_source(
|
|
247
|
+
self,
|
|
248
|
+
data_contract_specification: DataContractSpecification,
|
|
249
|
+
source: str,
|
|
250
|
+
import_args: dict = None,
|
|
251
|
+
) -> DataContractSpecification:
|
|
252
|
+
"""
|
|
253
|
+
Import a protobuf file (and its imports) into the given DataContractSpecification.
|
|
254
|
+
|
|
255
|
+
Parameters:
|
|
256
|
+
- data_contract_specification: the initial specification to update.
|
|
257
|
+
- source: the protobuf file path.
|
|
258
|
+
- import_args: optional dictionary with additional arguments (e.g. 'output_dir').
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
The updated DataContractSpecification.
|
|
262
|
+
"""
|
|
263
|
+
# Wrap the source in a list because import_protobuf expects a list of sources.
|
|
264
|
+
return import_protobuf(data_contract_specification, [source], import_args)
|
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
import atexit
|
|
2
|
+
import logging
|
|
3
|
+
import tempfile
|
|
4
|
+
|
|
5
|
+
from databricks.sdk import WorkspaceClient
|
|
1
6
|
from pyspark.sql import DataFrame, SparkSession, types
|
|
2
7
|
|
|
3
8
|
from datacontract.imports.importer import Importer
|
|
@@ -8,6 +13,8 @@ from datacontract.model.data_contract_specification import (
|
|
|
8
13
|
Server,
|
|
9
14
|
)
|
|
10
15
|
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
11
18
|
|
|
12
19
|
class SparkImporter(Importer):
|
|
13
20
|
def import_source(
|
|
@@ -23,39 +30,69 @@ class SparkImporter(Importer):
|
|
|
23
30
|
data_contract_specification: The data contract specification object.
|
|
24
31
|
source: The source string indicating the Spark tables to read.
|
|
25
32
|
import_args: Additional arguments for the import process.
|
|
26
|
-
|
|
27
33
|
Returns:
|
|
28
34
|
dict: The updated data contract specification.
|
|
29
35
|
"""
|
|
30
|
-
|
|
36
|
+
dataframe = import_args.get("dataframe", None)
|
|
37
|
+
description = import_args.get("description", None)
|
|
38
|
+
return import_spark(data_contract_specification, source, dataframe, description)
|
|
31
39
|
|
|
32
40
|
|
|
33
|
-
def import_spark(
|
|
41
|
+
def import_spark(
|
|
42
|
+
data_contract_specification: DataContractSpecification,
|
|
43
|
+
source: str,
|
|
44
|
+
dataframe: DataFrame | None = None,
|
|
45
|
+
description: str | None = None,
|
|
46
|
+
) -> DataContractSpecification:
|
|
34
47
|
"""
|
|
35
|
-
|
|
48
|
+
Imports schema(s) from Spark into a Data Contract Specification.
|
|
36
49
|
|
|
37
50
|
Args:
|
|
38
|
-
data_contract_specification: The
|
|
39
|
-
source:
|
|
51
|
+
data_contract_specification (DataContractSpecification): The contract spec to update.
|
|
52
|
+
source (str): Comma-separated Spark table/view names.
|
|
53
|
+
dataframe (DataFrame | None): Optional Spark DataFrame to import.
|
|
54
|
+
description (str | None): Optional table-level description.
|
|
40
55
|
|
|
41
56
|
Returns:
|
|
42
|
-
DataContractSpecification: The updated
|
|
57
|
+
DataContractSpecification: The updated contract spec with imported models.
|
|
43
58
|
"""
|
|
44
|
-
|
|
59
|
+
|
|
60
|
+
tmp_dir = tempfile.TemporaryDirectory(prefix="datacontract-cli-spark")
|
|
61
|
+
atexit.register(tmp_dir.cleanup)
|
|
62
|
+
|
|
63
|
+
spark = (
|
|
64
|
+
SparkSession.builder.config("spark.sql.warehouse.dir", f"{tmp_dir}/spark-warehouse")
|
|
65
|
+
.config("spark.streaming.stopGracefullyOnShutdown", "true")
|
|
66
|
+
.config("spark.ui.enabled", "false")
|
|
67
|
+
.getOrCreate()
|
|
68
|
+
)
|
|
45
69
|
data_contract_specification.servers["local"] = Server(type="dataframe")
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
70
|
+
|
|
71
|
+
if dataframe is not None:
|
|
72
|
+
if not isinstance(dataframe, DataFrame):
|
|
73
|
+
raise TypeError("Expected 'dataframe' to be a pyspark.sql.DataFrame")
|
|
74
|
+
data_contract_specification.models[source] = import_from_spark_df(spark, source, dataframe, description)
|
|
75
|
+
return data_contract_specification
|
|
76
|
+
|
|
77
|
+
if not source:
|
|
78
|
+
raise ValueError("Either 'dataframe' or a valid 'source' must be provided")
|
|
79
|
+
|
|
80
|
+
for table_name in map(str.strip, source.split(",")):
|
|
81
|
+
df = spark.read.table(table_name)
|
|
82
|
+
data_contract_specification.models[table_name] = import_from_spark_df(spark, table_name, df, description)
|
|
83
|
+
|
|
50
84
|
return data_contract_specification
|
|
51
85
|
|
|
52
86
|
|
|
53
|
-
def import_from_spark_df(df: DataFrame) -> Model:
|
|
87
|
+
def import_from_spark_df(spark: SparkSession, source: str, df: DataFrame, description: str) -> Model:
|
|
54
88
|
"""
|
|
55
89
|
Converts a Spark DataFrame into a Model.
|
|
56
90
|
|
|
57
91
|
Args:
|
|
92
|
+
spark: SparkSession
|
|
93
|
+
source: A comma-separated string of Spark temporary views to read.
|
|
58
94
|
df: The Spark DataFrame to convert.
|
|
95
|
+
description: Table level comment
|
|
59
96
|
|
|
60
97
|
Returns:
|
|
61
98
|
Model: The generated data contract model.
|
|
@@ -63,6 +100,11 @@ def import_from_spark_df(df: DataFrame) -> Model:
|
|
|
63
100
|
model = Model()
|
|
64
101
|
schema = df.schema
|
|
65
102
|
|
|
103
|
+
if description is None:
|
|
104
|
+
model.description = _table_comment_from_spark(spark, source)
|
|
105
|
+
else:
|
|
106
|
+
model.description = description
|
|
107
|
+
|
|
66
108
|
for field in schema:
|
|
67
109
|
model.fields[field.name] = _field_from_struct_type(field)
|
|
68
110
|
|
|
@@ -154,5 +196,67 @@ def _data_type_from_spark(spark_type: types.DataType) -> str:
|
|
|
154
196
|
return "null"
|
|
155
197
|
elif isinstance(spark_type, types.VarcharType):
|
|
156
198
|
return "varchar"
|
|
199
|
+
elif isinstance(spark_type, types.VariantType):
|
|
200
|
+
return "variant"
|
|
157
201
|
else:
|
|
158
202
|
raise ValueError(f"Unsupported Spark type: {spark_type}")
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _table_comment_from_spark(spark: SparkSession, source: str):
|
|
206
|
+
"""
|
|
207
|
+
Attempts to retrieve the table-level comment from a Spark table using multiple fallback methods.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
spark (SparkSession): The active Spark session.
|
|
211
|
+
source (str): The name of the table (without catalog or schema).
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
str or None: The table-level comment, if found.
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
# Get Current Catalog and Schema from Spark Session
|
|
218
|
+
try:
|
|
219
|
+
current_catalog = spark.sql("SELECT current_catalog()").collect()[0][0]
|
|
220
|
+
except Exception:
|
|
221
|
+
current_catalog = "hive_metastore" # Fallback for non-Unity Catalog clusters
|
|
222
|
+
try:
|
|
223
|
+
current_schema = spark.catalog.currentDatabase()
|
|
224
|
+
except Exception:
|
|
225
|
+
current_schema = spark.sql("SELECT current_database()").collect()[0][0]
|
|
226
|
+
|
|
227
|
+
# Get table comment if it exists
|
|
228
|
+
table_comment = ""
|
|
229
|
+
source = f"{current_catalog}.{current_schema}.{source}"
|
|
230
|
+
try:
|
|
231
|
+
# Initialize WorkspaceClient for Unity Catalog API calls
|
|
232
|
+
workspace_client = WorkspaceClient()
|
|
233
|
+
created_table = workspace_client.tables.get(full_name=f"{source}")
|
|
234
|
+
table_comment = created_table.comment
|
|
235
|
+
logger.info(f"'{source}' table comment retrieved using 'WorkspaceClient.tables.get({source})'")
|
|
236
|
+
return table_comment
|
|
237
|
+
except Exception:
|
|
238
|
+
pass
|
|
239
|
+
|
|
240
|
+
# Fallback to Spark Catalog API for Hive Metastore or Non-UC Tables
|
|
241
|
+
try:
|
|
242
|
+
table_comment = spark.catalog.getTable(f"{source}").description
|
|
243
|
+
logger.info(f"'{source}' table comment retrieved using 'spark.catalog.getTable({source}).description'")
|
|
244
|
+
return table_comment
|
|
245
|
+
except Exception:
|
|
246
|
+
pass
|
|
247
|
+
|
|
248
|
+
# Final Fallback Using DESCRIBE TABLE EXTENDED
|
|
249
|
+
try:
|
|
250
|
+
rows = spark.sql(f"DESCRIBE TABLE EXTENDED {source}").collect()
|
|
251
|
+
for row in rows:
|
|
252
|
+
if row.col_name.strip().lower() == "comment":
|
|
253
|
+
table_comment = row.data_type
|
|
254
|
+
break
|
|
255
|
+
logger.info(f"'{source}' table comment retrieved using 'DESCRIBE TABLE EXTENDED {source}'")
|
|
256
|
+
return table_comment
|
|
257
|
+
except Exception:
|
|
258
|
+
pass
|
|
259
|
+
|
|
260
|
+
logger.info(f"{source} table comment could not be retrieved")
|
|
261
|
+
|
|
262
|
+
return None
|
|
@@ -105,7 +105,7 @@ def to_dialect(import_args: dict) -> Dialects | None:
|
|
|
105
105
|
return None
|
|
106
106
|
|
|
107
107
|
|
|
108
|
-
def to_physical_type_key(dialect: Dialects | None) -> str:
|
|
108
|
+
def to_physical_type_key(dialect: Dialects | str | None) -> str:
|
|
109
109
|
dialect_map = {
|
|
110
110
|
Dialects.TSQL: "sqlserverType",
|
|
111
111
|
Dialects.POSTGRES: "postgresType",
|
|
@@ -116,6 +116,8 @@ def to_physical_type_key(dialect: Dialects | None) -> str:
|
|
|
116
116
|
Dialects.MYSQL: "mysqlType",
|
|
117
117
|
Dialects.DATABRICKS: "databricksType",
|
|
118
118
|
}
|
|
119
|
+
if isinstance(dialect, str):
|
|
120
|
+
dialect = Dialects[dialect.upper()] if dialect.upper() in Dialects.__members__ else None
|
|
119
121
|
return dialect_map.get(dialect, "physicalType")
|
|
120
122
|
|
|
121
123
|
|
|
@@ -198,7 +200,7 @@ def get_precision_scale(column):
|
|
|
198
200
|
return None, None
|
|
199
201
|
|
|
200
202
|
|
|
201
|
-
def map_type_from_sql(sql_type: str):
|
|
203
|
+
def map_type_from_sql(sql_type: str) -> str | None:
|
|
202
204
|
if sql_type is None:
|
|
203
205
|
return None
|
|
204
206
|
|
|
@@ -218,7 +220,7 @@ def map_type_from_sql(sql_type: str):
|
|
|
218
220
|
return "string"
|
|
219
221
|
elif sql_type_normed.startswith("ntext"):
|
|
220
222
|
return "string"
|
|
221
|
-
elif sql_type_normed.startswith("int"):
|
|
223
|
+
elif sql_type_normed.startswith("int") and not sql_type_normed.startswith("interval"):
|
|
222
224
|
return "int"
|
|
223
225
|
elif sql_type_normed.startswith("bigint"):
|
|
224
226
|
return "long"
|
|
@@ -228,6 +230,8 @@ def map_type_from_sql(sql_type: str):
|
|
|
228
230
|
return "int"
|
|
229
231
|
elif sql_type_normed.startswith("float"):
|
|
230
232
|
return "float"
|
|
233
|
+
elif sql_type_normed.startswith("double"):
|
|
234
|
+
return "double"
|
|
231
235
|
elif sql_type_normed.startswith("decimal"):
|
|
232
236
|
return "decimal"
|
|
233
237
|
elif sql_type_normed.startswith("numeric"):
|
|
@@ -240,26 +244,20 @@ def map_type_from_sql(sql_type: str):
|
|
|
240
244
|
return "bytes"
|
|
241
245
|
elif sql_type_normed.startswith("varbinary"):
|
|
242
246
|
return "bytes"
|
|
247
|
+
elif sql_type_normed.startswith("raw"):
|
|
248
|
+
return "bytes"
|
|
249
|
+
elif sql_type_normed == "blob" or sql_type_normed == "bfile":
|
|
250
|
+
return "bytes"
|
|
243
251
|
elif sql_type_normed == "date":
|
|
244
252
|
return "date"
|
|
245
253
|
elif sql_type_normed == "time":
|
|
246
254
|
return "string"
|
|
247
|
-
elif sql_type_normed
|
|
248
|
-
return
|
|
249
|
-
elif
|
|
250
|
-
sql_type_normed == "timestamptz"
|
|
251
|
-
or sql_type_normed == "timestamp_tz"
|
|
252
|
-
or sql_type_normed == "timestamp with time zone"
|
|
253
|
-
):
|
|
254
|
-
return "timestamp_tz"
|
|
255
|
-
elif sql_type_normed == "timestampntz" or sql_type_normed == "timestamp_ntz":
|
|
255
|
+
elif sql_type_normed.startswith("timestamp"):
|
|
256
|
+
return map_timestamp(sql_type_normed)
|
|
257
|
+
elif sql_type_normed == "datetime" or sql_type_normed == "datetime2":
|
|
256
258
|
return "timestamp_ntz"
|
|
257
259
|
elif sql_type_normed == "smalldatetime":
|
|
258
260
|
return "timestamp_ntz"
|
|
259
|
-
elif sql_type_normed == "datetime":
|
|
260
|
-
return "timestamp_ntz"
|
|
261
|
-
elif sql_type_normed == "datetime2":
|
|
262
|
-
return "timestamp_ntz"
|
|
263
261
|
elif sql_type_normed == "datetimeoffset":
|
|
264
262
|
return "timestamp_tz"
|
|
265
263
|
elif sql_type_normed == "uniqueidentifier": # tsql
|
|
@@ -268,10 +266,28 @@ def map_type_from_sql(sql_type: str):
|
|
|
268
266
|
return "string"
|
|
269
267
|
elif sql_type_normed == "xml": # tsql
|
|
270
268
|
return "string"
|
|
269
|
+
elif sql_type_normed.startswith("number"):
|
|
270
|
+
return "number"
|
|
271
|
+
elif sql_type_normed == "clob" or sql_type_normed == "nclob":
|
|
272
|
+
return "text"
|
|
271
273
|
else:
|
|
272
274
|
return "variant"
|
|
273
275
|
|
|
274
276
|
|
|
277
|
+
def map_timestamp(timestamp_type: str) -> str:
|
|
278
|
+
match timestamp_type:
|
|
279
|
+
case "timestamp" | "timestampntz" | "timestamp_ntz":
|
|
280
|
+
return "timestamp_ntz"
|
|
281
|
+
case "timestamptz" | "timestamp_tz" | "timestamp with time zone":
|
|
282
|
+
return "timestamp_tz"
|
|
283
|
+
case localTimezone if localTimezone.startswith("timestampltz"):
|
|
284
|
+
return "timestamp_tz"
|
|
285
|
+
case timezoneWrittenOut if timezoneWrittenOut.endswith("time zone"):
|
|
286
|
+
return "timestamp_tz"
|
|
287
|
+
case _:
|
|
288
|
+
return "timestamp"
|
|
289
|
+
|
|
290
|
+
|
|
275
291
|
def read_file(path):
|
|
276
292
|
if not os.path.exists(path):
|
|
277
293
|
raise DataContractException(
|