datacontract-cli 0.10.23__py3-none-any.whl → 0.10.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. datacontract/__init__.py +13 -0
  2. datacontract/api.py +12 -5
  3. datacontract/catalog/catalog.py +5 -3
  4. datacontract/cli.py +116 -10
  5. datacontract/data_contract.py +143 -65
  6. datacontract/engines/data_contract_checks.py +366 -60
  7. datacontract/engines/data_contract_test.py +50 -4
  8. datacontract/engines/fastjsonschema/check_jsonschema.py +37 -19
  9. datacontract/engines/fastjsonschema/s3/s3_read_files.py +3 -2
  10. datacontract/engines/soda/check_soda_execute.py +22 -3
  11. datacontract/engines/soda/connections/athena.py +79 -0
  12. datacontract/engines/soda/connections/duckdb_connection.py +65 -6
  13. datacontract/engines/soda/connections/kafka.py +4 -2
  14. datacontract/export/avro_converter.py +20 -3
  15. datacontract/export/bigquery_converter.py +1 -1
  16. datacontract/export/dbt_converter.py +36 -7
  17. datacontract/export/dqx_converter.py +126 -0
  18. datacontract/export/duckdb_type_converter.py +57 -0
  19. datacontract/export/excel_exporter.py +923 -0
  20. datacontract/export/exporter.py +3 -0
  21. datacontract/export/exporter_factory.py +17 -1
  22. datacontract/export/great_expectations_converter.py +55 -5
  23. datacontract/export/{html_export.py → html_exporter.py} +31 -20
  24. datacontract/export/markdown_converter.py +134 -5
  25. datacontract/export/mermaid_exporter.py +110 -0
  26. datacontract/export/odcs_v3_exporter.py +187 -145
  27. datacontract/export/protobuf_converter.py +163 -69
  28. datacontract/export/rdf_converter.py +2 -2
  29. datacontract/export/sodacl_converter.py +9 -1
  30. datacontract/export/spark_converter.py +31 -4
  31. datacontract/export/sql_converter.py +6 -2
  32. datacontract/export/sql_type_converter.py +20 -8
  33. datacontract/imports/avro_importer.py +63 -12
  34. datacontract/imports/csv_importer.py +111 -57
  35. datacontract/imports/excel_importer.py +1111 -0
  36. datacontract/imports/importer.py +16 -3
  37. datacontract/imports/importer_factory.py +17 -0
  38. datacontract/imports/json_importer.py +325 -0
  39. datacontract/imports/odcs_importer.py +2 -2
  40. datacontract/imports/odcs_v3_importer.py +351 -151
  41. datacontract/imports/protobuf_importer.py +264 -0
  42. datacontract/imports/spark_importer.py +117 -13
  43. datacontract/imports/sql_importer.py +32 -16
  44. datacontract/imports/unity_importer.py +84 -38
  45. datacontract/init/init_template.py +1 -1
  46. datacontract/integration/datamesh_manager.py +16 -2
  47. datacontract/lint/resolve.py +112 -23
  48. datacontract/lint/schema.py +24 -15
  49. datacontract/model/data_contract_specification/__init__.py +1 -0
  50. datacontract/model/odcs.py +13 -0
  51. datacontract/model/run.py +3 -0
  52. datacontract/output/junit_test_results.py +3 -3
  53. datacontract/schemas/datacontract-1.1.0.init.yaml +1 -1
  54. datacontract/schemas/datacontract-1.2.0.init.yaml +91 -0
  55. datacontract/schemas/datacontract-1.2.0.schema.json +2029 -0
  56. datacontract/schemas/datacontract-1.2.1.init.yaml +91 -0
  57. datacontract/schemas/datacontract-1.2.1.schema.json +2058 -0
  58. datacontract/schemas/odcs-3.0.2.schema.json +2382 -0
  59. datacontract/templates/datacontract.html +54 -3
  60. datacontract/templates/datacontract_odcs.html +685 -0
  61. datacontract/templates/index.html +5 -2
  62. datacontract/templates/partials/server.html +2 -0
  63. datacontract/templates/style/output.css +319 -145
  64. {datacontract_cli-0.10.23.dist-info → datacontract_cli-0.10.37.dist-info}/METADATA +656 -431
  65. datacontract_cli-0.10.37.dist-info/RECORD +119 -0
  66. {datacontract_cli-0.10.23.dist-info → datacontract_cli-0.10.37.dist-info}/WHEEL +1 -1
  67. {datacontract_cli-0.10.23.dist-info → datacontract_cli-0.10.37.dist-info/licenses}/LICENSE +1 -1
  68. datacontract/export/csv_type_converter.py +0 -36
  69. datacontract/lint/lint.py +0 -142
  70. datacontract/lint/linters/description_linter.py +0 -35
  71. datacontract/lint/linters/field_pattern_linter.py +0 -34
  72. datacontract/lint/linters/field_reference_linter.py +0 -48
  73. datacontract/lint/linters/notice_period_linter.py +0 -55
  74. datacontract/lint/linters/quality_schema_linter.py +0 -52
  75. datacontract/lint/linters/valid_constraints_linter.py +0 -100
  76. datacontract/model/data_contract_specification.py +0 -327
  77. datacontract_cli-0.10.23.dist-info/RECORD +0 -113
  78. /datacontract/{lint/linters → output}/__init__.py +0 -0
  79. {datacontract_cli-0.10.23.dist-info → datacontract_cli-0.10.37.dist-info}/entry_points.txt +0 -0
  80. {datacontract_cli-0.10.23.dist-info → datacontract_cli-0.10.37.dist-info}/top_level.txt +0 -0
