datacontract-cli 0.10.10__py3-none-any.whl → 0.10.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datacontract-cli might be problematic. Click here for more details.
- datacontract/cli.py +19 -3
- datacontract/data_contract.py +17 -17
- datacontract/engines/fastjsonschema/check_jsonschema.py +15 -1
- datacontract/engines/fastjsonschema/s3/s3_read_files.py +2 -0
- datacontract/engines/soda/check_soda_execute.py +2 -8
- datacontract/engines/soda/connections/duckdb.py +23 -20
- datacontract/engines/soda/connections/kafka.py +81 -23
- datacontract/engines/soda/connections/snowflake.py +8 -5
- datacontract/export/avro_converter.py +12 -2
- datacontract/export/dbml_converter.py +42 -19
- datacontract/export/exporter.py +2 -1
- datacontract/export/exporter_factory.py +6 -0
- datacontract/export/jsonschema_converter.py +1 -4
- datacontract/export/spark_converter.py +4 -0
- datacontract/export/sql_type_converter.py +64 -29
- datacontract/export/sqlalchemy_converter.py +169 -0
- datacontract/imports/avro_importer.py +1 -0
- datacontract/imports/bigquery_importer.py +2 -2
- datacontract/imports/dbml_importer.py +112 -0
- datacontract/imports/dbt_importer.py +67 -91
- datacontract/imports/glue_importer.py +64 -54
- datacontract/imports/importer.py +3 -2
- datacontract/imports/importer_factory.py +5 -0
- datacontract/imports/jsonschema_importer.py +106 -120
- datacontract/imports/odcs_importer.py +1 -1
- datacontract/imports/spark_importer.py +29 -10
- datacontract/imports/sql_importer.py +5 -1
- datacontract/imports/unity_importer.py +1 -1
- datacontract/integration/{publish_datamesh_manager.py → datamesh_manager.py} +33 -5
- datacontract/integration/{publish_opentelemetry.py → opentelemetry.py} +1 -1
- datacontract/model/data_contract_specification.py +6 -2
- datacontract/templates/partials/model_field.html +10 -2
- {datacontract_cli-0.10.10.dist-info → datacontract_cli-0.10.12.dist-info}/METADATA +283 -113
- {datacontract_cli-0.10.10.dist-info → datacontract_cli-0.10.12.dist-info}/RECORD +38 -37
- {datacontract_cli-0.10.10.dist-info → datacontract_cli-0.10.12.dist-info}/WHEEL +1 -1
- datacontract/publish/publish.py +0 -32
- {datacontract_cli-0.10.10.dist-info → datacontract_cli-0.10.12.dist-info}/LICENSE +0 -0
- {datacontract_cli-0.10.10.dist-info → datacontract_cli-0.10.12.dist-info}/entry_points.txt +0 -0
- {datacontract_cli-0.10.10.dist-info → datacontract_cli-0.10.12.dist-info}/top_level.txt +0 -0
|
@@ -149,37 +149,72 @@ def convert_to_databricks(field: Field) -> None | str:
|
|
|
149
149
|
|
|
150
150
|
|
|
151
151
|
def convert_to_duckdb(field: Field) -> None | str:
|
|
152
|
-
|
|
153
|
-
|
|
152
|
+
"""
|
|
153
|
+
Convert a data contract field to the corresponding DuckDB SQL type.
|
|
154
|
+
|
|
155
|
+
Parameters:
|
|
156
|
+
field (Field): The data contract field to convert.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
str: The corresponding DuckDB SQL type.
|
|
160
|
+
"""
|
|
161
|
+
# Check
|
|
162
|
+
if field is None or field.type is None:
|
|
154
163
|
return None
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
164
|
+
|
|
165
|
+
# Get
|
|
166
|
+
type_lower = field.type.lower()
|
|
167
|
+
|
|
168
|
+
# Prepare
|
|
169
|
+
type_mapping = {
|
|
170
|
+
"varchar": "VARCHAR",
|
|
171
|
+
"string": "VARCHAR",
|
|
172
|
+
"text": "VARCHAR",
|
|
173
|
+
"binary": "BLOB",
|
|
174
|
+
"bytes": "BLOB",
|
|
175
|
+
"blob": "BLOB",
|
|
176
|
+
"boolean": "BOOLEAN",
|
|
177
|
+
"float": "FLOAT",
|
|
178
|
+
"double": "DOUBLE",
|
|
179
|
+
"int": "INTEGER",
|
|
180
|
+
"int32": "INTEGER",
|
|
181
|
+
"integer": "INTEGER",
|
|
182
|
+
"int64": "BIGINT",
|
|
183
|
+
"long": "BIGINT",
|
|
184
|
+
"bigint": "BIGINT",
|
|
185
|
+
"date": "DATE",
|
|
186
|
+
"time": "TIME",
|
|
187
|
+
"timestamp": "TIMESTAMP WITH TIME ZONE",
|
|
188
|
+
"timestamp_tz": "TIMESTAMP WITH TIME ZONE",
|
|
189
|
+
"timestamp_ntz": "DATETIME",
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
# Convert simple mappings
|
|
193
|
+
if type_lower in type_mapping:
|
|
194
|
+
return type_mapping[type_lower]
|
|
195
|
+
|
|
196
|
+
# convert decimal numbers with precision and scale
|
|
197
|
+
if type_lower == "decimal" or type_lower == "number" or type_lower == "numeric":
|
|
166
198
|
return f"DECIMAL({field.precision},{field.scale})"
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
if
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
199
|
+
|
|
200
|
+
# Check list and map
|
|
201
|
+
if type_lower == "list" or type_lower == "array":
|
|
202
|
+
item_type = convert_to_duckdb(field.items)
|
|
203
|
+
return f"{item_type}[]"
|
|
204
|
+
if type_lower == "map":
|
|
205
|
+
key_type = convert_to_duckdb(field.keys)
|
|
206
|
+
value_type = convert_to_duckdb(field.values)
|
|
207
|
+
return f"MAP({key_type}, {value_type})"
|
|
208
|
+
if type_lower == "struct" or type_lower == "object" or type_lower == "record":
|
|
209
|
+
structure_field = "STRUCT("
|
|
210
|
+
field_strings = []
|
|
211
|
+
for fieldKey, fieldValue in field.fields.items():
|
|
212
|
+
field_strings.append(f"{fieldKey} {convert_to_duckdb(fieldValue)}")
|
|
213
|
+
structure_field += ", ".join(field_strings)
|
|
214
|
+
structure_field += ")"
|
|
215
|
+
return structure_field
|
|
216
|
+
|
|
217
|
+
# Return none
|
|
183
218
|
return None
|
|
184
219
|
|
|
185
220
|
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import datacontract.model.data_contract_specification as spec
|
|
5
|
+
from datacontract.export.exporter import Exporter
|
|
6
|
+
from datacontract.export.exporter import _determine_sql_server_type
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SQLAlchemyExporter(Exporter):
|
|
10
|
+
def export(
|
|
11
|
+
self, data_contract: spec.DataContractSpecification, model, server, sql_server_type, export_args
|
|
12
|
+
) -> dict:
|
|
13
|
+
sql_server_type = _determine_sql_server_type(data_contract, sql_server_type, server)
|
|
14
|
+
return to_sqlalchemy_model_str(data_contract, sql_server_type, server)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
DECLARATIVE_BASE = "Base"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def to_sqlalchemy_model_str(contract: spec.DataContractSpecification, sql_server_type: str = "", server=None) -> str:
|
|
21
|
+
server_obj = contract.servers.get(server)
|
|
22
|
+
classdefs = [
|
|
23
|
+
generate_model_class(model_name, model, server_obj, sql_server_type)
|
|
24
|
+
for (model_name, model) in contract.models.items()
|
|
25
|
+
]
|
|
26
|
+
documentation = (
|
|
27
|
+
[ast.Expr(ast.Constant(contract.info.description))] if (contract.info and contract.info.description) else []
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
declarative_base = ast.ClassDef(
|
|
31
|
+
name=DECLARATIVE_BASE,
|
|
32
|
+
bases=[ast.Name(id="DeclarativeBase", ctx=ast.Load())],
|
|
33
|
+
body=[ast.Pass()],
|
|
34
|
+
keywords=[],
|
|
35
|
+
decorator_list=[],
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
databricks_timestamp = ast.ImportFrom(
|
|
39
|
+
module="databricks.sqlalchemy", names=[ast.alias("TIMESTAMP"), ast.alias("TIMESTAMP_NTZ")]
|
|
40
|
+
)
|
|
41
|
+
timestamp = ast.ImportFrom(module="sqlalchemy", names=[ast.alias(name="TIMESTAMP")])
|
|
42
|
+
result = ast.Module(
|
|
43
|
+
body=[
|
|
44
|
+
ast.ImportFrom(module="sqlalchemy.orm", names=[ast.alias(name="DeclarativeBase")]),
|
|
45
|
+
ast.ImportFrom(
|
|
46
|
+
module="sqlalchemy",
|
|
47
|
+
names=[
|
|
48
|
+
ast.alias("Column"),
|
|
49
|
+
ast.alias("Date"),
|
|
50
|
+
ast.alias("Integer"),
|
|
51
|
+
ast.alias("Numeric"),
|
|
52
|
+
ast.alias("String"),
|
|
53
|
+
ast.alias("Text"),
|
|
54
|
+
ast.alias("VARCHAR"),
|
|
55
|
+
ast.alias("BigInteger"),
|
|
56
|
+
ast.alias("Float"),
|
|
57
|
+
ast.alias("Double"),
|
|
58
|
+
ast.alias("Boolean"),
|
|
59
|
+
ast.alias("Date"),
|
|
60
|
+
ast.alias("ARRAY"),
|
|
61
|
+
ast.alias("LargeBinary"),
|
|
62
|
+
],
|
|
63
|
+
),
|
|
64
|
+
databricks_timestamp if sql_server_type == "databricks" else timestamp,
|
|
65
|
+
*documentation,
|
|
66
|
+
declarative_base,
|
|
67
|
+
*classdefs,
|
|
68
|
+
],
|
|
69
|
+
type_ignores=[],
|
|
70
|
+
)
|
|
71
|
+
return ast.unparse(result)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def Call(name, *args, **kwargs) -> ast.Call:
|
|
75
|
+
return ast.Call(
|
|
76
|
+
ast.Name(name),
|
|
77
|
+
args=[v for v in args],
|
|
78
|
+
keywords=[ast.keyword(arg=f"{k}", value=ast.Constant(v)) for (k, v) in kwargs.items()],
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def Column(predicate, **kwargs) -> ast.Call:
|
|
83
|
+
return Call("Column", predicate, **kwargs)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def sqlalchemy_primitive(field: spec.Field):
|
|
87
|
+
sqlalchemy_name = {
|
|
88
|
+
"string": Call("String", ast.Constant(field.maxLength)),
|
|
89
|
+
"text": Call("Text", ast.Constant(field.maxLength)),
|
|
90
|
+
"varchar": Call("VARCHAR", ast.Constant(field.maxLength)),
|
|
91
|
+
"number": Call("Numeric", ast.Constant(field.precision), ast.Constant(field.scale)),
|
|
92
|
+
"decimal": Call("Numeric", ast.Constant(field.precision), ast.Constant(field.scale)),
|
|
93
|
+
"numeric": Call("Numeric", ast.Constant(field.precision), ast.Constant(field.scale)),
|
|
94
|
+
"int": ast.Name("Integer"),
|
|
95
|
+
"integer": ast.Name("Integer"),
|
|
96
|
+
"long": ast.Name("BigInteger"),
|
|
97
|
+
"bigint": ast.Name("BigInteger"),
|
|
98
|
+
"float": ast.Name("Float"),
|
|
99
|
+
"double": ast.Name("Double"),
|
|
100
|
+
"boolean": ast.Name("Boolean"),
|
|
101
|
+
"timestamp": ast.Name("TIMESTAMP"),
|
|
102
|
+
"timestamp_tz": Call("TIMESTAMP", ast.Constant(True)),
|
|
103
|
+
"timestamp_ntz": ast.Name("TIMESTAMP_NTZ"),
|
|
104
|
+
"date": ast.Name("Date"),
|
|
105
|
+
"bytes": Call("LargeBinary", ast.Constant(field.maxLength)),
|
|
106
|
+
}
|
|
107
|
+
return sqlalchemy_name.get(field.type)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def constant_field_value(field_name: str, field: spec.Field) -> tuple[ast.Call, typing.Optional[ast.ClassDef]]:
|
|
111
|
+
new_type = sqlalchemy_primitive(field)
|
|
112
|
+
match field.type:
|
|
113
|
+
case "array":
|
|
114
|
+
new_type = Call("ARRAY", sqlalchemy_primitive(field.items))
|
|
115
|
+
if new_type is None:
|
|
116
|
+
raise RuntimeError(f"Unsupported field type {field.type}.")
|
|
117
|
+
|
|
118
|
+
return Column(new_type, nullable=not field.required, comment=field.description, primary_key=field.primary), None
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def column_assignment(field_name: str, field: spec.Field) -> tuple[ast.Call, typing.Optional[ast.ClassDef]]:
|
|
122
|
+
return constant_field_value(field_name, field)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def is_simple_field(field: spec.Field) -> bool:
|
|
126
|
+
return field.type not in set(["object", "record", "struct"])
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def field_definitions(fields: dict[str, spec.Field]) -> tuple[list[ast.Expr], list[ast.ClassDef]]:
|
|
130
|
+
annotations: list[ast.Expr] = []
|
|
131
|
+
classes: list[typing.Any] = []
|
|
132
|
+
for field_name, field in fields.items():
|
|
133
|
+
(ann, new_class) = column_assignment(field_name, field)
|
|
134
|
+
annotations.append(ast.Assign(targets=[ast.Name(id=field_name, ctx=ast.Store())], value=ann, lineno=0))
|
|
135
|
+
return (annotations, classes)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def generate_model_class(
|
|
139
|
+
name: str, model_definition: spec.Model, server=None, sql_server_type: str = ""
|
|
140
|
+
) -> ast.ClassDef:
|
|
141
|
+
(field_assignments, nested_classes) = field_definitions(model_definition.fields)
|
|
142
|
+
documentation = [ast.Expr(ast.Constant(model_definition.description))] if model_definition.description else []
|
|
143
|
+
|
|
144
|
+
schema = None if server is None else server.schema_
|
|
145
|
+
table_name = ast.Constant(name)
|
|
146
|
+
if sql_server_type == "databricks":
|
|
147
|
+
table_name = ast.Constant(name.lower())
|
|
148
|
+
|
|
149
|
+
result = ast.ClassDef(
|
|
150
|
+
name=name.capitalize(),
|
|
151
|
+
bases=[ast.Name(id=DECLARATIVE_BASE, ctx=ast.Load())],
|
|
152
|
+
body=[
|
|
153
|
+
*documentation,
|
|
154
|
+
ast.Assign(targets=[ast.Name("__tablename__")], value=table_name, lineno=0),
|
|
155
|
+
ast.Assign(
|
|
156
|
+
targets=[ast.Name("__table_args__")],
|
|
157
|
+
value=ast.Dict(
|
|
158
|
+
keys=[ast.Constant("comment"), ast.Constant("schema")],
|
|
159
|
+
values=[ast.Constant(model_definition.description), ast.Constant(schema)],
|
|
160
|
+
),
|
|
161
|
+
lineno=0,
|
|
162
|
+
),
|
|
163
|
+
*nested_classes,
|
|
164
|
+
*field_assignments,
|
|
165
|
+
],
|
|
166
|
+
keywords=[],
|
|
167
|
+
decorator_list=[],
|
|
168
|
+
)
|
|
169
|
+
return result
|
|
@@ -133,6 +133,7 @@ def import_record_fields(record_fields: List[avro.schema.Field]) -> Dict[str, Fi
|
|
|
133
133
|
elif field.type.type == "enum":
|
|
134
134
|
imported_field.type = "string"
|
|
135
135
|
imported_field.enum = field.type.symbols
|
|
136
|
+
imported_field.title = field.type.name
|
|
136
137
|
if not imported_field.config:
|
|
137
138
|
imported_field.config = {}
|
|
138
139
|
imported_field.config["avroType"] = "enum"
|
|
@@ -10,13 +10,13 @@ from datacontract.model.exceptions import DataContractException
|
|
|
10
10
|
class BigQueryImporter(Importer):
|
|
11
11
|
def import_source(
|
|
12
12
|
self, data_contract_specification: DataContractSpecification, source: str, import_args: dict
|
|
13
|
-
) ->
|
|
13
|
+
) -> DataContractSpecification:
|
|
14
14
|
if source is not None:
|
|
15
15
|
data_contract_specification = import_bigquery_from_json(data_contract_specification, source)
|
|
16
16
|
else:
|
|
17
17
|
data_contract_specification = import_bigquery_from_api(
|
|
18
18
|
data_contract_specification,
|
|
19
|
-
import_args.get("
|
|
19
|
+
import_args.get("bigquery_table"),
|
|
20
20
|
import_args.get("bigquery_project"),
|
|
21
21
|
import_args.get("bigquery_dataset"),
|
|
22
22
|
)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from pydbml import PyDBML, Database
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from pyparsing import ParseException
|
|
5
|
+
|
|
6
|
+
from datacontract.imports.importer import Importer
|
|
7
|
+
from datacontract.imports.sql_importer import map_type_from_sql
|
|
8
|
+
from datacontract.model.data_contract_specification import DataContractSpecification, Model, Field
|
|
9
|
+
from datacontract.model.exceptions import DataContractException
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DBMLImporter(Importer):
|
|
13
|
+
def import_source(
|
|
14
|
+
self, data_contract_specification: DataContractSpecification, source: str, import_args: dict
|
|
15
|
+
) -> DataContractSpecification:
|
|
16
|
+
data_contract_specification = import_dbml_from_source(
|
|
17
|
+
data_contract_specification,
|
|
18
|
+
source,
|
|
19
|
+
import_args.get("dbml_schema"),
|
|
20
|
+
import_args.get("dbml_table"),
|
|
21
|
+
)
|
|
22
|
+
return data_contract_specification
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def import_dbml_from_source(
|
|
26
|
+
data_contract_specification: DataContractSpecification,
|
|
27
|
+
source: str,
|
|
28
|
+
import_schemas: List[str],
|
|
29
|
+
import_tables: List[str],
|
|
30
|
+
) -> DataContractSpecification:
|
|
31
|
+
try:
|
|
32
|
+
with open(source, "r") as file:
|
|
33
|
+
dbml_schema = PyDBML(file)
|
|
34
|
+
except ParseException as e:
|
|
35
|
+
raise DataContractException(
|
|
36
|
+
type="schema",
|
|
37
|
+
name="Parse DBML schema",
|
|
38
|
+
reason=f"Failed to parse DBML schema from {source}",
|
|
39
|
+
engine="datacontract",
|
|
40
|
+
original_exception=e,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
return convert_dbml(data_contract_specification, dbml_schema, import_schemas, import_tables)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def convert_dbml(
|
|
47
|
+
data_contract_specification: DataContractSpecification,
|
|
48
|
+
dbml_schema: Database,
|
|
49
|
+
import_schemas: List[str],
|
|
50
|
+
import_tables: List[str],
|
|
51
|
+
) -> DataContractSpecification:
|
|
52
|
+
if dbml_schema.project is not None:
|
|
53
|
+
data_contract_specification.info.title = dbml_schema.project.name
|
|
54
|
+
|
|
55
|
+
if data_contract_specification.models is None:
|
|
56
|
+
data_contract_specification.models = {}
|
|
57
|
+
|
|
58
|
+
for table in dbml_schema.tables:
|
|
59
|
+
schema_name = table.schema
|
|
60
|
+
table_name = table.name
|
|
61
|
+
|
|
62
|
+
# Skip if import schemas or table names are defined
|
|
63
|
+
# and the current table doesn't match
|
|
64
|
+
# if empty no filtering is done
|
|
65
|
+
if import_schemas and schema_name not in import_schemas:
|
|
66
|
+
continue
|
|
67
|
+
|
|
68
|
+
if import_tables and table_name not in import_tables:
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
fields = import_table_fields(table, dbml_schema.refs)
|
|
72
|
+
|
|
73
|
+
data_contract_specification.models[table_name] = Model(
|
|
74
|
+
fields=fields, namespace=schema_name, description=table.note.text
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return data_contract_specification
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def import_table_fields(table, references) -> dict[str, Field]:
|
|
81
|
+
imported_fields = {}
|
|
82
|
+
for field in table.columns:
|
|
83
|
+
field_name = field.name
|
|
84
|
+
imported_fields[field_name] = Field()
|
|
85
|
+
imported_fields[field_name].required = field.not_null
|
|
86
|
+
imported_fields[field_name].description = field.note.text
|
|
87
|
+
imported_fields[field_name].primary = field.pk
|
|
88
|
+
imported_fields[field_name].unique = field.unique
|
|
89
|
+
# This is an assumption, that these might be valid SQL Types, since
|
|
90
|
+
# DBML doesn't really enforce anything other than 'no spaces' in column types
|
|
91
|
+
imported_fields[field_name].type = map_type_from_sql(field.type)
|
|
92
|
+
|
|
93
|
+
ref = get_reference(field, references)
|
|
94
|
+
if ref is not None:
|
|
95
|
+
imported_fields[field_name].references = ref
|
|
96
|
+
|
|
97
|
+
return imported_fields
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def get_reference(field, references):
|
|
101
|
+
result = None
|
|
102
|
+
for ref in references:
|
|
103
|
+
ref_table_name = ref.col1[0].table.name
|
|
104
|
+
ref_col_name = ref.col1[0].name
|
|
105
|
+
field_table_name = field.table.name
|
|
106
|
+
field_name = field.name
|
|
107
|
+
|
|
108
|
+
if ref_table_name == field_table_name and ref_col_name == field_name:
|
|
109
|
+
result = f"{ref.col2[0].table.name}.{ref.col2[0].name}"
|
|
110
|
+
return result
|
|
111
|
+
|
|
112
|
+
return result
|
|
@@ -1,117 +1,93 @@
|
|
|
1
1
|
import json
|
|
2
|
-
|
|
3
|
-
from typing import (
|
|
4
|
-
List,
|
|
5
|
-
)
|
|
2
|
+
from typing import TypedDict
|
|
6
3
|
|
|
7
4
|
from datacontract.imports.importer import Importer
|
|
8
5
|
from datacontract.model.data_contract_specification import DataContractSpecification, Field, Model
|
|
6
|
+
from dbt.artifacts.resources.v1.components import ColumnInfo
|
|
7
|
+
from dbt.contracts.graph.manifest import Manifest
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DBTImportArgs(TypedDict, total=False):
|
|
11
|
+
"""
|
|
12
|
+
A dictionary representing arguments for importing DBT models.
|
|
13
|
+
Makes the DBT Importer more customizable by allowing for flexible filtering
|
|
14
|
+
of models and their properties, through wrapping or extending.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
dbt_models: The keys of models to be used in contract. All as default.
|
|
18
|
+
resource_types: Nodes listed in resource_types are kept while importing. model as default.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
dbt_nodes: list[str]
|
|
22
|
+
resource_types: list[str]
|
|
9
23
|
|
|
10
24
|
|
|
11
25
|
class DbtManifestImporter(Importer):
|
|
12
26
|
def import_source(
|
|
13
|
-
self,
|
|
14
|
-
|
|
15
|
-
|
|
27
|
+
self,
|
|
28
|
+
data_contract_specification: DataContractSpecification,
|
|
29
|
+
source: str,
|
|
30
|
+
import_args: DBTImportArgs,
|
|
31
|
+
) -> DataContractSpecification:
|
|
32
|
+
manifest = read_dbt_manifest(manifest_path=source)
|
|
16
33
|
return import_dbt_manifest(
|
|
17
|
-
data_contract_specification
|
|
34
|
+
data_contract_specification=data_contract_specification,
|
|
35
|
+
manifest=manifest,
|
|
36
|
+
dbt_nodes=import_args.get("dbt_nodes", []),
|
|
37
|
+
resource_types=import_args.get("resource_types", ["model"]),
|
|
18
38
|
)
|
|
19
39
|
|
|
20
40
|
|
|
21
|
-
def
|
|
22
|
-
|
|
23
|
-
):
|
|
24
|
-
|
|
25
|
-
|
|
41
|
+
def read_dbt_manifest(manifest_path: str) -> Manifest:
|
|
42
|
+
"""Read a manifest from file."""
|
|
43
|
+
with open(file=manifest_path, mode="r", encoding="utf-8") as f:
|
|
44
|
+
manifest_dict: dict = json.load(f)
|
|
45
|
+
return Manifest.from_dict(manifest_dict)
|
|
26
46
|
|
|
27
|
-
if data_contract_specification.models is None:
|
|
28
|
-
data_contract_specification.models = {}
|
|
29
47
|
|
|
30
|
-
|
|
31
|
-
|
|
48
|
+
def import_dbt_manifest(
|
|
49
|
+
data_contract_specification: DataContractSpecification,
|
|
50
|
+
manifest: Manifest,
|
|
51
|
+
dbt_nodes: list[str],
|
|
52
|
+
resource_types: list[str],
|
|
53
|
+
) -> DataContractSpecification:
|
|
54
|
+
"""
|
|
55
|
+
Extracts all relevant information from the manifest,
|
|
56
|
+
and puts it in a data contract specification.
|
|
57
|
+
"""
|
|
58
|
+
data_contract_specification.info.title = manifest.metadata.project_name
|
|
59
|
+
data_contract_specification.info.dbt_version = manifest.metadata.dbt_version
|
|
60
|
+
|
|
61
|
+
data_contract_specification.models = data_contract_specification.models or {}
|
|
62
|
+
for model_contents in manifest.nodes.values():
|
|
63
|
+
# Only intressted in processing models.
|
|
64
|
+
if model_contents.resource_type not in resource_types:
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
# To allow args stored in dbt_models to filter relevant models.
|
|
68
|
+
# If dbt_models is empty, use all models.
|
|
69
|
+
if dbt_nodes and model_contents.name not in dbt_nodes:
|
|
32
70
|
continue
|
|
33
71
|
|
|
34
72
|
dc_model = Model(
|
|
35
|
-
description=
|
|
36
|
-
tags=
|
|
37
|
-
fields=create_fields(
|
|
73
|
+
description=model_contents.description,
|
|
74
|
+
tags=model_contents.tags,
|
|
75
|
+
fields=create_fields(columns=model_contents.columns),
|
|
38
76
|
)
|
|
39
77
|
|
|
40
|
-
data_contract_specification.models[
|
|
78
|
+
data_contract_specification.models[model_contents.name] = dc_model
|
|
41
79
|
|
|
42
80
|
return data_contract_specification
|
|
43
81
|
|
|
44
82
|
|
|
45
|
-
def create_fields(columns:
|
|
46
|
-
fields = {
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
83
|
+
def create_fields(columns: dict[str, ColumnInfo]) -> dict[str, Field]:
|
|
84
|
+
fields = {
|
|
85
|
+
column.name: Field(
|
|
86
|
+
description=column.description,
|
|
87
|
+
type=column.data_type if column.data_type else "",
|
|
88
|
+
tags=column.tags,
|
|
50
89
|
)
|
|
51
|
-
|
|
90
|
+
for column in columns.values()
|
|
91
|
+
}
|
|
52
92
|
|
|
53
93
|
return fields
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def read_dbt_manifest(manifest_path: str):
|
|
57
|
-
with open(manifest_path, "r", encoding="utf-8") as f:
|
|
58
|
-
manifest = json.load(f)
|
|
59
|
-
return {"info": manifest.get("metadata"), "models": create_manifest_models(manifest)}
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
def create_manifest_models(manifest: dict) -> List:
|
|
63
|
-
models = []
|
|
64
|
-
nodes = manifest.get("nodes")
|
|
65
|
-
|
|
66
|
-
for node in nodes.values():
|
|
67
|
-
if node["resource_type"] != "model":
|
|
68
|
-
continue
|
|
69
|
-
|
|
70
|
-
models.append(DbtModel(node))
|
|
71
|
-
return models
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
class DbtColumn:
|
|
75
|
-
name: str
|
|
76
|
-
description: str
|
|
77
|
-
data_type: str
|
|
78
|
-
meta: dict
|
|
79
|
-
tags: List
|
|
80
|
-
|
|
81
|
-
def __init__(self, node_column: dict):
|
|
82
|
-
self.name = node_column.get("name")
|
|
83
|
-
self.description = node_column.get("description")
|
|
84
|
-
self.data_type = node_column.get("data_type")
|
|
85
|
-
self.meta = node_column.get("meta", {})
|
|
86
|
-
self.tags = node_column.get("tags", [])
|
|
87
|
-
|
|
88
|
-
def __repr__(self) -> str:
|
|
89
|
-
return self.name
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
class DbtModel:
|
|
93
|
-
name: str
|
|
94
|
-
database: str
|
|
95
|
-
schema: str
|
|
96
|
-
description: str
|
|
97
|
-
unique_id: str
|
|
98
|
-
tags: List
|
|
99
|
-
|
|
100
|
-
def __init__(self, node: dict):
|
|
101
|
-
self.name = node.get("name")
|
|
102
|
-
self.database = node.get("database")
|
|
103
|
-
self.schema = node.get("schema")
|
|
104
|
-
self.description = node.get("description")
|
|
105
|
-
self.display_name = node.get("display_name")
|
|
106
|
-
self.unique_id = node.get("unique_id")
|
|
107
|
-
self.columns = []
|
|
108
|
-
self.tags = node.get("tags")
|
|
109
|
-
if node.get("columns"):
|
|
110
|
-
self.add_columns(node.get("columns").values())
|
|
111
|
-
|
|
112
|
-
def add_columns(self, model_columns: List):
|
|
113
|
-
for column in model_columns:
|
|
114
|
-
self.columns.append(DbtColumn(column))
|
|
115
|
-
|
|
116
|
-
def __repr__(self) -> str:
|
|
117
|
-
return self.name
|