datacontract-cli 0.10.9__py3-none-any.whl → 0.10.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datacontract-cli might be problematic. Click here for more details.

Files changed (32) hide show
  1. datacontract/cli.py +7 -0
  2. datacontract/data_contract.py +16 -9
  3. datacontract/engines/fastjsonschema/check_jsonschema.py +4 -1
  4. datacontract/engines/soda/check_soda_execute.py +5 -2
  5. datacontract/engines/soda/connections/duckdb.py +20 -12
  6. datacontract/engines/soda/connections/snowflake.py +8 -5
  7. datacontract/export/avro_converter.py +1 -1
  8. datacontract/export/dbml_converter.py +41 -19
  9. datacontract/export/exporter.py +1 -1
  10. datacontract/export/jsonschema_converter.py +1 -4
  11. datacontract/export/sodacl_converter.py +1 -1
  12. datacontract/imports/avro_importer.py +142 -8
  13. datacontract/imports/dbt_importer.py +117 -0
  14. datacontract/imports/glue_importer.py +9 -3
  15. datacontract/imports/importer.py +7 -2
  16. datacontract/imports/importer_factory.py +24 -6
  17. datacontract/imports/jsonschema_importer.py +106 -117
  18. datacontract/imports/spark_importer.py +134 -0
  19. datacontract/imports/sql_importer.py +4 -0
  20. datacontract/integration/publish_datamesh_manager.py +10 -5
  21. datacontract/lint/resolve.py +72 -27
  22. datacontract/lint/schema.py +24 -4
  23. datacontract/model/data_contract_specification.py +3 -0
  24. datacontract/templates/datacontract.html +1 -1
  25. datacontract/templates/index.html +1 -1
  26. datacontract/templates/partials/model_field.html +10 -2
  27. {datacontract_cli-0.10.9.dist-info → datacontract_cli-0.10.11.dist-info}/METADATA +300 -192
  28. {datacontract_cli-0.10.9.dist-info → datacontract_cli-0.10.11.dist-info}/RECORD +32 -30
  29. {datacontract_cli-0.10.9.dist-info → datacontract_cli-0.10.11.dist-info}/WHEEL +1 -1
  30. {datacontract_cli-0.10.9.dist-info → datacontract_cli-0.10.11.dist-info}/LICENSE +0 -0
  31. {datacontract_cli-0.10.9.dist-info → datacontract_cli-0.10.11.dist-info}/entry_points.txt +0 -0
  32. {datacontract_cli-0.10.9.dist-info → datacontract_cli-0.10.11.dist-info}/top_level.txt +0 -0
datacontract/cli.py CHANGED
@@ -226,6 +226,12 @@ def import_(
226
226
  unity_table_full_name: Annotated[
227
227
  Optional[str], typer.Option(help="Full name of a table in the unity catalog")
228
228
  ] = None,
229
+ dbt_model: Annotated[
230
+ Optional[List[str]],
231
+ typer.Option(
232
+ help="List of models names to import from the dbt manifest file (repeat for multiple models names, leave empty for all models in the dataset)."
233
+ ),
234
+ ] = None,
229
235
  ):