@@ -4,102 +4,196 @@ from datacontract.model.data_contract_specification import DataContractSpecifica
4
4
 
5
5
  class ProtoBufExporter(Exporter):
6
6
  def export(self, data_contract, model, server, sql_server_type, export_args) -> dict:
7
- return to_protobuf(data_contract)
7
+ # Returns a dict containing the protobuf representation.
8
+ proto = to_protobuf(data_contract)
9
+ return {"protobuf": proto}
8
10
 
9
11
 
10
- def to_protobuf(data_contract_spec: DataContractSpecification):
12
+ def to_protobuf(data_contract_spec: DataContractSpecification) -> str:
13
+ """
14
+ Generates a Protobuf file from the data contract specification.
15
+ Scans all models for enum fields (even if the type is "string") by checking for a "values" property.
16
+ """
11
17
  messages = ""
18
+ enum_definitions = {}
19
+
20
+ # Iterate over all models to generate messages and collect enum definitions.
12
21
  for model_name, model in data_contract_spec.models.items():
13
- messages += to_protobuf_message(model_name, model.fields, model.description, 0)
22
+ for field_name, field in model.fields.items():
23
+ # If the field has enum values, collect them.
24
+ if _is_enum_field(field):
25
+ enum_name = _get_enum_name(field, field_name)
26
+ enum_values = _get_enum_values(field)
27
+ if enum_values and enum_name not in enum_definitions:
28
+ enum_definitions[enum_name] = enum_values
29
+
30
+ messages += to_protobuf_message(model_name, model.fields, getattr(model, "description", ""), 0)
14
31
  messages += "\n"
15
32
 
