datacontract-cli 0.10.11__py3-none-any.whl → 0.10.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datacontract-cli might be problematic. Click here for more details.
- datacontract/cli.py +19 -3
- datacontract/data_contract.py +5 -10
- datacontract/engines/fastjsonschema/check_jsonschema.py +11 -0
- datacontract/engines/fastjsonschema/s3/s3_read_files.py +2 -0
- datacontract/engines/soda/check_soda_execute.py +2 -8
- datacontract/engines/soda/connections/duckdb.py +23 -24
- datacontract/engines/soda/connections/kafka.py +84 -25
- datacontract/export/avro_converter.py +12 -2
- datacontract/export/bigquery_converter.py +30 -23
- datacontract/export/data_caterer_converter.py +148 -0
- datacontract/export/dbml_converter.py +3 -2
- datacontract/export/exporter.py +2 -0
- datacontract/export/exporter_factory.py +12 -0
- datacontract/export/jsonschema_converter.py +13 -2
- datacontract/export/spark_converter.py +5 -1
- datacontract/export/sql_type_converter.py +65 -39
- datacontract/export/sqlalchemy_converter.py +169 -0
- datacontract/imports/avro_importer.py +1 -0
- datacontract/imports/bigquery_importer.py +2 -2
- datacontract/imports/dbml_importer.py +112 -0
- datacontract/imports/dbt_importer.py +67 -91
- datacontract/imports/glue_importer.py +62 -58
- datacontract/imports/importer.py +2 -1
- datacontract/imports/importer_factory.py +5 -0
- datacontract/imports/odcs_importer.py +1 -1
- datacontract/imports/spark_importer.py +34 -11
- datacontract/imports/sql_importer.py +1 -1
- datacontract/imports/unity_importer.py +106 -85
- datacontract/integration/{publish_datamesh_manager.py → datamesh_manager.py} +33 -5
- datacontract/integration/{publish_opentelemetry.py → opentelemetry.py} +1 -1
- datacontract/lint/resolve.py +10 -1
- datacontract/lint/urls.py +27 -13
- datacontract/model/data_contract_specification.py +6 -2
- {datacontract_cli-0.10.11.dist-info → datacontract_cli-0.10.13.dist-info}/METADATA +123 -32
- {datacontract_cli-0.10.11.dist-info → datacontract_cli-0.10.13.dist-info}/RECORD +39 -37
- {datacontract_cli-0.10.11.dist-info → datacontract_cli-0.10.13.dist-info}/WHEEL +1 -1
- datacontract/publish/publish.py +0 -32
- {datacontract_cli-0.10.11.dist-info → datacontract_cli-0.10.13.dist-info}/LICENSE +0 -0
- {datacontract_cli-0.10.11.dist-info → datacontract_cli-0.10.13.dist-info}/entry_points.txt +0 -0
- {datacontract_cli-0.10.11.dist-info → datacontract_cli-0.10.13.dist-info}/top_level.txt +0 -0
|
@@ -62,6 +62,12 @@ exporter_factory.register_lazy_exporter(
|
|
|
62
62
|
class_name="BigQueryExporter",
|
|
63
63
|
)
|
|
64
64
|
|
|
65
|
+
exporter_factory.register_lazy_exporter(
|
|
66
|
+
name=ExportFormat.data_caterer,
|
|
67
|
+
module_path="datacontract.export.data_caterer_converter",
|
|
68
|
+
class_name="DataCatererExporter",
|
|
69
|
+
)
|
|
70
|
+
|
|
65
71
|
exporter_factory.register_lazy_exporter(
|
|
66
72
|
name=ExportFormat.dbml, module_path="datacontract.export.dbml_converter", class_name="DbmlExporter"
|
|
67
73
|
)
|
|
@@ -143,3 +149,9 @@ exporter_factory.register_lazy_exporter(
|
|
|
143
149
|
exporter_factory.register_lazy_exporter(
|
|
144
150
|
name=ExportFormat.spark, module_path="datacontract.export.spark_converter", class_name="SparkExporter"
|
|
145
151
|
)
|
|
152
|
+
|
|
153
|
+
exporter_factory.register_lazy_exporter(
|
|
154
|
+
name=ExportFormat.sqlalchemy,
|
|
155
|
+
module_path="datacontract.export.sqlalchemy_converter",
|
|
156
|
+
class_name="SQLAlchemyExporter",
|
|
157
|
+
)
|
|
@@ -36,7 +36,19 @@ def to_property(field: Field) -> dict:
|
|
|
36
36
|
property = {}
|
|
37
37
|
json_type, json_format = convert_type_format(field.type, field.format)
|
|
38
38
|
if json_type is not None:
|
|
39
|
-
|
|
39
|
+
if not field.required:
|
|
40
|
+
"""
|
|
41
|
+
From: https://json-schema.org/understanding-json-schema/reference/type
|
|
42
|
+
The type keyword may either be a string or an array:
|
|
43
|
+
|
|
44
|
+
If it's a string, it is the name of one of the basic types above.
|
|
45
|
+
If it is an array, it must be an array of strings, where each string
|
|
46
|
+
is the name of one of the basic types, and each element is unique.
|
|
47
|
+
In this case, the JSON snippet is valid if it matches any of the given types.
|
|
48
|
+
"""
|
|
49
|
+
property["type"] = [json_type, "null"]
|
|
50
|
+
else:
|
|
51
|
+
property["type"] = json_type
|
|
40
52
|
if json_format is not None:
|
|
41
53
|
property["format"] = json_format
|
|
42
54
|
if field.unique:
|
|
@@ -50,7 +62,6 @@ def to_property(field: Field) -> dict:
|
|
|
50
62
|
property["required"] = to_required(field.fields)
|
|
51
63
|
if json_type == "array":
|
|
52
64
|
property["items"] = to_property(field.items)
|
|
53
|
-
|
|
54
65
|
if field.pattern:
|
|
55
66
|
property["pattern"] = field.pattern
|
|
56
67
|
if field.enum:
|
|
@@ -123,10 +123,12 @@ def to_data_type(field: Field) -> types.DataType:
|
|
|
123
123
|
return types.ArrayType(to_data_type(field.items))
|
|
124
124
|
if field_type in ["object", "record", "struct"]:
|
|
125
125
|
return types.StructType(to_struct_type(field.fields))
|
|
126
|
+
if field_type == "map":
|
|
127
|
+
return types.MapType(to_data_type(field.keys), to_data_type(field.values))
|
|
126
128
|
if field_type in ["string", "varchar", "text"]:
|
|
127
129
|
return types.StringType()
|
|
128
130
|
if field_type in ["number", "decimal", "numeric"]:
|
|
129
|
-
return types.DecimalType()
|
|
131
|
+
return types.DecimalType(precision=field.precision, scale=field.scale)
|
|
130
132
|
if field_type in ["integer", "int"]:
|
|
131
133
|
return types.IntegerType()
|
|
132
134
|
if field_type == "long":
|
|
@@ -204,6 +206,8 @@ def print_schema(dtype: types.DataType) -> str:
|
|
|
204
206
|
return format_struct_type(dtype)
|
|
205
207
|
elif isinstance(dtype, types.ArrayType):
|
|
206
208
|
return f"ArrayType({print_schema(dtype.elementType)})"
|
|
209
|
+
elif isinstance(dtype, types.MapType):
|
|
210
|
+
return f"MapType(\n{indent(print_schema(dtype.keyType), 1)}, {print_schema(dtype.valueType)})"
|
|
207
211
|
elif isinstance(dtype, types.DecimalType):
|
|
208
212
|
return f"DecimalType({dtype.precision}, {dtype.scale})"
|
|
209
213
|
else:
|
|
@@ -149,37 +149,72 @@ def convert_to_databricks(field: Field) -> None | str:
|
|
|
149
149
|
|
|
150
150
|
|
|
151
151
|
def convert_to_duckdb(field: Field) -> None | str:
|
|
152
|
-
|
|
153
|
-
|
|
152
|
+
"""
|
|
153
|
+
Convert a data contract field to the corresponding DuckDB SQL type.
|
|
154
|
+
|
|
155
|
+
Parameters:
|
|
156
|
+
field (Field): The data contract field to convert.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
str: The corresponding DuckDB SQL type.
|
|
160
|
+
"""
|
|
161
|
+
# Check
|
|
162
|
+
if field is None or field.type is None:
|
|
154
163
|
return None
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
164
|
+
|
|
165
|
+
# Get
|
|
166
|
+
type_lower = field.type.lower()
|
|
167
|
+
|
|
168
|
+
# Prepare
|
|
169
|
+
type_mapping = {
|
|
170
|
+
"varchar": "VARCHAR",
|
|
171
|
+
"string": "VARCHAR",
|
|
172
|
+
"text": "VARCHAR",
|
|
173
|
+
"binary": "BLOB",
|
|
174
|
+
"bytes": "BLOB",
|
|
175
|
+
"blob": "BLOB",
|
|
176
|
+
"boolean": "BOOLEAN",
|
|
177
|
+
"float": "FLOAT",
|
|
178
|
+
"double": "DOUBLE",
|
|
179
|
+
"int": "INTEGER",
|
|
180
|
+
"int32": "INTEGER",
|
|
181
|
+
"integer": "INTEGER",
|
|
182
|
+
"int64": "BIGINT",
|
|
183
|
+
"long": "BIGINT",
|
|
184
|
+
"bigint": "BIGINT",
|
|
185
|
+
"date": "DATE",
|
|
186
|
+
"time": "TIME",
|
|
187
|
+
"timestamp": "TIMESTAMP WITH TIME ZONE",
|
|
188
|
+
"timestamp_tz": "TIMESTAMP WITH TIME ZONE",
|
|
189
|
+
"timestamp_ntz": "DATETIME",
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
# Convert simple mappings
|
|
193
|
+
if type_lower in type_mapping:
|
|
194
|
+
return type_mapping[type_lower]
|
|
195
|
+
|
|
196
|
+
# convert decimal numbers with precision and scale
|
|
197
|
+
if type_lower == "decimal" or type_lower == "number" or type_lower == "numeric":
|
|
166
198
|
return f"DECIMAL({field.precision},{field.scale})"
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
if
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
199
|
+
|
|
200
|
+
# Check list and map
|
|
201
|
+
if type_lower == "list" or type_lower == "array":
|
|
202
|
+
item_type = convert_to_duckdb(field.items)
|
|
203
|
+
return f"{item_type}[]"
|
|
204
|
+
if type_lower == "map":
|
|
205
|
+
key_type = convert_to_duckdb(field.keys)
|
|
206
|
+
value_type = convert_to_duckdb(field.values)
|
|
207
|
+
return f"MAP({key_type}, {value_type})"
|
|
208
|
+
if type_lower == "struct" or type_lower == "object" or type_lower == "record":
|
|
209
|
+
structure_field = "STRUCT("
|
|
210
|
+
field_strings = []
|
|
211
|
+
for fieldKey, fieldValue in field.fields.items():
|
|
212
|
+
field_strings.append(f"{fieldKey} {convert_to_duckdb(fieldValue)}")
|
|
213
|
+
structure_field += ", ".join(field_strings)
|
|
214
|
+
structure_field += ")"
|
|
215
|
+
return structure_field
|
|
216
|
+
|
|
217
|
+
# Return none
|
|
183
218
|
return None
|
|
184
219
|
|
|
185
220
|
|
|
@@ -234,16 +269,7 @@ def convert_type_to_sqlserver(field: Field) -> None | str:
|
|
|
234
269
|
|
|
235
270
|
def convert_type_to_bigquery(field: Field) -> None | str:
|
|
236
271
|
"""Convert from supported datacontract types to equivalent bigquery types"""
|
|
237
|
-
|
|
238
|
-
if not field_type:
|
|
239
|
-
return None
|
|
240
|
-
|
|
241
|
-
# If provided sql-server config type, prefer it over default mapping
|
|
242
|
-
if bigquery_type := get_type_config(field, "bigqueryType"):
|
|
243
|
-
return bigquery_type
|
|
244
|
-
|
|
245
|
-
field_type = field_type.lower()
|
|
246
|
-
return map_type_to_bigquery(field_type, field.title)
|
|
272
|
+
return map_type_to_bigquery(field)
|
|
247
273
|
|
|
248
274
|
|
|
249
275
|
def get_type_config(field: Field, config_attr: str) -> dict[str, str] | None:
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import datacontract.model.data_contract_specification as spec
|
|
5
|
+
from datacontract.export.exporter import Exporter
|
|
6
|
+
from datacontract.export.exporter import _determine_sql_server_type
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SQLAlchemyExporter(Exporter):
|
|
10
|
+
def export(
|
|
11
|
+
self, data_contract: spec.DataContractSpecification, model, server, sql_server_type, export_args
|
|
12
|
+
) -> dict:
|
|
13
|
+
sql_server_type = _determine_sql_server_type(data_contract, sql_server_type, server)
|
|
14
|
+
return to_sqlalchemy_model_str(data_contract, sql_server_type, server)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
DECLARATIVE_BASE = "Base"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def to_sqlalchemy_model_str(contract: spec.DataContractSpecification, sql_server_type: str = "", server=None) -> str:
|
|
21
|
+
server_obj = contract.servers.get(server)
|
|
22
|
+
classdefs = [
|
|
23
|
+
generate_model_class(model_name, model, server_obj, sql_server_type)
|
|
24
|
+
for (model_name, model) in contract.models.items()
|
|
25
|
+
]
|
|
26
|
+
documentation = (
|
|
27
|
+
[ast.Expr(ast.Constant(contract.info.description))] if (contract.info and contract.info.description) else []
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
declarative_base = ast.ClassDef(
|
|
31
|
+
name=DECLARATIVE_BASE,
|
|
32
|
+
bases=[ast.Name(id="DeclarativeBase", ctx=ast.Load())],
|
|
33
|
+
body=[ast.Pass()],
|
|
34
|
+
keywords=[],
|
|
35
|
+
decorator_list=[],
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
databricks_timestamp = ast.ImportFrom(
|
|
39
|
+
module="databricks.sqlalchemy", names=[ast.alias("TIMESTAMP"), ast.alias("TIMESTAMP_NTZ")]
|
|
40
|
+
)
|
|
41
|
+
timestamp = ast.ImportFrom(module="sqlalchemy", names=[ast.alias(name="TIMESTAMP")])
|
|
42
|
+
result = ast.Module(
|
|
43
|
+
body=[
|
|
44
|
+
ast.ImportFrom(module="sqlalchemy.orm", names=[ast.alias(name="DeclarativeBase")]),
|
|
45
|
+
ast.ImportFrom(
|
|
46
|
+
module="sqlalchemy",
|
|
47
|
+
names=[
|
|
48
|
+
ast.alias("Column"),
|
|
49
|
+
ast.alias("Date"),
|
|
50
|
+
ast.alias("Integer"),
|
|
51
|
+
ast.alias("Numeric"),
|
|
52
|
+
ast.alias("String"),
|
|
53
|
+
ast.alias("Text"),
|
|
54
|
+
ast.alias("VARCHAR"),
|
|
55
|
+
ast.alias("BigInteger"),
|
|
56
|
+
ast.alias("Float"),
|
|
57
|
+
ast.alias("Double"),
|
|
58
|
+
ast.alias("Boolean"),
|
|
59
|
+
ast.alias("Date"),
|
|
60
|
+
ast.alias("ARRAY"),
|
|
61
|
+
ast.alias("LargeBinary"),
|
|
62
|
+
],
|
|
63
|
+
),
|
|
64
|
+
databricks_timestamp if sql_server_type == "databricks" else timestamp,
|
|
65
|
+
*documentation,
|
|
66
|
+
declarative_base,
|
|
67
|
+
*classdefs,
|
|
68
|
+
],
|
|
69
|
+
type_ignores=[],
|
|
70
|
+
)
|
|
71
|
+
return ast.unparse(result)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def Call(name, *args, **kwargs) -> ast.Call:
|
|
75
|
+
return ast.Call(
|
|
76
|
+
ast.Name(name),
|
|
77
|
+
args=[v for v in args],
|
|
78
|
+
keywords=[ast.keyword(arg=f"{k}", value=ast.Constant(v)) for (k, v) in kwargs.items()],
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def Column(predicate, **kwargs) -> ast.Call:
|
|
83
|
+
return Call("Column", predicate, **kwargs)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def sqlalchemy_primitive(field: spec.Field):
|
|
87
|
+
sqlalchemy_name = {
|
|
88
|
+
"string": Call("String", ast.Constant(field.maxLength)),
|
|
89
|
+
"text": Call("Text", ast.Constant(field.maxLength)),
|
|
90
|
+
"varchar": Call("VARCHAR", ast.Constant(field.maxLength)),
|
|
91
|
+
"number": Call("Numeric", ast.Constant(field.precision), ast.Constant(field.scale)),
|
|
92
|
+
"decimal": Call("Numeric", ast.Constant(field.precision), ast.Constant(field.scale)),
|
|
93
|
+
"numeric": Call("Numeric", ast.Constant(field.precision), ast.Constant(field.scale)),
|
|
94
|
+
"int": ast.Name("Integer"),
|
|
95
|
+
"integer": ast.Name("Integer"),
|
|
96
|
+
"long": ast.Name("BigInteger"),
|
|
97
|
+
"bigint": ast.Name("BigInteger"),
|
|
98
|
+
"float": ast.Name("Float"),
|
|
99
|
+
"double": ast.Name("Double"),
|
|
100
|
+
"boolean": ast.Name("Boolean"),
|
|
101
|
+
"timestamp": ast.Name("TIMESTAMP"),
|
|
102
|
+
"timestamp_tz": Call("TIMESTAMP", ast.Constant(True)),
|
|
103
|
+
"timestamp_ntz": ast.Name("TIMESTAMP_NTZ"),
|
|
104
|
+
"date": ast.Name("Date"),
|
|
105
|
+
"bytes": Call("LargeBinary", ast.Constant(field.maxLength)),
|
|
106
|
+
}
|
|
107
|
+
return sqlalchemy_name.get(field.type)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def constant_field_value(field_name: str, field: spec.Field) -> tuple[ast.Call, typing.Optional[ast.ClassDef]]:
|
|
111
|
+
new_type = sqlalchemy_primitive(field)
|
|
112
|
+
match field.type:
|
|
113
|
+
case "array":
|
|
114
|
+
new_type = Call("ARRAY", sqlalchemy_primitive(field.items))
|
|
115
|
+
if new_type is None:
|
|
116
|
+
raise RuntimeError(f"Unsupported field type {field.type}.")
|
|
117
|
+
|
|
118
|
+
return Column(new_type, nullable=not field.required, comment=field.description, primary_key=field.primary), None
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def column_assignment(field_name: str, field: spec.Field) -> tuple[ast.Call, typing.Optional[ast.ClassDef]]:
|
|
122
|
+
return constant_field_value(field_name, field)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def is_simple_field(field: spec.Field) -> bool:
|
|
126
|
+
return field.type not in set(["object", "record", "struct"])
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def field_definitions(fields: dict[str, spec.Field]) -> tuple[list[ast.Expr], list[ast.ClassDef]]:
|
|
130
|
+
annotations: list[ast.Expr] = []
|
|
131
|
+
classes: list[typing.Any] = []
|
|
132
|
+
for field_name, field in fields.items():
|
|
133
|
+
(ann, new_class) = column_assignment(field_name, field)
|
|
134
|
+
annotations.append(ast.Assign(targets=[ast.Name(id=field_name, ctx=ast.Store())], value=ann, lineno=0))
|
|
135
|
+
return (annotations, classes)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def generate_model_class(
|
|
139
|
+
name: str, model_definition: spec.Model, server=None, sql_server_type: str = ""
|
|
140
|
+
) -> ast.ClassDef:
|
|
141
|
+
(field_assignments, nested_classes) = field_definitions(model_definition.fields)
|
|
142
|
+
documentation = [ast.Expr(ast.Constant(model_definition.description))] if model_definition.description else []
|
|
143
|
+
|
|
144
|
+
schema = None if server is None else server.schema_
|
|
145
|
+
table_name = ast.Constant(name)
|
|
146
|
+
if sql_server_type == "databricks":
|
|
147
|
+
table_name = ast.Constant(name.lower())
|
|
148
|
+
|
|
149
|
+
result = ast.ClassDef(
|
|
150
|
+
name=name.capitalize(),
|
|
151
|
+
bases=[ast.Name(id=DECLARATIVE_BASE, ctx=ast.Load())],
|
|
152
|
+
body=[
|
|
153
|
+
*documentation,
|
|
154
|
+
ast.Assign(targets=[ast.Name("__tablename__")], value=table_name, lineno=0),
|
|
155
|
+
ast.Assign(
|
|
156
|
+
targets=[ast.Name("__table_args__")],
|
|
157
|
+
value=ast.Dict(
|
|
158
|
+
keys=[ast.Constant("comment"), ast.Constant("schema")],
|
|
159
|
+
values=[ast.Constant(model_definition.description), ast.Constant(schema)],
|
|
160
|
+
),
|
|
161
|
+
lineno=0,
|
|
162
|
+
),
|
|
163
|
+
*nested_classes,
|
|
164
|
+
*field_assignments,
|
|
165
|
+
],
|
|
166
|
+
keywords=[],
|
|
167
|
+
decorator_list=[],
|
|
168
|
+
)
|
|
169
|
+
return result
|
|
@@ -133,6 +133,7 @@ def import_record_fields(record_fields: List[avro.schema.Field]) -> Dict[str, Fi
|
|
|
133
133
|
elif field.type.type == "enum":
|
|
134
134
|
imported_field.type = "string"
|
|
135
135
|
imported_field.enum = field.type.symbols
|
|
136
|
+
imported_field.title = field.type.name
|
|
136
137
|
if not imported_field.config:
|
|
137
138
|
imported_field.config = {}
|
|
138
139
|
imported_field.config["avroType"] = "enum"
|
|
@@ -10,13 +10,13 @@ from datacontract.model.exceptions import DataContractException
|
|
|
10
10
|
class BigQueryImporter(Importer):
|
|
11
11
|
def import_source(
|
|
12
12
|
self, data_contract_specification: DataContractSpecification, source: str, import_args: dict
|
|
13
|
-
) ->
|
|
13
|
+
) -> DataContractSpecification:
|
|
14
14
|
if source is not None:
|
|
15
15
|
data_contract_specification = import_bigquery_from_json(data_contract_specification, source)
|
|
16
16
|
else:
|
|
17
17
|
data_contract_specification = import_bigquery_from_api(
|
|
18
18
|
data_contract_specification,
|
|
19
|
-
import_args.get("
|
|
19
|
+
import_args.get("bigquery_table"),
|
|
20
20
|
import_args.get("bigquery_project"),
|
|
21
21
|
import_args.get("bigquery_dataset"),
|
|
22
22
|
)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from pydbml import PyDBML, Database
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from pyparsing import ParseException
|
|
5
|
+
|
|
6
|
+
from datacontract.imports.importer import Importer
|
|
7
|
+
from datacontract.imports.sql_importer import map_type_from_sql
|
|
8
|
+
from datacontract.model.data_contract_specification import DataContractSpecification, Model, Field
|
|
9
|
+
from datacontract.model.exceptions import DataContractException
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DBMLImporter(Importer):
|
|
13
|
+
def import_source(
|
|
14
|
+
self, data_contract_specification: DataContractSpecification, source: str, import_args: dict
|
|
15
|
+
) -> DataContractSpecification:
|
|
16
|
+
data_contract_specification = import_dbml_from_source(
|
|
17
|
+
data_contract_specification,
|
|
18
|
+
source,
|
|
19
|
+
import_args.get("dbml_schema"),
|
|
20
|
+
import_args.get("dbml_table"),
|
|
21
|
+
)
|
|
22
|
+
return data_contract_specification
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def import_dbml_from_source(
|
|
26
|
+
data_contract_specification: DataContractSpecification,
|
|
27
|
+
source: str,
|
|
28
|
+
import_schemas: List[str],
|
|
29
|
+
import_tables: List[str],
|
|
30
|
+
) -> DataContractSpecification:
|
|
31
|
+
try:
|
|
32
|
+
with open(source, "r") as file:
|
|
33
|
+
dbml_schema = PyDBML(file)
|
|
34
|
+
except ParseException as e:
|
|
35
|
+
raise DataContractException(
|
|
36
|
+
type="schema",
|
|
37
|
+
name="Parse DBML schema",
|
|
38
|
+
reason=f"Failed to parse DBML schema from {source}",
|
|
39
|
+
engine="datacontract",
|
|
40
|
+
original_exception=e,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
return convert_dbml(data_contract_specification, dbml_schema, import_schemas, import_tables)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def convert_dbml(
|
|
47
|
+
data_contract_specification: DataContractSpecification,
|
|
48
|
+
dbml_schema: Database,
|
|
49
|
+
import_schemas: List[str],
|
|
50
|
+
import_tables: List[str],
|
|
51
|
+
) -> DataContractSpecification:
|
|
52
|
+
if dbml_schema.project is not None:
|
|
53
|
+
data_contract_specification.info.title = dbml_schema.project.name
|
|
54
|
+
|
|
55
|
+
if data_contract_specification.models is None:
|
|
56
|
+
data_contract_specification.models = {}
|
|
57
|
+
|
|
58
|
+
for table in dbml_schema.tables:
|
|
59
|
+
schema_name = table.schema
|
|
60
|
+
table_name = table.name
|
|
61
|
+
|
|
62
|
+
# Skip if import schemas or table names are defined
|
|
63
|
+
# and the current table doesn't match
|
|
64
|
+
# if empty no filtering is done
|
|
65
|
+
if import_schemas and schema_name not in import_schemas:
|
|
66
|
+
continue
|
|
67
|
+
|
|
68
|
+
if import_tables and table_name not in import_tables:
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
fields = import_table_fields(table, dbml_schema.refs)
|
|
72
|
+
|
|
73
|
+
data_contract_specification.models[table_name] = Model(
|
|
74
|
+
fields=fields, namespace=schema_name, description=table.note.text
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return data_contract_specification
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def import_table_fields(table, references) -> dict[str, Field]:
|
|
81
|
+
imported_fields = {}
|
|
82
|
+
for field in table.columns:
|
|
83
|
+
field_name = field.name
|
|
84
|
+
imported_fields[field_name] = Field()
|
|
85
|
+
imported_fields[field_name].required = field.not_null
|
|
86
|
+
imported_fields[field_name].description = field.note.text
|
|
87
|
+
imported_fields[field_name].primary = field.pk
|
|
88
|
+
imported_fields[field_name].unique = field.unique
|
|
89
|
+
# This is an assumption, that these might be valid SQL Types, since
|
|
90
|
+
# DBML doesn't really enforce anything other than 'no spaces' in column types
|
|
91
|
+
imported_fields[field_name].type = map_type_from_sql(field.type)
|
|
92
|
+
|
|
93
|
+
ref = get_reference(field, references)
|
|
94
|
+
if ref is not None:
|
|
95
|
+
imported_fields[field_name].references = ref
|
|
96
|
+
|
|
97
|
+
return imported_fields
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def get_reference(field, references):
|
|
101
|
+
result = None
|
|
102
|
+
for ref in references:
|
|
103
|
+
ref_table_name = ref.col1[0].table.name
|
|
104
|
+
ref_col_name = ref.col1[0].name
|
|
105
|
+
field_table_name = field.table.name
|
|
106
|
+
field_name = field.name
|
|
107
|
+
|
|
108
|
+
if ref_table_name == field_table_name and ref_col_name == field_name:
|
|
109
|
+
result = f"{ref.col2[0].table.name}.{ref.col2[0].name}"
|
|
110
|
+
return result
|
|
111
|
+
|
|
112
|
+
return result
|