datacontract-cli 0.10.22__py3-none-any.whl → 0.10.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datacontract-cli might be problematic. Click here for more details.

Files changed (39) hide show
  1. datacontract/__init__.py +13 -0
  2. datacontract/catalog/catalog.py +2 -2
  3. datacontract/cli.py +20 -72
  4. datacontract/data_contract.py +5 -3
  5. datacontract/engines/data_contract_test.py +32 -7
  6. datacontract/engines/datacontract/check_that_datacontract_contains_valid_servers_configuration.py +2 -3
  7. datacontract/engines/fastjsonschema/s3/s3_read_files.py +3 -2
  8. datacontract/engines/soda/check_soda_execute.py +17 -4
  9. datacontract/engines/soda/connections/{duckdb.py → duckdb_connection.py} +66 -9
  10. datacontract/engines/soda/connections/kafka.py +3 -2
  11. datacontract/export/avro_converter.py +10 -3
  12. datacontract/export/bigquery_converter.py +1 -1
  13. datacontract/export/dbt_converter.py +13 -10
  14. datacontract/export/duckdb_type_converter.py +57 -0
  15. datacontract/export/odcs_v3_exporter.py +27 -7
  16. datacontract/export/protobuf_converter.py +163 -69
  17. datacontract/imports/avro_importer.py +31 -6
  18. datacontract/imports/csv_importer.py +111 -57
  19. datacontract/imports/importer.py +1 -0
  20. datacontract/imports/importer_factory.py +5 -0
  21. datacontract/imports/odcs_v3_importer.py +49 -7
  22. datacontract/imports/protobuf_importer.py +266 -0
  23. datacontract/lint/resolve.py +40 -12
  24. datacontract/model/data_contract_specification.py +2 -2
  25. datacontract/model/run.py +3 -0
  26. datacontract/output/__init__.py +0 -0
  27. datacontract/output/junit_test_results.py +135 -0
  28. datacontract/output/output_format.py +10 -0
  29. datacontract/output/test_results_writer.py +79 -0
  30. datacontract/templates/datacontract.html +2 -1
  31. datacontract/templates/index.html +2 -1
  32. {datacontract_cli-0.10.22.dist-info → datacontract_cli-0.10.24.dist-info}/METADATA +279 -193
  33. {datacontract_cli-0.10.22.dist-info → datacontract_cli-0.10.24.dist-info}/RECORD +37 -33
  34. {datacontract_cli-0.10.22.dist-info → datacontract_cli-0.10.24.dist-info}/WHEEL +1 -1
  35. datacontract/export/csv_type_converter.py +0 -36
  36. datacontract/lint/linters/quality_schema_linter.py +0 -52
  37. {datacontract_cli-0.10.22.dist-info → datacontract_cli-0.10.24.dist-info}/entry_points.txt +0 -0
  38. {datacontract_cli-0.10.22.dist-info → datacontract_cli-0.10.24.dist-info/licenses}/LICENSE +0 -0
  39. {datacontract_cli-0.10.22.dist-info → datacontract_cli-0.10.24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,57 @@
1
+ from typing import Dict
2
+
3
+ from datacontract.model.data_contract_specification import Field
4
+
5
+
6
+ # https://duckdb.org/docs/data/csv/overview.html
7
+ # ['SQLNULL', 'BOOLEAN', 'BIGINT', 'DOUBLE', 'TIME', 'DATE', 'TIMESTAMP', 'VARCHAR']
8
+ def convert_to_duckdb_csv_type(field) -> None | str:
9
+ datacontract_type = field.type
10
+ if datacontract_type is None:
11
+ return "VARCHAR"
12
+ if datacontract_type.lower() in ["string", "varchar", "text"]:
13
+ return "VARCHAR"
14
+ if datacontract_type.lower() in ["timestamp", "timestamp_tz"]:
15
+ return "TIMESTAMP"
16
+ if datacontract_type.lower() in ["timestamp_ntz"]:
17
+ return "TIMESTAMP"
18
+ if datacontract_type.lower() in ["date"]:
19
+ return "DATE"
20
+ if datacontract_type.lower() in ["time"]:
21
+ return "TIME"
22
+ if datacontract_type.lower() in ["number", "decimal", "numeric"]:
23
+ # precision and scale not supported by data contract
24
+ return "VARCHAR"
25
+ if datacontract_type.lower() in ["float", "double"]:
26
+ return "DOUBLE"
27
+ if datacontract_type.lower() in ["integer", "int", "long", "bigint"]:
28
+ return "BIGINT"
29
+ if datacontract_type.lower() in ["boolean"]:
30
+ return "BOOLEAN"
31
+ if datacontract_type.lower() in ["object", "record", "struct"]:
32
+ # not supported in CSV
33
+ return "VARCHAR"
34
+ if datacontract_type.lower() in ["bytes"]:
35
+ # not supported in CSV
36
+ return "VARCHAR"
37
+ if datacontract_type.lower() in ["array"]:
38
+ return "VARCHAR"
39
+ if datacontract_type.lower() in ["null"]:
40
+ return "SQLNULL"
41
+ return "VARCHAR"
42
+
43
+
44
+ def convert_to_duckdb_json_type(field: Field) -> None | str:
45
+ datacontract_type = field.type
46
+ if datacontract_type is None:
47
+ return "VARCHAR"
48
+ if datacontract_type.lower() in ["array"]:
49
+ return convert_to_duckdb_json_type(field.items) + "[]" # type: ignore
50
+ if datacontract_type.lower() in ["object", "record", "struct"]:
51
+ return convert_to_duckdb_object(field.fields)
52
+ return convert_to_duckdb_csv_type(field)
53
+
54
+
55
+ def convert_to_duckdb_object(fields: Dict[str, Field]):
56
+ columns = [f'"{x[0]}" {convert_to_duckdb_json_type(x[1])}' for x in fields.items()]
57
+ return f"STRUCT({', '.join(columns)})"
@@ -13,13 +13,12 @@ class OdcsV3Exporter(Exporter):
13
13
 
14
14
  def to_odcs_v3_yaml(data_contract_spec: DataContractSpecification) -> str:
15
15
  odcs = {
16
- "apiVersion": "v3.0.0",
16
+ "apiVersion": "v3.0.1",
17
17
  "kind": "DataContract",
18
18
  "id": data_contract_spec.id,
19
19
  "name": data_contract_spec.info.title,
20
20
  "version": data_contract_spec.info.version,
21
- "domain": data_contract_spec.info.owner,
22
- "status": data_contract_spec.info.status,
21
+ "status": to_status(data_contract_spec.info.status),
23
22
  }
24
23
 
25
24
  if data_contract_spec.terms is not None:
@@ -126,13 +125,15 @@ def to_odcs_v3_yaml(data_contract_spec: DataContractSpecification) -> str:
126
125
  odcs["servers"] = servers
127
126
 
128
127
  odcs["customProperties"] = []
128
+ if data_contract_spec.info.owner is not None:
129
+ odcs["customProperties"].append({"property": "owner", "value": data_contract_spec.info.owner})
129
130
  if data_contract_spec.info.model_extra is not None:
130
131
  for key, value in data_contract_spec.info.model_extra.items():
131
132
  odcs["customProperties"].append({"property": key, "value": value})
132
133
  if len(odcs["customProperties"]) == 0:
133
134
  del odcs["customProperties"]
134
135
 
135
- return yaml.dump(odcs, indent=2, sort_keys=False, allow_unicode=True)
136
+ return yaml.safe_dump(odcs, indent=2, sort_keys=False, allow_unicode=True)
136
137
 
137
138
 
138
139
  def to_odcs_schema(model_key, model_value: Model) -> dict:
@@ -217,13 +218,13 @@ def to_property(field_name: str, field: Field) -> dict:
217
218
  if field.description is not None:
218
219
  property["description"] = field.description
219
220
  if field.required is not None:
220
- property["isNullable"] = not field.required
221
+ property["required"] = field.required
221
222
  if field.unique is not None:
222
- property["isUnique"] = field.unique
223
+ property["unique"] = field.unique
223
224
  if field.classification is not None:
224
225
  property["classification"] = field.classification
225
226
  if field.examples is not None:
226
- property["examples"] = field.examples
227
+ property["examples"] = field.examples.copy()
227
228
  if field.example is not None:
228
229
  property["examples"] = [field.example]
229
230
  if field.primaryKey is not None and field.primaryKey:
@@ -312,3 +313,22 @@ def to_odcs_quality(quality):
312
313
  if quality.implementation is not None:
313
314
  quality_dict["implementation"] = quality.implementation
314
315
  return quality_dict
316
+
317
+
318
+ def to_status(status):
319
+ """Convert the data contract status to ODCS v3 format."""
320
+ if status is None:
321
+ return "draft" # Default to draft if no status is provided
322
+
323
+ # Valid status values according to ODCS v3.0.1 spec
324
+ valid_statuses = ["proposed", "draft", "active", "deprecated", "retired"]
325
+
326
+ # Convert to lowercase for comparison
327
+ status_lower = status.lower()
328
+
329
+ # If status is already valid, return it as is
330
+ if status_lower in valid_statuses:
331
+ return status_lower
332
+
333
+ # Default to "draft" for any non-standard status
334
+ return "draft"
@@ -4,102 +4,196 @@ from datacontract.model.data_contract_specification import DataContractSpecifica
4
4
 
5
5
  class ProtoBufExporter(Exporter):
6
6
  def export(self, data_contract, model, server, sql_server_type, export_args) -> dict:
7
- return to_protobuf(data_contract)
7
+ # Returns a dict containing the protobuf representation.
8
+ proto = to_protobuf(data_contract)
9
+ return {"protobuf": proto}
8
10
 
9
11
 
10
- def to_protobuf(data_contract_spec: DataContractSpecification):
12
+ def to_protobuf(data_contract_spec: DataContractSpecification) -> str:
13
+ """
14
+ Generates a Protobuf file from the data contract specification.
15
+ Scans all models for enum fields (even if the type is "string") by checking for a "values" property.
16
+ """
11
17
  messages = ""
18
+ enum_definitions = {}
19
+
20
+ # Iterate over all models to generate messages and collect enum definitions.
12
21
  for model_name, model in data_contract_spec.models.items():
13
- messages += to_protobuf_message(model_name, model.fields, model.description, 0)
22
+ for field_name, field in model.fields.items():
23
+ # If the field has enum values, collect them.
24
+ if _is_enum_field(field):
25
+ enum_name = _get_enum_name(field, field_name)
26
+ enum_values = _get_enum_values(field)
27
+ if enum_values and enum_name not in enum_definitions:
28
+ enum_definitions[enum_name] = enum_values
29
+
30
+ messages += to_protobuf_message(model_name, model.fields, getattr(model, "description", ""), 0)
14
31
  messages += "\n"
15
32
 
16
- result = f"""syntax = "proto3";
17
-
18
- {messages}
19
- """
20
-
21
- return result
22
-
23
-
24
- def _to_protobuf_message_name(model_name):
25
- return model_name[0].upper() + model_name[1:]
26
-
27
-
28
- def to_protobuf_message(model_name, fields, description, indent_level: int = 0):
33
+ # Build header with syntax and package declarations.
34
+ header = 'syntax = "proto3";\n\n'
35
+ package = getattr(data_contract_spec, "package", "example")
36
+ header += f"package {package};\n\n"
37
+
38
+ # Append enum definitions.
39
+ for enum_name, enum_values in enum_definitions.items():
40
+ header += f"// Enum for {enum_name}\n"
41
+ header += f"enum {enum_name} {{\n"
42
+ # Only iterate if enum_values is a dictionary.
43
+ if isinstance(enum_values, dict):
44
+ for enum_const, value in sorted(enum_values.items(), key=lambda item: item[1]):
45
+ normalized_const = enum_const.upper().replace(" ", "_")
46
+ header += f" {normalized_const} = {value};\n"
47
+ else:
48
+ header += f" // Warning: Enum values for {enum_name} are not a dictionary\n"
49
+ header += "}\n\n"
50
+ return header + messages
51
+
52
+
53
+ def _is_enum_field(field) -> bool:
54
+ """
55
+ Returns True if the field (dict or object) has a non-empty "values" property.
56
+ """
57
+ if isinstance(field, dict):
58
+ return bool(field.get("values"))
59
+ return bool(getattr(field, "values", None))
60
+
61
+
62
+ def _get_enum_name(field, field_name: str) -> str:
63
+ """
64
+ Returns the enum name either from the field's "enum_name" or derived from the field name.
65
+ """
66
+ if isinstance(field, dict):
67
+ return field.get("enum_name", _to_protobuf_message_name(field_name))
68
+ return getattr(field, "enum_name", None) or _to_protobuf_message_name(field_name)
69
+
70
+
71
+ def _get_enum_values(field) -> dict:
72
+ """
73
+ Returns the enum values from the field.
74
+ If the values are not a dictionary, attempts to extract enum attributes.
75
+ """
76
+ if isinstance(field, dict):
77
+ values = field.get("values", {})
78
+ else:
79
+ values = getattr(field, "values", {})
80
+
81
+ if not isinstance(values, dict):
82
+ # If values is a BaseModel (or similar) with a .dict() method, use it.
83
+ if hasattr(values, "dict") and callable(values.dict):
84
+ values_dict = values.dict()
85
+ return {k: v for k, v in values_dict.items() if k.isupper() and isinstance(v, int)}
86
+ else:
87
+ # Otherwise, iterate over attributes that look like enums.
88
+ return {
89
+ key: getattr(values, key)
90
+ for key in dir(values)
91
+ if key.isupper() and isinstance(getattr(values, key), int)
92
+ }
93
+ return values
94
+
95
+
96
+ def _to_protobuf_message_name(name: str) -> str:
97
+ """
98
+ Returns a valid Protobuf message/enum name by capitalizing the first letter.
99
+ """
100
+ return name[0].upper() + name[1:] if name else name
101
+
102
+
103
+ def to_protobuf_message(model_name: str, fields: dict, description: str, indent_level: int = 0) -> str:
104
+ """
105
+ Generates a Protobuf message definition from the model's fields.
106
+ Handles nested messages for complex types.
107
+ """
29
108
  result = ""
109
+ if description:
110
+ result += f"{indent(indent_level)}// {description}\n"
30
111
 
31
- if description is not None:
32
- result += f"""{indent(indent_level)}/* {description} */\n"""
33
-
34
- fields_protobuf = ""
112
+ result += f"message {_to_protobuf_message_name(model_name)} {{\n"
35
113
  number = 1
36
114
  for field_name, field in fields.items():
37
- if field.type in ["object", "record", "struct"]:
38
- fields_protobuf += (
39
- "\n".join(
40
- map(
41
- lambda x: " " + x,
42
- to_protobuf_message(field_name, field.fields, field.description, indent_level + 1).splitlines(),
43
- )
44
- )
45
- + "\n"
46
- )
47
-
48
- fields_protobuf += to_protobuf_field(field_name, field, field.description, number, 1) + "\n"
115
+ # For nested objects, generate a nested message.
116
+ field_type = _get_field_type(field)
117
+ if field_type in ["object", "record", "struct"]:
118
+ nested_desc = field.get("description", "") if isinstance(field, dict) else getattr(field, "description", "")
119
+ nested_fields = field.get("fields", {}) if isinstance(field, dict) else field.fields
120
+ nested_message = to_protobuf_message(field_name, nested_fields, nested_desc, indent_level + 1)
121
+ result += nested_message + "\n"
122
+
123
+ field_desc = field.get("description", "") if isinstance(field, dict) else getattr(field, "description", "")
124
+ result += to_protobuf_field(field_name, field, field_desc, number, indent_level + 1) + "\n"
49
125
  number += 1
50
- result += f"message {_to_protobuf_message_name(model_name)} {{\n{fields_protobuf}}}\n"
51
126
 
127
+ result += f"{indent(indent_level)}}}\n"
52
128
  return result
53
129
 
54
130
 
55
- def to_protobuf_field(field_name, field, description, number: int, indent_level: int = 0):
56
- optional = ""
57
- if not field.required:
58
- optional = "optional "
59
-
131
+ def to_protobuf_field(field_name: str, field, description: str, number: int, indent_level: int = 0) -> str:
132
+ """
133
+ Generates a field definition within a Protobuf message.
134
+ """
60
135
  result = ""
61
-
62
- if description is not None:
63
- result += f"""{indent(indent_level)}/* {description} */\n"""
64
-
65
- result += f"{indent(indent_level)}{optional}{_convert_type(field_name, field)} {field_name} = {number};"
66
-
136
+ if description:
137
+ result += f"{indent(indent_level)}// {description}\n"
138
+ result += f"{indent(indent_level)}{_convert_type(field_name, field)} {field_name} = {number};"
67
139
  return result
68
140
 
69
141
 
70
- def indent(indent_level):
142
+ def indent(indent_level: int) -> str:
71
143
  return " " * indent_level
72
144
 
73
145
 
74
- def _convert_type(field_name, field) -> None | str:
75
- type = field.type
76
- if type is None:
77
- return None
78
- if type.lower() in ["string", "varchar", "text"]:
79
- return "string"
80
- if type.lower() in ["timestamp", "timestamp_tz"]:
81
- return "string"
82
- if type.lower() in ["timestamp_ntz"]:
83
- return "string"
84
- if type.lower() in ["date"]:
146
+ def _get_field_type(field) -> str:
147
+ """
148
+ Retrieves the field type from the field definition.
149
+ """
150
+ if isinstance(field, dict):
151
+ return field.get("type", "").lower()
152
+ return getattr(field, "type", "").lower()
153
+
154
+
155
+ def _convert_type(field_name: str, field) -> str:
156
+ """
157
+ Converts a field's type (from the data contract) to a Protobuf type.
158
+ Prioritizes enum conversion if a non-empty "values" property exists.
159
+ """
160
+ # For debugging purposes
161
+ print("Converting field:", field_name)
162
+ # If the field should be treated as an enum, return its enum name.
163
+ if _is_enum_field(field):
164
+ return _get_enum_name(field, field_name)
165
+
166
+ lower_type = _get_field_type(field)
167
+ if lower_type in ["string", "varchar", "text"]:
85
168
  return "string"
86
- if type.lower() in ["time"]:
169
+ if lower_type in ["timestamp", "timestamp_tz", "timestamp_ntz", "date", "time"]:
87
170
  return "string"
88
- if type.lower() in ["number", "decimal", "numeric"]:
171
+ if lower_type in ["number", "decimal", "numeric"]:
89
172
  return "double"
90
- if type.lower() in ["float", "double"]:
91
- return type.lower()
92
- if type.lower() in ["integer", "int"]:
173
+ if lower_type in ["float", "double"]:
174
+ return lower_type
175
+ if lower_type in ["integer", "int"]:
93
176
  return "int32"
94
- if type.lower() in ["long", "bigint"]:
177
+ if lower_type in ["long", "bigint"]:
95
178
  return "int64"
96
- if type.lower() in ["boolean"]:
179
+ if lower_type in ["boolean"]:
97
180
  return "bool"
98
- if type.lower() in ["bytes"]:
181
+ if lower_type in ["bytes"]:
99
182
  return "bytes"
100
- if type.lower() in ["object", "record", "struct"]:
183
+ if lower_type in ["object", "record", "struct"]:
101
184
  return _to_protobuf_message_name(field_name)
102
- if type.lower() in ["array"]:
103
- # TODO spec is missing arrays
104
- return "repeated string"
105
- return None
185
+ if lower_type == "array":
186
+ # Handle array types. Check for an "items" property.
187
+ items = field.get("items") if isinstance(field, dict) else getattr(field, "items", None)
188
+ if items and isinstance(items, dict) and items.get("type"):
189
+ item_type = items.get("type", "").lower()
190
+ if item_type in ["object", "record", "struct"]:
191
+ # Singularize the field name (a simple approach).
192
+ singular = field_name[:-1] if field_name.endswith("s") else field_name
193
+ return "repeated " + _to_protobuf_message_name(singular)
194
+ else:
195
+ return "repeated " + _convert_type(field_name, items)
196
+ else:
197
+ return "repeated string"
198
+ # Fallback for unrecognized types.
199
+ return "string"
@@ -55,8 +55,7 @@ def import_avro(data_contract_specification: DataContractSpecification, source:
55
55
  engine="datacontract",
56
56
  original_exception=e,
57
57
  )
58
-
59
- # type record is being used for both the table and the object types in data contract
58
+ # type record is being used for both the table and the object types in data contract
60
59
  # -> CONSTRAINT: one table per .avsc input, all nested records are interpreted as objects
61
60
  fields = import_record_fields(avro_schema.fields)
62
61
 
@@ -92,6 +91,20 @@ def handle_config_avro_custom_properties(field: avro.schema.Field, imported_fiel
92
91
  imported_field.config["avroDefault"] = field.default
93
92
 
94
93
 
94
+ LOGICAL_TYPE_MAPPING = {
95
+ "decimal": "decimal",
96
+ "date": "date",
97
+ "time-millis": "time",
98
+ "time-micros": "time",
99
+ "timestamp-millis": "timestamp_tz",
100
+ "timestamp-micros": "timestamp_tz",
101
+ "local-timestamp-micros": "timestamp_ntz",
102
+ "local-timestamp-millis": "timestamp_ntz",
103
+ "duration": "string",
104
+ "uuid": "string",
105
+ }
106
+
107
+
95
108
  def import_record_fields(record_fields: List[avro.schema.Field]) -> Dict[str, Field]:
96
109
  """
97
110
  Import Avro record fields and convert them to data contract fields.
@@ -137,9 +150,15 @@ def import_record_fields(record_fields: List[avro.schema.Field]) -> Dict[str, Fi
137
150
  if not imported_field.config:
138
151
  imported_field.config = {}
139
152
  imported_field.config["avroType"] = "enum"
140
- else: # primitive type
141
- imported_field.type = map_type_from_avro(field.type.type)
142
-
153
+ else:
154
+ logical_type = field.type.get_prop("logicalType")
155
+ if logical_type in LOGICAL_TYPE_MAPPING:
156
+ imported_field.type = LOGICAL_TYPE_MAPPING[logical_type]
157
+ if logical_type == "decimal":
158
+ imported_field.precision = field.type.precision
159
+ imported_field.scale = field.type.scale
160
+ else:
161
+ imported_field.type = map_type_from_avro(field.type.type)
143
162
  imported_fields[field.name] = imported_field
144
163
 
145
164
  return imported_fields
@@ -212,7 +231,11 @@ def import_type_of_optional_field(field: avro.schema.Field) -> str:
212
231
  """
213
232
  for field_type in field.type.schemas:
214
233
  if field_type.type != "null":
215
- return map_type_from_avro(field_type.type)
234
+ logical_type = field_type.get_prop("logicalType")
235
+ if logical_type and logical_type in LOGICAL_TYPE_MAPPING:
236
+ return LOGICAL_TYPE_MAPPING[logical_type]
237
+ else:
238
+ return map_type_from_avro(field_type.type)
216
239
  raise DataContractException(
217
240
  type="schema",
218
241
  result="failed",
@@ -276,6 +299,8 @@ def map_type_from_avro(avro_type_str: str) -> str:
276
299
  return "binary"
277
300
  elif avro_type_str == "double":
278
301
  return "double"
302
+ elif avro_type_str == "float":
303
+ return "float"
279
304
  elif avro_type_str == "int":
280
305
  return "int"
281
306
  elif avro_type_str == "long":
@@ -1,89 +1,143 @@
1
1
  import os
2
+ from typing import Any, Dict, List
2
3
 
3
- import clevercsv
4
+ import duckdb
4
5
 
5
6
  from datacontract.imports.importer import Importer
6
- from datacontract.model.data_contract_specification import DataContractSpecification, Example, Field, Model, Server
7
+ from datacontract.model.data_contract_specification import DataContractSpecification, Model, Server
7
8
 
8
9
 
9
10
  class CsvImporter(Importer):
10
11
  def import_source(
11
12
  self, data_contract_specification: DataContractSpecification, source: str, import_args: dict
12
13
  ) -> DataContractSpecification:
13
- return import_csv(data_contract_specification, self.import_format, source)
14
+ return import_csv(data_contract_specification, source)
14
15
 
15
16
 
16
- def import_csv(data_contract_specification: DataContractSpecification, format: str, source: str):
17
- include_example = False
18
-
19
- # detect encoding and dialect
20
- encoding = clevercsv.encoding.get_encoding(source)
21
- with open(source, "r", newline="") as fp:
22
- dialect = clevercsv.Sniffer().sniff(fp.read(10000))
23
-
24
- # using auto detecting of the format and encoding
25
- df = clevercsv.read_dataframe(source)
26
-
27
- if data_contract_specification.models is None:
28
- data_contract_specification.models = {}
29
-
17
+ def import_csv(
18
+ data_contract_specification: DataContractSpecification, source: str, include_examples: bool = False
19
+ ) -> DataContractSpecification:
30
20
  # use the file name as table name
31
21
  table_name = os.path.splitext(os.path.basename(source))[0]
32
22
 
23
+ # use duckdb to auto detect format, columns, etc.
24
+ con = duckdb.connect(database=":memory:")
25
+ con.sql(
26
+ f"""CREATE VIEW "{table_name}" AS SELECT * FROM read_csv_auto('{source}', hive_partitioning=1, auto_type_candidates = ['BOOLEAN', 'INTEGER', 'BIGINT', 'DOUBLE', 'VARCHAR']);"""
27
+ )
28
+ dialect = con.sql(f"SELECT * FROM sniff_csv('{source}', sample_size = 1000);").fetchnumpy()
29
+ tbl = con.table(table_name)
30
+
33
31
  if data_contract_specification.servers is None:
34
32
  data_contract_specification.servers = {}
35
33
 
34
+ delimiter = None if dialect is None else dialect['Delimiter'][0]
35
+
36
+ if dialect is not None:
37
+ dc_types = [map_type_from_duckdb(x["type"]) for x in dialect['Columns'][0]]
38
+ else:
39
+ dc_types = [map_type_from_duckdb(str(x)) for x in tbl.dtypes]
40
+
36
41
  data_contract_specification.servers["production"] = Server(
37
- type="local", path=source, format="csv", delimiter=dialect.delimiter
42
+ type="local", path=source, format="csv", delimiter=delimiter
38
43
  )
39
44
 
45
+ rowcount = tbl.shape[0]
46
+
47
+ tallies = dict()
48
+ for row in tbl.describe().fetchall():
49
+ if row[0] not in ["count", "max", "min"]:
50
+ continue
51
+ for i in range(tbl.shape[1]):
52
+ tallies[(row[0], tbl.columns[i])] = row[i + 1] if row[0] != "count" else int(row[i + 1])
53
+
54
+ samples: Dict[str, List] = dict()
55
+ for i in range(tbl.shape[1]):
56
+ field_name = tbl.columns[i]
57
+ if tallies[("count", field_name)] > 0 and tbl.dtypes[i] not in ["BOOLEAN", "BLOB"]:
58
+ sql = f"""SELECT DISTINCT "{field_name}" FROM "{table_name}" WHERE "{field_name}" IS NOT NULL USING SAMPLE 5 ROWS;"""
59
+ samples[field_name] = [x[0] for x in con.sql(sql).fetchall()]
60
+
61
+ formats: Dict[str, str] = dict()
62
+ for i in range(tbl.shape[1]):
63
+ field_name = tbl.columns[i]
64
+ if tallies[("count", field_name)] > 0 and tbl.dtypes[i] == "VARCHAR":
65
+ sql = f"""SELECT
66
+ count_if("{field_name}" IS NOT NULL) as count,
67
+ count_if(regexp_matches("{field_name}", '^[\\w-\\.]+@([\\w-]+\\.)+[\\w-]{{2,4}}$')) as email,
68
+ count_if(regexp_matches("{field_name}", '^[[a-z0-9]{{8}}-?[a-z0-9]{{4}}-?[a-z0-9]{{4}}-?[a-z0-9]{{4}}-?[a-z0-9]{{12}}]')) as uuid
69
+ FROM "{table_name}";
70
+ """
71
+ res = con.sql(sql).fetchone()
72
+ if res[1] == res[0]:
73
+ formats[field_name] = "email"
74
+ elif res[2] == res[0]:
75
+ formats[field_name] = "uuid"
76
+
40
77
  fields = {}
41
- for column, dtype in df.dtypes.items():
42
- field = Field()
43
- field.type = map_type_from_pandas(dtype.name)
44
- fields[column] = field
78
+ for i in range(tbl.shape[1]):
79
+ field_name = tbl.columns[i]
80
+ dc_type = dc_types[i]
81
+
82
+ ## specifying "integer" rather than "bigint" looks nicer
83
+ if (
84
+ dc_type == "bigint"
85
+ and tallies[("max", field_name)] <= 2147483647
86
+ and tallies[("min", field_name)] >= -2147483648
87
+ ):
88
+ dc_type = "integer"
89
+
90
+ field: Dict[str, Any] = {"type": dc_type, "format": formats.get(field_name, None)}
91
+
92
+ if tallies[("count", field_name)] == rowcount:
93
+ field["required"] = True
94
+ if dc_type not in ["boolean", "bytes"]:
95
+ distinct_values = tbl.count(f'DISTINCT "{field_name}"').fetchone()[0] # type: ignore
96
+ if distinct_values > 0 and distinct_values == tallies[("count", field_name)]:
97
+ field["unique"] = True
98
+ s = samples.get(field_name, None)
99
+ if s is not None:
100
+ field["examples"] = s
101
+ if dc_type in ["integer", "bigint", "float", "double"]:
102
+ field["minimum"] = tallies[("min", field_name)]
103
+ field["maximum"] = tallies[("max", field_name)]
104
+
105
+ fields[field_name] = field
106
+
107
+ model_examples = None
108
+ if include_examples:
109
+ model_examples = con.sql(f"""SELECT DISTINCT * FROM "{table_name}" USING SAMPLE 5 ROWS;""").fetchall()
45
110
 
46
111
  data_contract_specification.models[table_name] = Model(
47
- type="table",
48
- description=f"Csv file with encoding {encoding}",
49
- fields=fields,
112
+ type="table", description="Generated model of " + source, fields=fields, examples=model_examples
50
113
  )
51
114
 
52
- # multiline data is not correctly handled by yaml dump
53
- if include_example:
54
- if data_contract_specification.examples is None:
55
- data_contract_specification.examples = []
56
-
57
- # read first 10 lines with the detected encoding
58
- with open(source, "r", encoding=encoding) as csvfile:
59
- lines = csvfile.readlines()[:10]
60
-
61
- data_contract_specification.examples.append(Example(type="csv", model=table_name, data="".join(lines)))
62
-
63
115
  return data_contract_specification
64
116
 
65
117
 
66
- def map_type_from_pandas(sql_type: str):
118
+ _duck_db_types = {
119
+ "BOOLEAN": "boolean",
120
+ "BLOB": "bytes",
121
+ "TINYINT": "integer",
122
+ "SMALLINT": "integer",
123
+ "INTEGER": "integer",
124
+ "BIGINT": "bigint",
125
+ "UTINYINT": "integer",
126
+ "USMALLINT": "integer",
127
+ "UINTEGER": "integer",
128
+ "UBIGINT": "bigint",
129
+ "FLOAT": "float",
130
+ "DOUBLE": "double",
131
+ "VARCHAR": "string",
132
+ "TIMESTAMP": "timestamp",
133
+ "DATE": "date",
134
+ # TODO: Add support for NULL
135
+ }
136
+
137
+
138
+ def map_type_from_duckdb(sql_type: None | str):
67
139
  if sql_type is None:
68
140
  return None
69
141
 
70
- sql_type_normed = sql_type.lower().strip()
71
-
72
- if sql_type_normed == "object":
73
- return "string"
74
- elif sql_type_normed.startswith("str"):
75
- return "string"
76
- elif sql_type_normed.startswith("int"):
77
- return "integer"
78
- elif sql_type_normed.startswith("float"):
79
- return "float"
80
- elif sql_type_normed.startswith("bool"):
81
- return "boolean"
82
- elif sql_type_normed.startswith("timestamp"):
83
- return "timestamp"
84
- elif sql_type_normed == "datetime64":
85
- return "date"
86
- elif sql_type_normed == "timedelta[ns]":
87
- return "timestamp_ntz"
88
- else:
89
- return "variant"
142
+ sql_type_normed = sql_type.upper().strip()
143
+ return _duck_db_types.get(sql_type_normed, "string")