16
- result = f"""syntax = "proto3";
17
-
18
- {messages}
19
- """
20
-
21
- return result
22
-
23
-
24
- def _to_protobuf_message_name(model_name):
25
- return model_name[0].upper() + model_name[1:]
26
-
27
-
28
- def to_protobuf_message(model_name, fields, description, indent_level: int = 0):
33
+ # Build header with syntax and package declarations.
34
+ header = 'syntax = "proto3";\n\n'
35
+ package = getattr(data_contract_spec, "package", "example")
36
+ header += f"package {package};\n\n"
37
+
38
+ # Append enum definitions.
39
+ for enum_name, enum_values in enum_definitions.items():
40
+ header += f"// Enum for {enum_name}\n"
41
+ header += f"enum {enum_name} {{\n"
42
+ # Only iterate if enum_values is a dictionary.
43
+ if isinstance(enum_values, dict):
44
+ for enum_const, value in sorted(enum_values.items(), key=lambda item: item[1]):
45
+ normalized_const = enum_const.upper().replace(" ", "_")
46
+ header += f" {normalized_const} = {value};\n"
47
+ else:
48
+ header += f" // Warning: Enum values for {enum_name} are not a dictionary\n"
49
+ header += "}\n\n"
50
+ return header + messages
51
+
52
+
53
+ def _is_enum_field(field) -> bool:
54
+ """
55
+ Returns True if the field (dict or object) has a non-empty "values" property.
56
+ """
57
+ if isinstance(field, dict):
58
+ return bool(field.get("values"))
59
+ return bool(getattr(field, "values", None))
60
+
61
+
62
+ def _get_enum_name(field, field_name: str) -> str:
63
+ """
64
+ Returns the enum name either from the field's "enum_name" or derived from the field name.
65
+ """
66
+ if isinstance(field, dict):
67
+ return field.get("enum_name", _to_protobuf_message_name(field_name))
68
+ return getattr(field, "enum_name", None) or _to_protobuf_message_name(field_name)
69
+
70
+
71
+ def _get_enum_values(field) -> dict:
72
+ """
73
+ Returns the enum values from the field.
74
+ If the values are not a dictionary, attempts to extract enum attributes.
75
+ """
76
+ if isinstance(field, dict):
77
+ values = field.get("values", {})
78
+ else:
79
+ values = getattr(field, "values", {})
80
+
81
+ if not isinstance(values, dict):
82
+ # If values is a BaseModel (or similar) with a .dict() method, use it.
83
+ if hasattr(values, "dict") and callable(values.dict):
84
+ values_dict = values.dict()
85
+ return {k: v for k, v in values_dict.items() if k.isupper() and isinstance(v, int)}
86
+ else:
87
+ # Otherwise, iterate over attributes that look like enums.
88
+ return {
89
+ key: getattr(values, key)
90
+ for key in dir(values)
91
+ if key.isupper() and isinstance(getattr(values, key), int)
92
+ }
93
+ return values
94
+
95
+
96
+ def _to_protobuf_message_name(name: str) -> str:
97
+ """
98
+ Returns a valid Protobuf message/enum name by capitalizing the first letter.
99
+ """
100
+ return name[0].upper() + name[1:] if name else name
101
+
102
+
103
+ def to_protobuf_message(model_name: str, fields: dict, description: str, indent_level: int = 0) -> str:
104
+ """
105
+ Generates a Protobuf message definition from the model's fields.
106
+ Handles nested messages for complex types.
107
+ """
29
108
  result = ""
109
+ if description:
110
+ result += f"{indent(indent_level)}// {description}\n"
30
111
 
31
- if description is not None:
32
- result += f"""{indent(indent_level)}/* {description} */\n"""
33
-
34
- fields_protobuf = ""
112
+ result += f"message {_to_protobuf_message_name(model_name)} {{\n"
35
113
  number = 1
36
114
  for field_name, field in fields.items():
37
- if field.type in ["object", "record", "struct"]:
38
- fields_protobuf += (
39
- "\n".join(
40
- map(
41
- lambda x: " " + x,
42
- to_protobuf_message(field_name, field.fields, field.description, indent_level + 1).splitlines(),
43
- )
44
- )
45
- + "\n"
46
- )
47
-
48
- fields_protobuf += to_protobuf_field(field_name, field, field.description, number, 1) + "\n"
115
+ # For nested objects, generate a nested message.
116
+ field_type = _get_field_type(field)
117
+ if field_type in ["object", "record", "struct"]:
118
+ nested_desc = field.get("description", "") if isinstance(field, dict) else getattr(field, "description", "")
119
+ nested_fields = field.get("fields", {}) if isinstance(field, dict) else field.fields
120
+ nested_message = to_protobuf_message(field_name, nested_fields, nested_desc, indent_level + 1)
121
+ result += nested_message + "\n"
122
+
123
+ field_desc = field.get("description", "") if isinstance(field, dict) else getattr(field, "description", "")
124
+ result += to_protobuf_field(field_name, field, field_desc, number, indent_level + 1) + "\n"
49
125
  number += 1
50
- result += f"message {_to_protobuf_message_name(model_name)} {{\n{fields_protobuf}}}\n"
51
126
 
127
+ result += f"{indent(indent_level)}}}\n"
52
128
  return result
53
129
 
54
130
 
