datacontract-cli 0.10.27__py3-none-any.whl → 0.10.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datacontract-cli might be problematic. Click here for more details.

Files changed (37) hide show
  1. datacontract/api.py +1 -1
  2. datacontract/cli.py +37 -5
  3. datacontract/data_contract.py +122 -29
  4. datacontract/engines/data_contract_checks.py +2 -0
  5. datacontract/engines/soda/connections/duckdb_connection.py +1 -1
  6. datacontract/export/html_exporter.py +28 -23
  7. datacontract/export/mermaid_exporter.py +78 -13
  8. datacontract/export/odcs_v3_exporter.py +7 -9
  9. datacontract/export/rdf_converter.py +2 -2
  10. datacontract/export/sql_type_converter.py +2 -2
  11. datacontract/imports/excel_importer.py +7 -2
  12. datacontract/imports/importer.py +11 -1
  13. datacontract/imports/importer_factory.py +7 -0
  14. datacontract/imports/json_importer.py +325 -0
  15. datacontract/imports/odcs_importer.py +2 -2
  16. datacontract/imports/odcs_v3_importer.py +9 -9
  17. datacontract/imports/spark_importer.py +38 -16
  18. datacontract/imports/sql_importer.py +4 -2
  19. datacontract/imports/unity_importer.py +77 -37
  20. datacontract/init/init_template.py +1 -1
  21. datacontract/integration/datamesh_manager.py +16 -2
  22. datacontract/lint/resolve.py +61 -7
  23. datacontract/lint/schema.py +1 -1
  24. datacontract/schemas/datacontract-1.1.0.init.yaml +1 -1
  25. datacontract/schemas/datacontract-1.2.0.init.yaml +91 -0
  26. datacontract/schemas/datacontract-1.2.0.schema.json +2029 -0
  27. datacontract/templates/datacontract.html +4 -0
  28. datacontract/templates/datacontract_odcs.html +666 -0
  29. datacontract/templates/index.html +2 -0
  30. datacontract/templates/partials/server.html +2 -0
  31. datacontract/templates/style/output.css +319 -145
  32. {datacontract_cli-0.10.27.dist-info → datacontract_cli-0.10.29.dist-info}/METADATA +98 -62
  33. {datacontract_cli-0.10.27.dist-info → datacontract_cli-0.10.29.dist-info}/RECORD +37 -33
  34. {datacontract_cli-0.10.27.dist-info → datacontract_cli-0.10.29.dist-info}/WHEEL +1 -1
  35. {datacontract_cli-0.10.27.dist-info → datacontract_cli-0.10.29.dist-info}/entry_points.txt +0 -0
  36. {datacontract_cli-0.10.27.dist-info → datacontract_cli-0.10.29.dist-info}/licenses/LICENSE +0 -0
  37. {datacontract_cli-0.10.27.dist-info → datacontract_cli-0.10.29.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@ class Importer(ABC):
12
12
  @abstractmethod
13
13
  def import_source(
14
14
  self,
15
- data_contract_specification: DataContractSpecification,
15
+ data_contract_specification: DataContractSpecification | OpenDataContractStandard,
16
16
  source: str,
17
17
  import_args: dict,
18
18
  ) -> DataContractSpecification | OpenDataContractStandard:
@@ -26,6 +26,7 @@ class ImportFormat(str, Enum):
26
26
  dbml = "dbml"
27
27
  glue = "glue"
28
28
  jsonschema = "jsonschema"
29
+ json = "json"
29
30
  bigquery = "bigquery"
30
31
  odcs = "odcs"
31
32
  unity = "unity"
@@ -39,3 +40,12 @@ class ImportFormat(str, Enum):
39
40
  @classmethod
40
41
  def get_supported_formats(cls):
41
42
  return list(map(lambda c: c.value, cls))
43
+
44
+
45
+ class Spec(str, Enum):
46
+ datacontract_specification = "datacontract_specification"
47
+ odcs = "odcs"
48
+
49
+ @classmethod
50
+ def get_supported_types(cls):
51
+ return list(map(lambda c: c.value, cls))
@@ -119,3 +119,10 @@ importer_factory.register_lazy_importer(
119
119
  module_path="datacontract.imports.excel_importer",
120
120
  class_name="ExcelImporter",
121
121
  )
122
+
123
+
124
+ importer_factory.register_lazy_importer(
125
+ name=ImportFormat.json,
126
+ module_path="datacontract.imports.json_importer",
127
+ class_name="JsonImporter",
128
+ )
@@ -0,0 +1,325 @@
1
+ import json
2
+ import os
3
+ import re
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ from datacontract.imports.importer import Importer
7
+ from datacontract.model.data_contract_specification import DataContractSpecification, Model, Server
8
+
9
+
10
+ class JsonImporter(Importer):
11
+ def import_source(
12
+ self, data_contract_specification: DataContractSpecification, source: str, import_args: dict
13
+ ) -> DataContractSpecification:
14
+ return import_json(data_contract_specification, source)
15
+
16
+
17
+ def is_ndjson(file_path: str) -> bool:
18
+ """Check if a file contains newline-delimited JSON."""
19
+ with open(file_path, "r", encoding="utf-8") as file:
20
+ for _ in range(5):
21
+ line = file.readline().strip()
22
+ if not line:
23
+ continue
24
+ try:
25
+ json.loads(line)
26
+ return True
27
+ except json.JSONDecodeError:
28
+ break
29
+ return False
30
+
31
+
32
+ def import_json(
33
+ data_contract_specification: DataContractSpecification, source: str, include_examples: bool = False
34
+ ) -> DataContractSpecification:
35
+ # use the file name as base model name
36
+ base_model_name = os.path.splitext(os.path.basename(source))[0]
37
+
38
+ # check if file is newline-delimited JSON
39
+ if is_ndjson(source):
40
+ # load NDJSON data
41
+ json_data = []
42
+ with open(source, "r", encoding="utf-8") as file:
43
+ for line in file:
44
+ line = line.strip()
45
+ if line:
46
+ try:
47
+ json_data.append(json.loads(line))
48
+ except json.JSONDecodeError:
49
+ continue
50
+ else:
51
+ # load regular JSON data
52
+ with open(source, "r", encoding="utf-8") as file:
53
+ json_data = json.load(file)
54
+
55
+ if data_contract_specification.servers is None:
56
+ data_contract_specification.servers = {}
57
+
58
+ data_contract_specification.servers["production"] = Server(type="local", path=source, format="json")
59
+
60
+ # initialisation
61
+ models = {}
62
+
63
+ if isinstance(json_data, list) and json_data:
64
+ # Array of items
65
+ if all(isinstance(item, dict) for item in json_data[:5]):
66
+ # Array of objects, as table
67
+ fields = {}
68
+ for item in json_data[:20]:
69
+ for key, value in item.items():
70
+ field_def = generate_field_definition(value, key, base_model_name, models)
71
+ if key in fields:
72
+ fields[key] = merge_field_definitions(fields[key], field_def)
73
+ else:
74
+ fields[key] = field_def
75
+
76
+ models[base_model_name] = {
77
+ "type": "table",
78
+ "description": f"Generated from JSON array in {source}",
79
+ "fields": fields,
80
+ "examples": json_data[:3] if include_examples else None,
81
+ }
82
+ else:
83
+ # Simple array
84
+ item_type, item_format = infer_array_type(json_data[:20])
85
+ models[base_model_name] = {
86
+ "type": "array",
87
+ "description": f"Generated from JSON array in {source}",
88
+ "items": {"type": item_type, "format": item_format} if item_format else {"type": item_type},
89
+ "examples": [json_data[:5]] if include_examples else None,
90
+ }
91
+ elif isinstance(json_data, dict):
92
+ # Single object
93
+ fields = {}
94
+ for key, value in json_data.items():
95
+ fields[key] = generate_field_definition(value, key, base_model_name, models)
96
+
97
+ models[base_model_name] = {
98
+ "type": "object",
99
+ "description": f"Generated from JSON object in {source}",
100
+ "fields": fields,
101
+ "examples": [json_data] if include_examples else None,
102
+ }
103
+ else:
104
+ # Primitive value
105
+ field_type, field_format = determine_type_and_format(json_data)
106
+ models[base_model_name] = {
107
+ "type": field_type,
108
+ "description": f"Generated from JSON primitive in {source}",
109
+ "format": field_format,
110
+ "examples": [json_data] if include_examples and field_type != "boolean" else None,
111
+ }
112
+
113
+ for model_name, model_def in models.items():
114
+ model_type = model_def.pop("type")
115
+ data_contract_specification.models[model_name] = Model(type=model_type, **model_def)
116
+
117
+ return data_contract_specification
118
+
119
+
120
+ def generate_field_definition(
121
+ value: Any, field_name: str, parent_model: str, models: Dict[str, Dict[str, Any]]
122
+ ) -> Dict[str, Any]:
123
+ """Generate a field definition for a JSON value, creating nested models."""
124
+
125
+ if isinstance(value, dict):
126
+ # Handle object fields
127
+ fields = {}
128
+ for key, nested_value in value.items():
129
+ fields[key] = generate_field_definition(nested_value, key, parent_model, models)
130
+
131
+ return {"type": "object", "fields": fields}
132
+
133
+ elif isinstance(value, list):
134
+ # Handle array fields
135
+ if not value:
136
+ return {"type": "array", "items": {"type": "string"}}
137
+
138
+ if all(isinstance(item, dict) for item in value):
139
+ # Array of objects
140
+ fields = {}
141
+ for item in value:
142
+ for key, nested_value in item.items():
143
+ field_def = generate_field_definition(nested_value, key, parent_model, models)
144
+ if key in fields:
145
+ fields[key] = merge_field_definitions(fields[key], field_def)
146
+ else:
147
+ fields[key] = field_def
148
+
149
+ return {"type": "array", "items": {"type": "object", "fields": fields}}
150
+
151
+ elif all(isinstance(item, list) for item in value):
152
+ # Array of arrays
153
+ inner_type, inner_format = infer_array_type(value[0])
154
+ return {
155
+ "type": "array",
156
+ "items": {
157
+ "type": "array",
158
+ "items": {"type": inner_type, "format": inner_format} if inner_format else {"type": inner_type},
159
+ },
160
+ "examples": value[:5], # Include examples for nested arrays
161
+ }
162
+
163
+ else:
164
+ # Array of simple or mixed types
165
+ item_type, item_format = infer_array_type(value)
166
+ items_def = {"type": item_type}
167
+ if item_format:
168
+ items_def["format"] = item_format
169
+
170
+ field_def = {"type": "array", "items": items_def}
171
+
172
+ # Add examples if appropriate
173
+ sample_values = [item for item in value[:5] if item is not None]
174
+ if sample_values:
175
+ field_def["examples"] = sample_values
176
+
177
+ return field_def
178
+
179
+ else:
180
+ # Handle primitive types
181
+ field_type, field_format = determine_type_and_format(value)
182
+ field_def = {"type": field_type}
183
+ if field_format:
184
+ field_def["format"] = field_format
185
+
186
+ # Add examples
187
+ if value is not None and field_type != "boolean":
188
+ field_def["examples"] = [value]
189
+
190
+ return field_def
191
+
192
+
193
+ def infer_array_type(array: List) -> Tuple[str, Optional[str]]:
194
+ """Infer the common type of items in an array."""
195
+ if not array:
196
+ return "string", None
197
+
198
+ # if all items are dictionaries with the same structure
199
+ if all(isinstance(item, dict) for item in array):
200
+ return "object", None
201
+
202
+ # if all items are of the same primitive type
203
+ non_null_items = [item for item in array if item is not None]
204
+ if not non_null_items:
205
+ return "null", None
206
+
207
+ types_and_formats = [determine_type_and_format(item) for item in non_null_items]
208
+ types = {t for t, _ in types_and_formats}
209
+ formats = {f for _, f in types_and_formats if f is not None}
210
+
211
+ # simplify type combinations
212
+ if types == {"integer", "number"}:
213
+ return "number", None
214
+ if len(types) == 1:
215
+ type_name = next(iter(types))
216
+ format_name = next(iter(formats)) if len(formats) == 1 else None
217
+ return type_name, format_name
218
+ if all(t in {"string", "integer", "number", "boolean", "null"} for t in types):
219
+ # If all string values have the same format, keep it
220
+ if len(formats) == 1 and "string" in types:
221
+ return "string", next(iter(formats))
222
+ return "string", None
223
+
224
+ # Mixed types
225
+ return "string", None
226
+
227
+
228
+ def determine_type_and_format(value: Any) -> Tuple[str, Optional[str]]:
229
+ """determine the datacontract type and format for a JSON value."""
230
+ if value is None:
231
+ return "null", None
232
+ elif isinstance(value, bool):
233
+ return "boolean", None
234
+ elif isinstance(value, int):
235
+ return "integer", None
236
+ elif isinstance(value, float):
237
+ return "number", None
238
+ elif isinstance(value, str):
239
+ try:
240
+ if re.match(r"^\d{4}-\d{2}-\d{2}$", value):
241
+ return "string", "date"
242
+ elif re.match(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(\.\d+)?(Z|[+-]\d{2}:\d{2})?$", value):
243
+ return "string", "date-time"
244
+ elif re.match(r"^[\w\.-]+@([\w-]+\.)+[\w-]{2,4}$", value):
245
+ return "string", "email"
246
+ elif re.match(r"^[a-f0-9]{8}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{12}$", value.lower()):
247
+ return "string", "uuid"
248
+ else:
249
+ return "string", None
250
+ except re.error:
251
+ return "string", None
252
+ elif isinstance(value, dict):
253
+ return "object", None
254
+ elif isinstance(value, list):
255
+ return "array", None
256
+ else:
257
+ return "string", None
258
+
259
+
260
+ def merge_field_definitions(field1: Dict[str, Any], field2: Dict[str, Any]) -> Dict[str, Any]:
261
+ """Merge two field definitions."""
262
+ result = field1.copy()
263
+ if field1.get("type") == "object" and field2.get("type") != "object":
264
+ return field1
265
+ if field2.get("type") == "object" and field1.get("type") != "object":
266
+ return field2
267
+ # Handle type differences
268
+ if field1.get("type") != field2.get("type"):
269
+ type1, _ = field1.get("type", "string"), field1.get("format")
270
+ type2, _ = field2.get("type", "string"), field2.get("format")
271
+
272
+ if type1 == "integer" and type2 == "number" or type1 == "number" and type2 == "integer":
273
+ common_type = "number"
274
+ common_format = None
275
+ elif "string" in [type1, type2]:
276
+ common_type = "string"
277
+ common_format = None
278
+ elif all(t in ["string", "integer", "number", "boolean", "null"] for t in [type1, type2]):
279
+ common_type = "string"
280
+ common_format = None
281
+ elif type1 == "array" and type2 == "array":
282
+ # Handle mixed array types
283
+ items1 = field1.get("items", {})
284
+ items2 = field2.get("items", {})
285
+ if items1.get("type") == "object" or items2.get("type") == "object":
286
+ if items1.get("type") == "object" and items2.get("type") == "object":
287
+ merged_items = merge_field_definitions(items1, items2)
288
+ else:
289
+ merged_items = items1 if items1.get("type") == "object" else items2
290
+ return {"type": "array", "items": merged_items}
291
+ else:
292
+ merged_items = merge_field_definitions(items1, items2)
293
+ return {"type": "array", "items": merged_items}
294
+ else:
295
+ common_type = "array" if "array" in [type1, type2] else "object"
296
+ common_format = None
297
+
298
+ result["type"] = common_type
299
+ if common_format:
300
+ result["format"] = common_format
301
+ elif "format" in result:
302
+ del result["format"]
303
+
304
+ # Merge examples
305
+ if "examples" in field2:
306
+ if "examples" in result:
307
+ combined = result["examples"] + [ex for ex in field2["examples"] if ex not in result["examples"]]
308
+ result["examples"] = combined[:5] # Limit to 5 examples
309
+ else:
310
+ result["examples"] = field2["examples"]
311
+
312
+ # Handle nested structures
313
+ if result.get("type") == "array" and "items" in field1 and "items" in field2:
314
+ result["items"] = merge_field_definitions(field1["items"], field2["items"])
315
+ elif result.get("type") == "object" and "fields" in field1 and "fields" in field2:
316
+ # Merge fields from both objects
317
+ merged_fields = field1["fields"].copy()
318
+ for key, field_def in field2["fields"].items():
319
+ if key in merged_fields:
320
+ merged_fields[key] = merge_field_definitions(merged_fields[key], field_def)
321
+ else:
322
+ merged_fields[key] = field_def
323
+ result["fields"] = merged_fields
324
+
325
+ return result
@@ -48,9 +48,9 @@ def import_odcs(data_contract_specification: DataContractSpecification, source:
48
48
  engine="datacontract",
49
49
  )
50
50
  elif odcs_api_version.startswith("v3."):
51
- from datacontract.imports.odcs_v3_importer import import_odcs_v3
51
+ from datacontract.imports.odcs_v3_importer import import_odcs_v3_as_dcs
52
52
 
53
- return import_odcs_v3(data_contract_specification, source)
53
+ return import_odcs_v3_as_dcs(data_contract_specification, source)
54
54
  else:
55
55
  raise DataContractException(
56
56
  type="schema",
@@ -29,17 +29,18 @@ class OdcsImporter(Importer):
29
29
  def import_source(
30
30
  self, data_contract_specification: DataContractSpecification, source: str, import_args: dict
31
31
  ) -> DataContractSpecification:
32
- return import_odcs_v3(data_contract_specification, source)
32
+ return import_odcs_v3_as_dcs(data_contract_specification, source)
33
33
 
34
34
 
35
- def import_odcs_v3(data_contract_specification: DataContractSpecification, source: str) -> DataContractSpecification:
35
+ def import_odcs_v3_as_dcs(
36
+ data_contract_specification: DataContractSpecification, source: str
37
+ ) -> DataContractSpecification:
36
38
  source_str = read_resource(source)
37
- return import_odcs_v3_from_str(data_contract_specification, source_str)
39
+ odcs = parse_odcs_v3_from_str(source_str)
40
+ return import_from_odcs(data_contract_specification, odcs)
38
41
 
39
42
 
40
- def import_odcs_v3_from_str(
41
- data_contract_specification: DataContractSpecification, source_str: str
42
- ) -> DataContractSpecification:
43
+ def parse_odcs_v3_from_str(source_str):
43
44
  try:
44
45
  odcs = OpenDataContractStandard.from_string(source_str)
45
46
  except Exception as e:
@@ -50,11 +51,10 @@ def import_odcs_v3_from_str(
50
51
  engine="datacontract",
51
52
  original_exception=e,
52
53
  )
53
-
54
- return import_from_odcs_model(data_contract_specification, odcs)
54
+ return odcs
55
55
 
56
56
 
57
- def import_from_odcs_model(data_contract_specification, odcs):
57
+ def import_from_odcs(data_contract_specification: DataContractSpecification, odcs: OpenDataContractStandard):
58
58
  data_contract_specification.id = odcs.id
59
59
  data_contract_specification.info = import_info(odcs)
60
60
  data_contract_specification.servers = import_servers(odcs)
@@ -28,34 +28,52 @@ class SparkImporter(Importer):
28
28
  data_contract_specification: The data contract specification object.
29
29
  source: The source string indicating the Spark tables to read.
30
30
  import_args: Additional arguments for the import process.
31
-
32
31
  Returns:
33
32
  dict: The updated data contract specification.
34
33
  """
35
- return import_spark(data_contract_specification, source)
34
+ dataframe = import_args.get("dataframe", None)
35
+ description = import_args.get("description", None)
36
+ return import_spark(data_contract_specification, source, dataframe, description)
36
37
 
37
38
 
38
- def import_spark(data_contract_specification: DataContractSpecification, source: str) -> DataContractSpecification:
39
+ def import_spark(
40
+ data_contract_specification: DataContractSpecification,
41
+ source: str,
42
+ dataframe: DataFrame | None = None,
43
+ description: str | None = None,
44
+ ) -> DataContractSpecification:
39
45
  """
40
- Reads Spark tables and updates the data contract specification with their schemas.
46
+ Imports schema(s) from Spark into a Data Contract Specification.
41
47
 
42
48
  Args:
43
- data_contract_specification: The data contract specification to update.
44
- source: A comma-separated string of Spark temporary views to read.
49
+ data_contract_specification (DataContractSpecification): The contract spec to update.
50
+ source (str): Comma-separated Spark table/view names.
51
+ dataframe (DataFrame | None): Optional Spark DataFrame to import.
52
+ description (str | None): Optional table-level description.
45
53
 
46
54
  Returns:
47
- DataContractSpecification: The updated data contract specification.
55
+ DataContractSpecification: The updated contract spec with imported models.
48
56
  """
49
57
  spark = SparkSession.builder.getOrCreate()
50
58
  data_contract_specification.servers["local"] = Server(type="dataframe")
51
- for temp_view in source.split(","):
52
- temp_view = temp_view.strip()
53
- df = spark.read.table(temp_view)
54
- data_contract_specification.models[temp_view] = import_from_spark_df(spark, source, df)
59
+
60
+ if dataframe is not None:
61
+ if not isinstance(dataframe, DataFrame):
62
+ raise TypeError("Expected 'dataframe' to be a pyspark.sql.DataFrame")
63
+ data_contract_specification.models[source] = import_from_spark_df(spark, source, dataframe, description)
64
+ return data_contract_specification
65
+
66
+ if not source:
67
+ raise ValueError("Either 'dataframe' or a valid 'source' must be provided")
68
+
69
+ for table_name in map(str.strip, source.split(",")):
70
+ df = spark.read.table(table_name)
71
+ data_contract_specification.models[table_name] = import_from_spark_df(spark, table_name, df, description)
72
+
55
73
  return data_contract_specification
56
74
 
57
75
 
58
- def import_from_spark_df(spark: SparkSession, source: str, df: DataFrame) -> Model:
76
+ def import_from_spark_df(spark: SparkSession, source: str, df: DataFrame, description: str) -> Model:
59
77
  """
60
78
  Converts a Spark DataFrame into a Model.
61
79
 
@@ -63,6 +81,7 @@ def import_from_spark_df(spark: SparkSession, source: str, df: DataFrame) -> Mod
63
81
  spark: SparkSession
64
82
  source: A comma-separated string of Spark temporary views to read.
65
83
  df: The Spark DataFrame to convert.
84
+ description: Table level comment
66
85
 
67
86
  Returns:
68
87
  Model: The generated data contract model.
@@ -70,7 +89,10 @@ def import_from_spark_df(spark: SparkSession, source: str, df: DataFrame) -> Mod
70
89
  model = Model()
71
90
  schema = df.schema
72
91
 
73
- model.description = _table_comment_from_spark(spark, source)
92
+ if description is None:
93
+ model.description = _table_comment_from_spark(spark, source)
94
+ else:
95
+ model.description = description
74
96
 
75
97
  for field in schema:
76
98
  model.fields[field.name] = _field_from_struct_type(field)
@@ -199,7 +221,7 @@ def _table_comment_from_spark(spark: SparkSession, source: str):
199
221
  workspace_client = WorkspaceClient()
200
222
  created_table = workspace_client.tables.get(full_name=f"{source}")
201
223
  table_comment = created_table.comment
202
- print(f"'{source}' table comment retrieved using 'WorkspaceClient.tables.get({source})'")
224
+ logger.info(f"'{source}' table comment retrieved using 'WorkspaceClient.tables.get({source})'")
203
225
  return table_comment
204
226
  except Exception:
205
227
  pass
@@ -207,7 +229,7 @@ def _table_comment_from_spark(spark: SparkSession, source: str):
207
229
  # Fallback to Spark Catalog API for Hive Metastore or Non-UC Tables
208
230
  try:
209
231
  table_comment = spark.catalog.getTable(f"{source}").description
210
- print(f"'{source}' table comment retrieved using 'spark.catalog.getTable({source}).description'")
232
+ logger.info(f"'{source}' table comment retrieved using 'spark.catalog.getTable({source}).description'")
211
233
  return table_comment
212
234
  except Exception:
213
235
  pass
@@ -219,7 +241,7 @@ def _table_comment_from_spark(spark: SparkSession, source: str):
219
241
  if row.col_name.strip().lower() == "comment":
220
242
  table_comment = row.data_type
221
243
  break
222
- print(f"'{source}' table comment retrieved using 'DESCRIBE TABLE EXTENDED {source}'")
244
+ logger.info(f"'{source}' table comment retrieved using 'DESCRIBE TABLE EXTENDED {source}'")
223
245
  return table_comment
224
246
  except Exception:
225
247
  pass
@@ -105,7 +105,7 @@ def to_dialect(import_args: dict) -> Dialects | None:
105
105
  return None
106
106
 
107
107
 
108
- def to_physical_type_key(dialect: Dialects | None) -> str:
108
+ def to_physical_type_key(dialect: Dialects | str | None) -> str:
109
109
  dialect_map = {
110
110
  Dialects.TSQL: "sqlserverType",
111
111
  Dialects.POSTGRES: "postgresType",
@@ -116,6 +116,8 @@ def to_physical_type_key(dialect: Dialects | None) -> str:
116
116
  Dialects.MYSQL: "mysqlType",
117
117
  Dialects.DATABRICKS: "databricksType",
118
118
  }
119
+ if isinstance(dialect, str):
120
+ dialect = Dialects[dialect.upper()] if dialect.upper() in Dialects.__members__ else None
119
121
  return dialect_map.get(dialect, "physicalType")
120
122
 
121
123
 
@@ -198,7 +200,7 @@ def get_precision_scale(column):
198
200
  return None, None
199
201
 
200
202
 
201
- def map_type_from_sql(sql_type: str):
203
+ def map_type_from_sql(sql_type: str) -> str | None:
202
204
  if sql_type is None:
203
205
  return None
204
206