230
236
  """
231
237
  Create a data contract from the given source location. Prints to stdout.
@@ -238,6 +244,7 @@ def import_(
238
244
  bigquery_project=bigquery_project,
239
245
  bigquery_dataset=bigquery_dataset,
240
246
  unity_table_full_name=unity_table_full_name,
247
+ dbt_model=dbt_model,
241
248
  )
242
249
  console.print(result.to_yaml())
243
250
 
@@ -4,7 +4,9 @@ import tempfile
4
4
  import typing
5
5
 
6
6
  import yaml
7
- from pyspark.sql import SparkSession
7
+
8
+ if typing.TYPE_CHECKING:
9
+ from pyspark.sql import SparkSession
8
10
 
9
11
  from datacontract.breaking.breaking import models_breaking_changes, quality_breaking_changes
10
12
  from datacontract.engines.datacontract.check_that_datacontract_contains_valid_servers_configuration import (
@@ -43,9 +45,9 @@ class DataContract:
43
45
  examples: bool = False,
44
46
  publish_url: str = None,
45
47
  publish_to_opentelemetry: bool = False,
46
- spark: SparkSession = None,
47
- inline_definitions: bool = False,
48
- inline_quality: bool = False,
48
+ spark: "SparkSession" = None,
49
+ inline_definitions: bool = True,
50
+ inline_quality: bool = True,
49
51
  ):
50
52
  self._data_contract_file = data_contract_file
51
53
  self._data_contract_str = data_contract_str
@@ -85,8 +87,8 @@ class DataContract:
85
87
  self._data_contract_str,
86
88
  self._data_contract,
87
89
  self._schema_location,
88
- inline_definitions=True,
89
- inline_quality=True,
90
+ inline_definitions=self._inline_definitions,
91
+ inline_quality=self._inline_quality,
90
92
  )
91
93
  run.checks.append(
92
94
  Check(type="lint", result="passed", name="Data contract is syntactically valid", engine="datacontract")
@@ -138,7 +140,12 @@ class DataContract:
138
140
  try:
139
141
  run.log_info("Testing data contract")
140
142
  data_contract = resolve.resolve_data_contract(
141
- self._data_contract_file, self._data_contract_str, self._data_contract, self._schema_location
143
+ self._data_contract_file,
144
+ self._data_contract_str,
145
+ self._data_contract,
146
+ self._schema_location,
147
+ inline_definitions=self._inline_definitions,
148
+ inline_quality=self._inline_quality,
142
149
  )
143
150
 
144
151
  if data_contract.models is None or len(data_contract.models) == 0:
@@ -302,8 +309,8 @@ class DataContract:
302
309
  self._data_contract_str,
303
310
  self._data_contract,
304
311
  schema_location=self._schema_location,
305
- inline_definitions=True,
306
- inline_quality=True,
312
+ inline_definitions=self._inline_definitions,
313
+ inline_quality=self._inline_quality,
307
314
  )
308
315
 
309
316
  return exporter_factory.create(export_format).export(
@@ -148,7 +148,10 @@ def check_jsonschema(run: Run, data_contract: DataContractSpecification, server:
148
148
  schema = to_jsonschema(model_name, model)
149
149
  run.log_info(f"jsonschema: {schema}")
150
150
 
151
- validate = fastjsonschema.compile(schema)
151
+ validate = fastjsonschema.compile(
152
+ schema,
153
+ formats={"uuid": r"^[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}$"},
154
+ )
152
155
 
153
156
  # Process files based on server type
154
157
  if server.type == "local":
@@ -1,6 +1,9 @@
1
1
  import logging
2
+ import typing
3
+
4
+ if typing.TYPE_CHECKING:
5
+ from pyspark.sql import SparkSession
2
6
 
3
- from pyspark.sql import SparkSession
4
7
  from soda.scan import Scan
5
8
 
6
9
  from datacontract.engines.soda.connections.bigquery import to_bigquery_soda_configuration
@@ -17,7 +20,7 @@ from datacontract.model.run import Run, Check, Log
17
20
 
18
21
 
19
22
  def check_soda_execute(
20
- run: Run, data_contract: DataContractSpecification, server: Server, spark: SparkSession, tmp_dir
23
+ run: Run, data_contract: DataContractSpecification, server: Server, spark: "SparkSession", tmp_dir
21
24
  ):
22
25
  if data_contract is None:
23
26
  run.log_warn("Cannot run engine soda-core, as data contract is invalid")
@@ -49,20 +49,28 @@ def get_duckdb_connection(data_contract, server, run: Run):
49
49
  f"""CREATE VIEW "{model_name}" AS SELECT * FROM read_csv('{model_path}', hive_partitioning=1, columns={columns});"""
50
50
  )
51
51
  elif server.format == "delta":
52
+ if server.type == "local":
53
+ delta_table_arrow = DeltaTable(model_path).to_pyarrow_dataset()
54
+ con.register(model_name, delta_table_arrow)
55
+
52
56
  if server.type == "azure":
57
+ # After switching to native delta table support
58
+ # in https://github.com/datacontract/datacontract-cli/issues/258,
59
+ # azure storage should also work
60
+ # https://github.com/duckdb/duckdb_delta/issues/21
53
61
  raise NotImplementedError("Support for Delta Tables on Azure Storage is not implemented yet")
54
-
55
- storage_options = {
56
- "AWS_ENDPOINT_URL": server.endpointUrl,
57
- "AWS_ACCESS_KEY_ID": os.getenv("DATACONTRACT_S3_ACCESS_KEY_ID"),
58
- "AWS_SECRET_ACCESS_KEY": os.getenv("DATACONTRACT_S3_SECRET_ACCESS_KEY"),
59
- "AWS_REGION": os.getenv("DATACONTRACT_S3_REGION", "us-east-1"),
60
- "AWS_ALLOW_HTTP": "True" if server.endpointUrl.startswith("http://") else "False",
61
- }
62
-
63
- delta_table_arrow = DeltaTable(model_path, storage_options=storage_options).to_pyarrow_dataset()
64
-
65
- con.register(model_name, delta_table_arrow)
62
+ if server.type == "s3":
63
+ storage_options = {
64
+ "AWS_ENDPOINT_URL": server.endpointUrl,
65
+ "AWS_ACCESS_KEY_ID": os.getenv("DATACONTRACT_S3_ACCESS_KEY_ID"),
66
+ "AWS_SECRET_ACCESS_KEY": os.getenv("DATACONTRACT_S3_SECRET_ACCESS_KEY"),
67
+ "AWS_REGION": os.getenv("DATACONTRACT_S3_REGION", "us-east-1"),
68
+ "AWS_ALLOW_HTTP": "True" if server.endpointUrl.startswith("http://") else "False",
69
+ }
70
+
71
+ delta_table_arrow = DeltaTable(model_path, storage_options=storage_options).to_pyarrow_dataset()
72
+
73
+ con.register(model_name, delta_table_arrow)
66
74
  return con
67
75
 
68
76
 
@@ -4,17 +4,20 @@ import yaml
4
4
 
5
5
 
6
6
  def to_snowflake_soda_configuration(server):
7
+ prefix = "DATACONTRACT_SNOWFLAKE_"
8
+ snowflake_soda_params = {k.replace(prefix, "").lower(): v for k, v in os.environ.items() if k.startswith(prefix)}
9
+
10
+ # backward compatibility
11
+ if "connection_timeout" not in snowflake_soda_params:
12
+ snowflake_soda_params["connection_timeout"] = "5" # minutes
13
+
7
14
  soda_configuration = {
8
15
  f"data_source {server.type}": {
9
16
  "type": "snowflake",
10
- "username": os.getenv("DATACONTRACT_SNOWFLAKE_USERNAME"),
11
- "password": os.getenv("DATACONTRACT_SNOWFLAKE_PASSWORD"),
12
- "role": os.getenv("DATACONTRACT_SNOWFLAKE_ROLE"),
13
17
  "account": server.account,
14
18
  "database": server.database,
15
19
  "schema": server.schema_,
16
- "warehouse": os.getenv("DATACONTRACT_SNOWFLAKE_WAREHOUSE"),
17
- "connection_timeout": 5, # minutes
20
+ **snowflake_soda_params,
18
21
  }
19
22
  }
20
23
  soda_configuration_str = yaml.dump(soda_configuration)
@@ -65,7 +65,7 @@ def to_avro_type(field: Field, field_name: str) -> str | dict:
65
65
  if field.config["avroLogicalType"] in ["time-millis", "date"]:
66
66
  return {"type": "int", "logicalType": field.config["avroLogicalType"]}
67
67
  if "avroType" in field.config:
68
- return field.config["avroLogicalType"]
68
+ return field.config["avroType"]
69
69
 
70
70
  if field.type is None:
71
71
  return "null"
@@ -3,6 +3,7 @@ from importlib.metadata import version
3
3
  from typing import Tuple
4
4
 
5
5
  import pytz
6
+ from datacontract.model.exceptions import DataContractException
6
7
 
7
8
  import datacontract.model.data_contract_specification as spec
8
9
  from datacontract.export.sql_type_converter import convert_to_sql_type
@@ -48,17 +49,7 @@ Using {5} Types for the field types
48
49
  {0}
49
50
  */
50
51
  """.format(generated_info)