55
- def to_protobuf_field(field_name, field, description, number: int, indent_level: int = 0):
56
- optional = ""
57
- if not field.required:
58
- optional = "optional "
59
-
131
+ def to_protobuf_field(field_name: str, field, description: str, number: int, indent_level: int = 0) -> str:
132
+ """
133
+ Generates a field definition within a Protobuf message.
134
+ """
60
135
  result = ""
61
-
62
- if description is not None:
63
- result += f"""{indent(indent_level)}/* {description} */\n"""
64
-
65
- result += f"{indent(indent_level)}{optional}{_convert_type(field_name, field)} {field_name} = {number};"
66
-
136
+ if description:
137
+ result += f"{indent(indent_level)}// {description}\n"
138
+ result += f"{indent(indent_level)}{_convert_type(field_name, field)} {field_name} = {number};"
67
139
  return result
68
140
 
69
141
 
70
- def indent(indent_level):
142
+ def indent(indent_level: int) -> str:
71
143
  return " " * indent_level
72
144
 
73
145
 
74
- def _convert_type(field_name, field) -> None | str:
75
- type = field.type
76
- if type is None:
77
- return None
78
- if type.lower() in ["string", "varchar", "text"]:
79
- return "string"
80
- if type.lower() in ["timestamp", "timestamp_tz"]:
81
- return "string"
82
- if type.lower() in ["timestamp_ntz"]:
83
- return "string"
84
- if type.lower() in ["date"]:
146
+ def _get_field_type(field) -> str:
147
+ """
148
+ Retrieves the field type from the field definition.
149
+ """
150
+ if isinstance(field, dict):
151
+ return field.get("type", "").lower()
152
+ return getattr(field, "type", "").lower()
153
+
154
+
155
+ def _convert_type(field_name: str, field) -> str:
156
+ """
157
+ Converts a field's type (from the data contract) to a Protobuf type.
158
+ Prioritizes enum conversion if a non-empty "values" property exists.
159
+ """
160
+ # For debugging purposes
161
+ print("Converting field:", field_name)
162
+ # If the field should be treated as an enum, return its enum name.
163
+ if _is_enum_field(field):
164
+ return _get_enum_name(field, field_name)
165
+
166
+ lower_type = _get_field_type(field)
167
+ if lower_type in ["string", "varchar", "text"]:
85
168
  return "string"
86
- if type.lower() in ["time"]:
169
+ if lower_type in ["timestamp", "timestamp_tz", "timestamp_ntz", "date", "time"]:
87
170
  return "string"
88
- if type.lower() in ["number", "decimal", "numeric"]:
171
+ if lower_type in ["number", "decimal", "numeric"]:
89
172
  return "double"
90
- if type.lower() in ["float", "double"]:
91
- return type.lower()
92
- if type.lower() in ["integer", "int"]:
173
+ if lower_type in ["float", "double"]:
174
+ return lower_type
175
+ if lower_type in ["integer", "int"]:
93
176
  return "int32"
94
- if type.lower() in ["long", "bigint"]:
177
+ if lower_type in ["long", "bigint"]:
95
178
  return "int64"
96
- if type.lower() in ["boolean"]:
179
+ if lower_type in ["boolean"]:
97
180
  return "bool"
98
- if type.lower() in ["bytes"]:
181
+ if lower_type in ["bytes"]:
99
182
  return "bytes"
100
- if type.lower() in ["object", "record", "struct"]:
183
+ if lower_type in ["object", "record", "struct"]:
101
184
  return _to_protobuf_message_name(field_name)
102
- if type.lower() in ["array"]:
103
- # TODO spec is missing arrays
104
- return "repeated string"
105
- return None
185
+ if lower_type == "array":
186
+ # Handle array types. Check for an "items" property.
187
+ items = field.get("items") if isinstance(field, dict) else getattr(field, "items", None)
188
+ if items and isinstance(items, dict) and items.get("type"):
189
+ item_type = items.get("type", "").lower()
190
+ if item_type in ["object", "record", "struct"]:
191
+ # Singularize the field name (a simple approach).
192
+ singular = field_name[:-1] if field_name.endswith("s") else field_name
193
+ return "repeated " + _to_protobuf_message_name(singular)
194
+ else:
195
+ return "repeated " + _convert_type(field_name, items)
196
+ else:
197
+ return "repeated string"
198
+ # Fallback for unrecognized types.
199
+ return "string"
@@ -57,8 +57,8 @@ def to_rdf(data_contract_spec: DataContractSpecification, base) -> Graph:
57
57
  else:
58
58
  g = Graph(base=Namespace(""))
59
59
 
60
- dc = Namespace("https://datacontract.com/DataContractSpecification/1.1.0/")
61
- dcx = Namespace("https://datacontract.com/DataContractSpecification/1.1.0/Extension/")
60
+ dc = Namespace("https://datacontract.com/DataContractSpecification/1.2.1/")
61
+ dcx = Namespace("https://datacontract.com/DataContractSpecification/1.2.1/Extension/")
62
62
 
63
63
  g.bind("dc", dc)
64
64
  g.bind("dcx", dcx)
@@ -2,12 +2,14 @@ import yaml
2
2
 
3
3
  from datacontract.engines.data_contract_checks import create_checks
4
4
  from datacontract.export.exporter import Exporter
5
+ from datacontract.model.data_contract_specification import DataContractSpecification, Server
5
6
  from datacontract.model.run import Run
6
7
 
7
8
 
8
9
  class SodaExporter(Exporter):
9
- def export(self, data_contract, model, server, sql_server_type, export_args) -> dict:
10
+ def export(self, data_contract, model, server, sql_server_type, export_args) -> str:
10
11
  run = Run.create_run()
12
+ server = get_server(data_contract, server)
11
13
  run.checks.extend(create_checks(data_contract, server))
12
14
  return to_sodacl_yaml(run)
13
15
 
@@ -28,3 +30,9 @@ def to_sodacl_yaml(run: Run) -> str:
28
30
  else:
29
31
  sodacl_dict[key] = value
30
32
  return yaml.dump(sodacl_dict)
33
+
34
+
35
+ def get_server(data_contract_specification: DataContractSpecification, server_name: str = None) -> Server | None:
36
+ if server_name is None:
37
+ return None
38
+ return data_contract_specification.servers.get(server_name)
@@ -1,3 +1,5 @@
1
+ import json
2
+
1
3
  from pyspark.sql import types
2
4
 
3
5
  from datacontract.export.exporter import Exporter
@@ -104,7 +106,8 @@ def to_struct_field(field: Field, field_name: str) -> types.StructField:
104
106
  types.StructField: The corresponding Spark StructField.
105
107
  """
106
108
  data_type = to_spark_data_type(field)
107
- return types.StructField(name=field_name, dataType=data_type, nullable=not field.required)
109
+ metadata = to_spark_metadata(field)
110
+ return types.StructField(name=field_name, dataType=data_type, nullable=not field.required, metadata=metadata)
108
111
 
109
112
 
110
113
  def to_spark_data_type(field: Field) -> types.DataType:
@@ -126,6 +129,8 @@ def to_spark_data_type(field: Field) -> types.DataType:
126
129
  return types.StructType(to_struct_type(field.fields))
127
130
  if field_type == "map":
128
131
  return types.MapType(to_spark_data_type(field.keys), to_spark_data_type(field.values))
132
+ if field_type == "variant":
133
+ return types.VariantType()
129
134
  if field_type in ["string", "varchar", "text"]:
130
135
  return types.StringType()
131
136
  if field_type in ["number", "decimal", "numeric"]:
@@ -150,7 +155,25 @@ def to_spark_data_type(field: Field) -> types.DataType:
150
155
  return types.DateType()
151
156
  if field_type == "bytes":
152
157
  return types.BinaryType()
153
- return types.BinaryType()
158
+ return types.StringType() # default if no condition is met
159
+
160
+
161
+ def to_spark_metadata(field: Field) -> dict[str, str]:
162
+ """
163
+ Convert a field to a Spark metadata dictonary.
164
+
165
+ Args:
166
+ field (Field): The field to convert.
167
+
168
+ Returns:
169
+ dict: dictionary that can be supplied to Spark as metadata for a StructField
170
+ """
171
+
172
+ metadata = {}
173
+ if field.description:
174
+ metadata["comment"] = field.description
175
+
176
+ return metadata
154
177
 
155
178
 
156
179
  def print_schema(dtype: types.DataType) -> str:
@@ -175,7 +198,7 @@ def print_schema(dtype: types.DataType) -> str:
175
198
  Returns:
176
199
  str: The indented text.
177
200
  """