51
-
52
- note = """Note project_info {{
53
- '''
54
- {0}
55
- '''
56
- }}
57
- """.format(generated_info)
58
-
59
- return """{0}
60
- {1}
61
- """.format(comment, note)
52
+ return comment
62
53
 
63
54
 
64
55
  def get_version() -> str:
@@ -70,19 +61,18 @@ def get_version() -> str:
70
61
 
71
62
  def generate_project_info(contract: spec.DataContractSpecification) -> str:
72
63
  return """Project "{0}" {{
73
- Note: "{1}"
64
+ Note: '''{1}'''
74
65
  }}\n
75
- """.format(contract.info.title, " ".join(contract.info.description.splitlines()))
66
+ """.format(contract.info.title, contract.info.description)
76
67
 
77
68
 
78
69
  def generate_table(model_name: str, model: spec.Model, server: spec.Server) -> str:
79
70
  result = """Table "{0}" {{
80
- Note: "{1}"
81
- """.format(model_name, " ".join(model.description.splitlines()))
71
+ Note: {1}
72
+ """.format(model_name, formatDescription(model.description))
82
73
 
83
74
  references = []
84
75
 
85
- # Add all the fields
86
76
  for field_name, field in model.fields.items():
87
77
  ref, field_string = generate_field(field_name, field, model_name, server)
88
78
  if ref is not None:
@@ -102,6 +92,30 @@ Note: "{1}"
102
92
 
103
93
 
104
94
  def generate_field(field_name: str, field: spec.Field, model_name: str, server: spec.Server) -> Tuple[str, str]:
95
+ if field.primary:
96
+ if field.required is not None:
97
+ if not field.required:
98
+ raise DataContractException(
99
+ type="lint",
100
+ name="Primary key fields cannot have required == False.",
101
+ result="error",
102
+ reason="Primary key fields cannot have required == False.",
103
+ engine="datacontract",
104
+ )
105
+ else:
106
+ field.required = True
107
+ if field.unique is not None:
108
+ if not field.unique:
109
+ raise DataContractException(
110
+ type="lint",
111
+ name="Primary key fields cannot have unique == False",
112
+ result="error",
113
+ reason="Primary key fields cannot have unique == False.",
114
+ engine="datacontract",
115
+ )
116
+ else:
117
+ field.unique = True
118
+
105
119
  field_attrs = []
106
120
  if field.primary:
107
121
  field_attrs.append("pk")
@@ -115,13 +129,21 @@ def generate_field(field_name: str, field: spec.Field, model_name: str, server:
115
129
  field_attrs.append("null")
116
130
 
117
131
  if field.description:
118
- field_attrs.append('Note: "{0}"'.format(" ".join(field.description.splitlines())))
132
+ field_attrs.append("""Note: {0}""".format(formatDescription(field.description)))
119
133
 
120
134
  field_type = field.type if server is None else convert_to_sql_type(field, server.type)
121
135
 
122
136
  field_str = '"{0}" "{1}" [{2}]'.format(field_name, field_type, ",".join(field_attrs))
123
137
  ref_str = None
124
138
  if (field.references) is not None:
125
- # we always assume many to one, as datacontract doesn't really give us more info
126
- ref_str = "{0}.{1} > {2}".format(model_name, field_name, field.references)
139
+ if field.unique:
140
+ ref_str = "{0}.{1} - {2}".format(model_name, field_name, field.references)
141
+ else:
142
+ ref_str = "{0}.{1} > {2}".format(model_name, field_name, field.references)
127
143
  return (ref_str, field_str)
144
+
145
+ def formatDescription(input: str) -> str:
146
+ if '\n' in input or '\r' in input or '"' in input:
147
+ return "'''{0}'''".format(input)
148
+ else:
149
+ return '"{0}"'.format(input)
@@ -37,7 +37,7 @@ class ExportFormat(str, Enum):
37
37
  spark = "spark"
38
38
 
39
39
  @classmethod
40
- def get_suported_formats(cls):
40
+ def get_supported_formats(cls):
41
41
  return list(map(lambda c: c.value, cls))
42
42
 
43
43
 
@@ -36,10 +36,7 @@ def to_property(field: Field) -> dict:
36
36
  property = {}
37
37
  json_type, json_format = convert_type_format(field.type, field.format)
38
38
  if json_type is not None:
39
- if field.required:
40
- property["type"] = json_type
41
- else:
42
- property["type"] = [json_type, "null"]
39
+ property["type"] = json_type
43
40
  if json_format is not None:
44
41
  property["format"] = json_format
45
42
  if field.unique:
@@ -131,7 +131,7 @@ def check_field_minimum(field_name, minimum, quote_field_name: bool = False):
131
131
  field_name = f'"{field_name}"'
132
132
  return {
133
133
  f"invalid_count({field_name}) = 0": {
134
- "name": f"Check that field {field_name} has a minimum of {min}",
134
+ "name": f"Check that field {field_name} has a minimum of {minimum}",
135
135
  "valid min": minimum,
136
136
  }
137
137
  }
@@ -1,3 +1,5 @@
1
+ from typing import Dict, List
2
+
1
3
  import avro.schema
2
4
 
3
5
  from datacontract.imports.importer import Importer
@@ -6,13 +8,39 @@ from datacontract.model.exceptions import DataContractException
6
8
 
7
9
 
8
10
  class AvroImporter(Importer):
11
+ """Class to import Avro Schema file"""
12
+
9
13
  def import_source(
10
14
  self, data_contract_specification: DataContractSpecification, source: str, import_args: dict
11
- ) -> dict:
15
+ ) -> DataContractSpecification:
16
+ """
17
+ Import Avro schema from a source file.
18
+
19
+ Args:
20
+ data_contract_specification: The data contract specification to update.
21
+ source: The path to the Avro schema file.
22
+ import_args: Additional import arguments.
23
+
24
+ Returns:
25
+ The updated data contract specification.
26
+ """
12
27
  return import_avro(data_contract_specification, source)
13
28
 
14
29
 
15
30
  def import_avro(data_contract_specification: DataContractSpecification, source: str) -> DataContractSpecification:
31
+ """
32
+ Import an Avro schema from a file and update the data contract specification.
33
+
34
+ Args:
35
+ data_contract_specification: The data contract specification to update.
36
+ source: The path to the Avro schema file.
37
+
38
+ Returns:
39
+ DataContractSpecification: The updated data contract specification.
40
+
41
+ Raises:
42
+ DataContractException: If there's an error parsing the Avro schema.
43
+ """
16
44
  if data_contract_specification.models is None:
17
45
  data_contract_specification.models = {}
18
46
 
@@ -45,7 +73,14 @@ def import_avro(data_contract_specification: DataContractSpecification, source:
45
73
  return data_contract_specification
46
74
 
47
75
 
48
- def handle_config_avro_custom_properties(field, imported_field):
76
+ def handle_config_avro_custom_properties(field: avro.schema.Field, imported_field: Field) -> None:
77
+ """
78
+ Handle custom Avro properties and add them to the imported field's config.
79
+
80
+ Args:
81
+ field: The Avro field.
82
+ imported_field: The imported field to update.
83
+ """
49
84
  if field.get_prop("logicalType") is not None:
50
85
  if imported_field.config is None:
51
86
  imported_field.config = {}
@@ -57,7 +92,16 @@ def handle_config_avro_custom_properties(field, imported_field):
57
92
  imported_field.config["avroDefault"] = field.default
58
93
 
59
94
 
60
- def import_record_fields(record_fields):
95
+ def import_record_fields(record_fields: List[avro.schema.Field]) -> Dict[str, Field]:
96
+ """
97
+ Import Avro record fields and convert them to data contract fields.
98
+
99
+ Args:
100
+ record_fields: List of Avro record fields.
101
+
102
+ Returns:
103
+ A dictionary of imported fields.
104
+ """
61
105
  imported_fields = {}
62
106
  for field in record_fields:
63
107
  imported_field = Field()
@@ -83,6 +127,15 @@ def import_record_fields(record_fields):
83
127
  elif field.type.type == "array":
84
128
  imported_field.type = "array"
85
129
  imported_field.items = import_avro_array_items(field.type)
130
+ elif field.type.type == "map":
131
+ imported_field.type = "map"
132
+ imported_field.values = import_avro_map_values(field.type)
133
+ elif field.type.type == "enum":
134
+ imported_field.type = "string"
135
+ imported_field.enum = field.type.symbols
136
+ if not imported_field.config:
137
+ imported_field.config = {}
138
+ imported_field.config["avroType"] = "enum"
86
139
  else: # primitive type
87
140
  imported_field.type = map_type_from_avro(field.type.type)
88
141
 
@@ -91,7 +144,16 @@ def import_record_fields(record_fields):
91
144
  return imported_fields
92
145
 
93
146
 
94
- def import_avro_array_items(array_schema):
147
+ def import_avro_array_items(array_schema: avro.schema.ArraySchema) -> Field:
148
+ """
149
+ Import Avro array items and convert them to a data contract field.
150
+
151
+ Args:
152
+ array_schema: The Avro array schema.
153
+
154
+ Returns:
155
+ Field: The imported field representing the array items.
156
+ """
95
157
  items = Field()
96
158
  for prop in array_schema.other_props:
97
159
  items.__setattr__(prop, array_schema.other_props[prop])
@@ -108,7 +170,45 @@ def import_avro_array_items(array_schema):
108
170
  return items
109
171
 
110
172
 
111
- def import_type_of_optional_field(field):
173
+ def import_avro_map_values(map_schema: avro.schema.MapSchema) -> Field:
174
+ """
175
+ Import Avro map values and convert them to a data contract field.
176
+
177
+ Args:
178
+ map_schema: The Avro map schema.
179
+
180
+ Returns:
181
+ Field: The imported field representing the map values.
182
+ """
183
+ values = Field()
184
+ for prop in map_schema.other_props:
185
+ values.__setattr__(prop, map_schema.other_props[prop])
186
+
187
+ if map_schema.values.type == "record":
188
+ values.type = "object"
189
+ values.fields = import_record_fields(map_schema.values.fields)
190
+ elif map_schema.values.type == "array":
191
+ values.type = "array"
192
+ values.items = import_avro_array_items(map_schema.values)
193
+ else: # primitive type
194
+ values.type = map_type_from_avro(map_schema.values.type)
195
+
196
+ return values
197
+
198
+
199
+ def import_type_of_optional_field(field: avro.schema.Field) -> str:
200
+ """
201
+ Determine the type of optional field in an Avro union.
202
+
203
+ Args:
204
+ field: The Avro field with a union type.
205
+
206
+ Returns:
207
+ str: The mapped type of the non-null field in the union.
208
+
209
+ Raises:
210
+ DataContractException: If no non-null type is found in the union.
211
+ """
112
212
  for field_type in field.type.schemas:
113
213
  if field_type.type != "null":
114
214
  return map_type_from_avro(field_type.type)
@@ -121,21 +221,51 @@ def import_type_of_optional_field(field):
121
221
  )
122
222
 
123
223
 
124
- def get_record_from_union_field(field):
224
+ def get_record_from_union_field(field: avro.schema.Field) -> avro.schema.RecordSchema | None:
225
+ """
226
+ Get the record schema from a union field.
227
+
228
+ Args:
229
+ field: The Avro field with a union type.
230
+
231
+ Returns:
232
+ The record schema if found, None otherwise.
233
+ """
125
234
  for field_type in field.type.schemas:
126
235
  if field_type.type == "record":
127
236
  return field_type
128
237
  return None
129
238
 
130
239
 
131
- def get_array_from_union_field(field):
240
+ def get_array_from_union_field(field: avro.schema.Field) -> avro.schema.ArraySchema | None:
241
+ """
242
+ Get the array schema from a union field.
243
+
244
+ Args:
245
+ field: The Avro field with a union type.
246
+
247
+ Returns:
248
+ The array schema if found, None otherwise.
249
+ """
132
250
  for field_type in field.type.schemas:
133
251
  if field_type.type == "array":
134
252
  return field_type
135
253
  return None
136
254
 
137
255
 
138
- def map_type_from_avro(avro_type_str: str):
256
+ def map_type_from_avro(avro_type_str: str) -> str:
257
+ """
258
+ Map Avro type strings to data contract type strings.
259
+
260
+ Args:
261
+ avro_type_str (str): The Avro type string.
262
+
263
+ Returns:
264
+ str: The corresponding data contract type string.
265
+
266
+ Raises:
267
+ DataContractException: If the Avro type is unsupported.
268
+ """
139
269
  # TODO: ambiguous mapping in the export
140
270
  if avro_type_str == "null":
141
271
  return "null"
@@ -155,6 +285,10 @@ def map_type_from_avro(avro_type_str: str):
155
285
  return "record"
156
286
  elif avro_type_str == "array":
157
287
  return "array"
288
+ elif avro_type_str == "map":
289
+ return "map"
290
+ elif avro_type_str == "enum":
291
+ return "string"
158
292
  else:
159
293
  raise DataContractException(
160
294
  type="schema",
@@ -0,0 +1,117 @@
1
+ import json
2
+
3
+ from typing import (
4
+ List,
5
+ )
6
+
7
+ from datacontract.imports.importer import Importer
8
+ from datacontract.model.data_contract_specification import DataContractSpecification, Field, Model
9
+
10
+
11
+ class DbtManifestImporter(Importer):
12
+ def import_source(
13
+ self, data_contract_specification: DataContractSpecification, source: str, import_args: dict
14
+ ) -> dict:
15
+ data = read_dbt_manifest(manifest_path=source)
16
+ return import_dbt_manifest(
17
+ data_contract_specification, manifest_dict=data, dbt_models=import_args.get("dbt_model")
18
+ )
19
+
20
+
21
+ def import_dbt_manifest(
22
+ data_contract_specification: DataContractSpecification, manifest_dict: dict, dbt_models: List[str]
23
+ ):
24
+ data_contract_specification.info.title = manifest_dict.get("info").get("project_name")
25
+ data_contract_specification.info.dbt_version = manifest_dict.get("info").get("dbt_version")
26
+
27
+ if data_contract_specification.models is None:
28
+ data_contract_specification.models = {}
29
+
30
+ for model in manifest_dict.get("models", []):
31
+ if dbt_models and model.name not in dbt_models:
32
+ continue
33
+
34
+ dc_model = Model(
35
+ description=model.description,
36
+ tags=model.tags,
37
+ fields=create_fields(model.columns),
38
+ )
39
+
40
+ data_contract_specification.models[model.name] = dc_model
41
+
42
+ return data_contract_specification
43
+
44
+
45
+ def create_fields(columns: List):
46
+ fields = {}
47
+ for column in columns:
48
+ field = Field(
49
+ description=column.description, type=column.data_type if column.data_type else "", tags=column.tags
50
+ )
51
+ fields[column.name] = field
52
+
53
+ return fields
54
+
55
+
56
+ def read_dbt_manifest(manifest_path: str):
57
+ with open(manifest_path, "r", encoding="utf-8") as f:
58
+ manifest = json.load(f)
59
+ return {"info": manifest.get("metadata"), "models": create_manifest_models(manifest)}
60
+
61
+
62
+ def create_manifest_models(manifest: dict) -> List:
63
+ models = []
64
+ nodes = manifest.get("nodes")
65
+
66
+ for node in nodes.values():
67
+ if node["resource_type"] != "model":
68
+ continue
69
+
70
+ models.append(DbtModel(node))
71
+ return models
72
+
73
+
74
+ class DbtColumn:
75
+ name: str
76
+ description: str
77
+ data_type: str
78
+ meta: dict
79
+ tags: List
80
+
81
+ def __init__(self, node_column: dict):
82
+ self.name = node_column.get("name")
83
+ self.description = node_column.get("description")
84
+ self.data_type = node_column.get("data_type")
85
+ self.meta = node_column.get("meta", {})
86
+ self.tags = node_column.get("tags", [])
87
+
88
+ def __repr__(self) -> str:
89
+ return self.name
90
+
91
+
92
+ class DbtModel:
93
+ name: str
94
+ database: str
95
+ schema: str
96
+ description: str
97
+ unique_id: str
98
+ tags: List
99
+
100
+ def __init__(self, node: dict):
101
+ self.name = node.get("name")
102
+ self.database = node.get("database")
103
+ self.schema = node.get("schema")
104
+ self.description = node.get("description")
105
+ self.display_name = node.get("display_name")
106
+ self.unique_id = node.get("unique_id")
107
+ self.columns = []
108
+ self.tags = node.get("tags")
109
+ if node.get("columns"):
110
+ self.add_columns(node.get("columns").values())
111
+
112
+ def add_columns(self, model_columns: List):
113
+ for column in model_columns:
114
+ self.columns.append(DbtColumn(column))
115
+
116
+ def __repr__(self) -> str:
117
+ return self.name