178
- return "\n".join([f'{" " * level}{line}' for line in text.split("\n")])
201
+ return "\n".join([f"{' ' * level}{line}" for line in text.split("\n")])
179
202
 
180
203
  def repr_column(column: types.StructField) -> str:
181
204
  """
@@ -190,7 +213,11 @@ def print_schema(dtype: types.DataType) -> str:
190
213
  name = f'"{column.name}"'
191
214
  data_type = indent(print_schema(column.dataType), 1)
192
215
  nullable = indent(f"{column.nullable}", 1)
193
- return f"StructField({name},\n{data_type},\n{nullable}\n)"
216
+ if column.metadata:
217
+ metadata = indent(f"{json.dumps(column.metadata)}", 1)
218
+ return f"StructField({name},\n{data_type},\n{nullable},\n{metadata}\n)"
219
+ else:
220
+ return f"StructField({name},\n{data_type},\n{nullable}\n)"
194
221
 
195
222
  def format_struct_type(struct_type: types.StructType) -> str:
196
223
  """
@@ -4,7 +4,7 @@ from datacontract.model.data_contract_specification import DataContractSpecifica
4
4
 
5
5
 
6
6
  class SqlExporter(Exporter):
7
- def export(self, data_contract, model, server, sql_server_type, export_args) -> dict:
7
+ def export(self, data_contract, model, server, sql_server_type, export_args) -> str:
8
8
  server_type = _determine_sql_server_type(
9
9
  data_contract,
10
10
  sql_server_type,
@@ -13,7 +13,7 @@ class SqlExporter(Exporter):
13
13
 
14
14
 
15
15
  class SqlQueryExporter(Exporter):
16
- def export(self, data_contract, model, server, sql_server_type, export_args) -> dict:
16
+ def export(self, data_contract, model, server, sql_server_type, export_args) -> str:
17
17
  model_name, model_value = _check_models_for_export(data_contract, model, self.export_format)
18
18
  server_type = _determine_sql_server_type(data_contract, sql_server_type, export_args.get("server"))
19
19
  return to_sql_query(
@@ -117,6 +117,8 @@ def _to_sql_table(model_name, model, server_type="snowflake"):
117
117
  result += " primary key"
118
118
  if server_type == "databricks" and field.description is not None:
119
119
  result += f' COMMENT "{_escape(field.description)}"'
120
+ if server_type == "snowflake" and field.description is not None:
121
+ result += f" COMMENT '{_escape(field.description)}'"
120
122
  if current_field_index < fields:
121
123
  result += ","
122
124
  result += "\n"
@@ -124,6 +126,8 @@ def _to_sql_table(model_name, model, server_type="snowflake"):
124
126
  result += ")"
125
127
  if server_type == "databricks" and model.description is not None:
126
128
  result += f' COMMENT "{_escape(model.description)}"'
129
+ if server_type == "snowflake" and model.description is not None:
130
+ result += f" COMMENT='{_escape(model.description)}'"
127
131
  result += ";\n"
128
132
  return result
129
133
 
@@ -3,6 +3,9 @@ from datacontract.model.data_contract_specification import Field
3
3
 
4
4
 
5
5
  def convert_to_sql_type(field: Field, server_type: str) -> str:
6
+ if field.config and "physicalType" in field.config:
7
+ return field.config["physicalType"]
8
+
6
9
  if server_type == "snowflake":
7
10
  return convert_to_snowflake(field)
8
11
  elif server_type == "postgres":
@@ -19,6 +22,7 @@ def convert_to_sql_type(field: Field, server_type: str) -> str:
19
22
  return convert_type_to_bigquery(field)
20
23
  elif server_type == "trino":
21
24
  return convert_type_to_trino(field)
25
+
22
26
  return field.type
23
27
 
24
28
 
@@ -129,8 +133,9 @@ def convert_to_dataframe(field: Field) -> None | str:
129
133
  if type.lower() in ["time"]:
130
134
  return "STRING"
131
135
  if type.lower() in ["number", "decimal", "numeric"]:
132
- # precision and scale not supported by data contract
133
- return "DECIMAL"
136
+ precision = field.precision if field.precision is not None else 38
137
+ scale = field.scale if field.scale is not None else 0
138
+ return f"DECIMAL({precision},{scale})"
134
139
  if type.lower() in ["float"]:
135
140
  return "FLOAT"
136
141
  if type.lower() in ["double"]:
@@ -158,9 +163,13 @@ def convert_to_dataframe(field: Field) -> None | str:
158
163
  # databricks data types:
159
164
  # https://docs.databricks.com/en/sql/language-manual/sql-ref-datatypes.html
160
165
  def convert_to_databricks(field: Field) -> None | str:
161
- if field.config and "databricksType" in field.config:
162
- return field.config["databricksType"]
163
166
  type = field.type
167
+ if (
168
+ field.config
169
+ and "databricksType" in field.config
170
+ and type.lower() not in ["array", "object", "record", "struct"]
171
+ ):
172
+ return field.config["databricksType"]
164
173
  if type is None:
165
174
  return None
166
175
  if type.lower() in ["string", "varchar", "text"]:
@@ -174,8 +183,9 @@ def convert_to_databricks(field: Field) -> None | str:
174
183
  if type.lower() in ["time"]:
175
184
  return "STRING"
176
185
  if type.lower() in ["number", "decimal", "numeric"]:
177
- # precision and scale not supported by data contract
178
- return "DECIMAL"
186
+ precision = field.precision if field.precision is not None else 38
187
+ scale = field.scale if field.scale is not None else 0
188
+ return f"DECIMAL({precision},{scale})"
179
189
  if type.lower() in ["float"]:
180
190
  return "FLOAT"
181
191
  if type.lower() in ["double"]:
@@ -190,13 +200,15 @@ def convert_to_databricks(field: Field) -> None | str:
190
200
  nested_fields = []
191
201
  for nested_field_name, nested_field in field.fields.items():
192
202
  nested_field_type = convert_to_databricks(nested_field)
193
- nested_fields.append(f"{nested_field_name} {nested_field_type}")
194
- return f"STRUCT<{', '.join(nested_fields)}>"
203
+ nested_fields.append(f"{nested_field_name}:{nested_field_type}")
204
+ return f"STRUCT<{','.join(nested_fields)}>"
195
205
  if type.lower() in ["bytes"]:
196
206
  return "BINARY"
197
207
  if type.lower() in ["array"]:
198
208
  item_type = convert_to_databricks(field.items)
199
209
  return f"ARRAY<{item_type}>"
210
+ if type.lower() in ["variant"]:
211
+ return "VARIANT"
200
212
  return None
201
213
 
202
214
 
@@ -55,7 +55,6 @@ def import_avro(data_contract_specification: DataContractSpecification, source:
55
55
  engine="datacontract",
56
56
  original_exception=e,
57
57
  )
58
-
59
58
  # type record is being used for both the table and the object types in data contract
60
59
  # -> CONSTRAINT: one table per .avsc input, all nested records are interpreted as objects
61
60
  fields = import_record_fields(avro_schema.fields)
@@ -92,6 +91,20 @@ def handle_config_avro_custom_properties(field: avro.schema.Field, imported_fiel
92
91
  imported_field.config["avroDefault"] = field.default
93
92
 
94
93
 
94
+ LOGICAL_TYPE_MAPPING = {
95
+ "decimal": "decimal",
96
+ "date": "date",
97
+ "time-millis": "time",
98
+ "time-micros": "time",
99
+ "timestamp-millis": "timestamp_tz",
100
+ "timestamp-micros": "timestamp_tz",
101
+ "local-timestamp-micros": "timestamp_ntz",
102
+ "local-timestamp-millis": "timestamp_ntz",
103
+ "duration": "string",
104
+ "uuid": "string",
105
+ }
106
+
107
+
95
108
  def import_record_fields(record_fields: List[avro.schema.Field]) -> Dict[str, Field]:
96
109
  """
97
110
  Import Avro record fields and convert them to data contract fields.
@@ -117,13 +130,23 @@ def import_record_fields(record_fields: List[avro.schema.Field]) -> Dict[str, Fi
117
130
  imported_field.fields = import_record_fields(field.type.fields)
118
131
  elif field.type.type == "union":
119
132
  imported_field.required = False
120
- type = import_type_of_optional_field(field)
121
- imported_field.type = type
122
- if type == "record":
123
- imported_field.fields = import_record_fields(get_record_from_union_field(field).fields)
124
- elif type == "array":
125
- imported_field.type = "array"
126
- imported_field.items = import_avro_array_items(get_array_from_union_field(field))
133
+ # Check for enum in union first, since it needs special handling
134
+ enum_schema = get_enum_from_union_field(field)
135
+ if enum_schema:
136
+ imported_field.type = "string"
137
+ imported_field.enum = enum_schema.symbols
138
+ imported_field.title = enum_schema.name
139
+ if not imported_field.config:
140
+ imported_field.config = {}
141
+ imported_field.config["avroType"] = "enum"
142
+ else:
143
+ type = import_type_of_optional_field(field)
144
+ imported_field.type = type
145
+ if type == "record":
146
+ imported_field.fields = import_record_fields(get_record_from_union_field(field).fields)
147
+ elif type == "array":
148
+ imported_field.type = "array"
149
+ imported_field.items = import_avro_array_items(get_array_from_union_field(field))
127
150
  elif field.type.type == "array":
128
151
  imported_field.type = "array"
129
152
  imported_field.items = import_avro_array_items(field.type)
@@ -137,9 +160,15 @@ def import_record_fields(record_fields: List[avro.schema.Field]) -> Dict[str, Fi
137
160
  if not imported_field.config:
138
161
  imported_field.config = {}
139
162
  imported_field.config["avroType"] = "enum"
140
- else: # primitive type
141
- imported_field.type = map_type_from_avro(field.type.type)
142
-
163
+ else:
164
+ logical_type = field.type.get_prop("logicalType")
165
+ if logical_type in LOGICAL_TYPE_MAPPING:
166
+ imported_field.type = LOGICAL_TYPE_MAPPING[logical_type]
167
+ if logical_type == "decimal":
168
+ imported_field.precision = field.type.precision
169
+ imported_field.scale = field.type.scale
170
+ else:
171
+ imported_field.type = map_type_from_avro(field.type.type)
143
172
  imported_fields[field.name] = imported_field
144
173
 
145
174
  return imported_fields
@@ -212,7 +241,11 @@ def import_type_of_optional_field(field: avro.schema.Field) -> str:
212
241
  """
213
242
  for field_type in field.type.schemas:
214
243
  if field_type.type != "null":
215
- return map_type_from_avro(field_type.type)
244
+ logical_type = field_type.get_prop("logicalType")
245
+ if logical_type and logical_type in LOGICAL_TYPE_MAPPING:
246
+ return LOGICAL_TYPE_MAPPING[logical_type]
247
+ else:
248
+ return map_type_from_avro(field_type.type)
216
249
  raise DataContractException(
217
250
  type="schema",
218
251
  result="failed",
@@ -254,6 +287,22 @@ def get_array_from_union_field(field: avro.schema.Field) -> avro.schema.ArraySch
254
287
  return None
255
288
 
256
289
 
290
+ def get_enum_from_union_field(field: avro.schema.Field) -> avro.schema.EnumSchema | None:
291
+ """
292
+ Get the enum schema from a union field.
293
+
294
+ Args:
295
+ field: The Avro field with a union type.
296
+
297
+ Returns:
298
+ The enum schema if found, None otherwise.
299
+ """
300
+ for field_type in field.type.schemas:
301
+ if field_type.type == "enum":
302
+ return field_type
303
+ return None
304
+
305
+
257
306
  def map_type_from_avro(avro_type_str: str) -> str:
258
307
  """
259
308
  Map Avro type strings to data contract type strings.
@@ -276,6 +325,8 @@ def map_type_from_avro(avro_type_str: str) -> str:
276
325
  return "binary"
277
326
  elif avro_type_str == "double":
278
327
  return "double"
328
+ elif avro_type_str == "float":
329
+ return "float"
279
330
  elif avro_type_str == "int":
280
331
  return "int"
281
332
  elif avro_type_str == "long":