datacontract-cli 0.10.33__py3-none-any.whl → 0.10.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datacontract-cli might be problematic. Click here for more details.

@@ -0,0 +1,79 @@
1
+ import os
2
+
3
+ import yaml
4
+
5
+ from datacontract.model.exceptions import DataContractException
6
+
7
+
8
+ def to_athena_soda_configuration(server):
9
+ s3_region = os.getenv("DATACONTRACT_S3_REGION")
10
+ s3_access_key_id = os.getenv("DATACONTRACT_S3_ACCESS_KEY_ID")
11
+ s3_secret_access_key = os.getenv("DATACONTRACT_S3_SECRET_ACCESS_KEY")
12
+ s3_session_token = os.getenv("DATACONTRACT_S3_SESSION_TOKEN")
13
+
14
+ # Validate required parameters
15
+ if not s3_access_key_id:
16
+ raise DataContractException(
17
+ type="athena-connection",
18
+ name="missing_access_key_id",
19
+ reason="AWS access key ID is required. Set the DATACONTRACT_S3_ACCESS_KEY_ID environment variable.",
20
+ engine="datacontract",
21
+ )
22
+
23
+ if not s3_secret_access_key:
24
+ raise DataContractException(
25
+ type="athena-connection",
26
+ name="missing_secret_access_key",
27
+ reason="AWS secret access key is required. Set the DATACONTRACT_S3_SECRET_ACCESS_KEY environment variable.",
28
+ engine="datacontract",
29
+ )
30
+
31
+ if not hasattr(server, "schema_") or not server.schema_:
32
+ raise DataContractException(
33
+ type="athena-connection",
34
+ name="missing_schema",
35
+ reason="Schema is required for Athena connection. Specify the schema where your tables exist in the server configuration.",
36
+ engine="datacontract",
37
+ )
38
+
39
+ if not hasattr(server, "stagingDir") or not server.stagingDir:
40
+ raise DataContractException(
41
+ type="athena-connection",
42
+ name="missing_s3_staging_dir",
43
+ reason="S3 staging directory is required for Athena connection. This should be the Amazon S3 Query Result Location (e.g., 's3://my-bucket/athena-results/').",
44
+ engine="datacontract",
45
+ )
46
+
47
+ # Validate S3 staging directory format
48
+ if not server.stagingDir.startswith("s3://"):
49
+ raise DataContractException(
50
+ type="athena-connection",
51
+ name="invalid_s3_staging_dir",
52
+ reason=f"S3 staging directory must start with 's3://'. Got: {server.s3_staging_dir}. Example: 's3://my-bucket/athena-results/'",
53
+ engine="datacontract",
54
+ )
55
+
56
+ data_source = {
57
+ "type": "athena",
58
+ "access_key_id": s3_access_key_id,
59
+ "secret_access_key": s3_secret_access_key,
60
+ "schema": server.schema_,
61
+ "staging_dir": server.stagingDir,
62
+ }
63
+
64
+ if s3_region:
65
+ data_source["region_name"] = s3_region
66
+ elif server.region_name:
67
+ data_source["region_name"] = server.region_name
68
+
69
+ if server.catalog:
70
+ # Optional, Identify the name of the Data Source, also referred to as a Catalog. The default value is `awsdatacatalog`.
71
+ data_source["catalog"] = server.catalog
72
+
73
+ if s3_session_token:
74
+ data_source["aws_session_token"] = s3_session_token
75
+
76
+ soda_configuration = {f"data_source {server.type}": data_source}
77
+
78
+ soda_configuration_str = yaml.dump(soda_configuration)
79
+ return soda_configuration_str
@@ -71,6 +71,9 @@ def get_duckdb_connection(
71
71
  elif server.format == "delta":
72
72
  con.sql("update extensions;") # Make sure we have the latest delta extension
73
73
  con.sql(f"""CREATE VIEW "{model_name}" AS SELECT * FROM delta_scan('{model_path}');""")
74
+ table_info = con.sql(f"PRAGMA table_info('{model_name}');").fetchdf()
75
+ if table_info is not None and not table_info.empty:
76
+ run.log_info(f"DuckDB Table Info: {table_info.to_string(index=False)}")
74
77
  return con
75
78
 
76
79
 
@@ -44,12 +44,18 @@ def to_avro_field(field, field_name):
44
44
  avro_type = to_avro_type(field, field_name)
45
45
  avro_field["type"] = avro_type if is_required_avro else ["null", avro_type]
46
46
 
47
- if avro_field["type"] == "enum":
48
- avro_field["type"] = {
47
+ # Handle enum types - both required and optional
48
+ if avro_type == "enum" or (isinstance(avro_field["type"], list) and "enum" in avro_field["type"]):
49
+ enum_def = {
49
50
  "type": "enum",
50
51
  "name": field.title,
51
52
  "symbols": field.enum,
52
53
  }
54
+ if is_required_avro:
55
+ avro_field["type"] = enum_def
56
+ else:
57
+ # Replace "enum" with the full enum definition in the union
58
+ avro_field["type"] = ["null", enum_def]
53
59
 
54
60
  if field.config:
55
61
  if "avroDefault" in field.config:
@@ -77,6 +83,10 @@ def to_avro_type(field: Field, field_name: str) -> str | dict:
77
83
  if "avroType" in field.config:
78
84
  return field.config["avroType"]
79
85
 
86
+ # Check for enum fields based on presence of enum list and avroType config
87
+ if field.enum and field.config and field.config.get("avroType") == "enum":
88
+ return "enum"
89
+
80
90
  if field.type is None:
81
91
  return "null"
82
92
  if field.type in ["string", "varchar", "text"]:
@@ -0,0 +1,121 @@
1
+ from typing import Any, Dict, List, Union
2
+
3
+ import yaml
4
+
5
+ from datacontract.export.exporter import Exporter, _check_models_for_export
6
+ from datacontract.model.data_contract_specification import DataContractSpecification, Field, Model, Quality
7
+
8
+
9
+ class DqxKeys:
10
+ CHECK = "check"
11
+ ARGUMENTS = "arguments"
12
+ SPECIFICATION = "specification"
13
+ COL_NAME = "column"
14
+ COL_NAMES = "for_each_column"
15
+ COLUMNS = "columns"
16
+ FUNCTION = "function"
17
+
18
+
19
+ class DqxExporter(Exporter):
20
+ """Exporter implementation for converting data contracts to DQX YAML file."""
21
+
22
+ def export(
23
+ self,
24
+ data_contract: DataContractSpecification,
25
+ model: Model,
26
+ server: str,
27
+ sql_server_type: str,
28
+ export_args: Dict[str, Any],
29
+ ) -> str:
30
+ """Exports a data contract to DQX format."""
31
+ model_name, model_value = _check_models_for_export(data_contract, model, self.export_format)
32
+ return to_dqx_yaml(model_value)
33
+
34
+
35
+ def to_dqx_yaml(model_value: Model) -> str:
36
+ """
37
+ Converts the data contract's quality checks to DQX YAML format.
38
+
39
+ Args:
40
+ model_value (Model): The data contract to convert.
41
+
42
+ Returns:
43
+ str: YAML representation of the data contract's quality checks.
44
+ """
45
+ extracted_rules = extract_quality_rules(model_value)
46
+ return yaml.dump(extracted_rules, sort_keys=False, allow_unicode=True, default_flow_style=False)
47
+
48
+
49
+ def process_quality_rule(rule: Quality, column_name: str) -> Dict[str, Any]:
50
+ """
51
+ Processes a single quality rule by injecting the column path into its arguments if absent.
52
+
53
+ Args:
54
+ rule (Quality): The quality rule to process.
55
+ column_name (str): The full path to the current column.
56
+
57
+ Returns:
58
+ dict: The processed quality rule specification.
59
+ """
60
+ rule_data = rule.model_extra
61
+ specification = rule_data[DqxKeys.SPECIFICATION]
62
+ check = specification[DqxKeys.CHECK]
63
+
64
+ arguments = check.setdefault(DqxKeys.ARGUMENTS, {})
65
+
66
+ if DqxKeys.COL_NAME not in arguments and DqxKeys.COL_NAMES not in arguments and DqxKeys.COLUMNS not in arguments:
67
+ if check[DqxKeys.FUNCTION] not in ("is_unique", "foreign_key"):
68
+ arguments[DqxKeys.COL_NAME] = column_name
69
+ else:
70
+ arguments[DqxKeys.COLUMNS] = [column_name]
71
+
72
+ return specification
73
+
74
+
75
+ def extract_quality_rules(data: Union[Model, Field, Quality], column_path: str = "") -> List[Dict[str, Any]]:
76
+ """
77
+ Recursively extracts all quality rules from a data contract structure.
78
+
79
+ Args:
80
+ data (Union[Model, Field, Quality]): The data contract model, field, or quality rule.
81
+ column_path (str, optional): The current path in the schema hierarchy. Defaults to "".
82
+
83
+ Returns:
84
+ List[Dict[str, Any]]: A list of quality rule specifications.
85
+ """
86
+ quality_rules = []
87
+
88
+ if isinstance(data, Quality):
89
+ return [process_quality_rule(data, column_path)]
90
+
91
+ if isinstance(data, (Model, Field)):
92
+ for key, field in data.fields.items():
93
+ current_path = build_column_path(column_path, key)
94
+
95
+ if field.fields:
96
+ # Field is a struct-like object, recurse deeper
97
+ quality_rules.extend(extract_quality_rules(field, current_path))
98
+ else:
99
+ # Process quality rules at leaf fields
100
+ for rule in field.quality:
101
+ quality_rules.append(process_quality_rule(rule, current_path))
102
+
103
+ # Process any quality rules attached directly to this level
104
+ for rule in data.quality:
105
+ quality_rules.append(process_quality_rule(rule, column_path))
106
+
107
+ return quality_rules
108
+
109
+
110
+ def build_column_path(current_path: str, key: str) -> str:
111
+ """
112
+ Builds the full column path by concatenating parent path with current key.
113
+
114
+ Args:
115
+ current_path (str): The current path prefix.
116
+ key (str): The current field's key.
117
+
118
+ Returns:
119
+ str: The full path.
120
+ """
121
+ return f"{current_path}.{key}" if current_path else key
@@ -46,6 +46,7 @@ class ExportFormat(str, Enum):
46
46
  iceberg = "iceberg"
47
47
  custom = "custom"
48
48
  excel = "excel"
49
+ dqx = "dqx"
49
50
 
50
51
  @classmethod
51
52
  def get_supported_formats(cls):
@@ -197,6 +197,12 @@ exporter_factory.register_lazy_exporter(
197
197
  class_name="MarkdownExporter",
198
198
  )
199
199
 
200
+ exporter_factory.register_lazy_exporter(
201
+ name=ExportFormat.dqx,
202
+ module_path="datacontract.export.dqx_converter",
203
+ class_name="DqxExporter",
204
+ )
205
+
200
206
  exporter_factory.register_lazy_exporter(
201
207
  name=ExportFormat.iceberg, module_path="datacontract.export.iceberg_converter", class_name="IcebergExporter"
202
208
  )
@@ -1,4 +1,4 @@
1
- from typing import Dict
1
+ from typing import Dict, List
2
2
 
3
3
  from pydantic import BaseModel
4
4
 
@@ -12,6 +12,9 @@ from datacontract.model.data_contract_specification import (
12
12
  ServiceLevel,
13
13
  )
14
14
 
15
+ TAB = " "
16
+ ARROW = "↳"
17
+
15
18
 
16
19
  class MarkdownExporter(Exporter):
17
20
  """Exporter implementation for converting data contracts to Markdown."""
@@ -70,7 +73,8 @@ def obj_attributes_to_markdown(obj: BaseModel, excluded_fields: set = set(), is_
70
73
  else:
71
74
  bullet_char = "-"
72
75
  newline_char = "\n"
73
- obj_model = obj.model_dump(exclude_unset=True, exclude=excluded_fields)
76
+ model_attributes_to_include = set(obj.__class__.model_fields.keys())
77
+ obj_model = obj.model_dump(exclude_unset=True, include=model_attributes_to_include, exclude=excluded_fields)
74
78
  description_value = obj_model.pop("description", None)
75
79
  attributes = [
76
80
  (f"{bullet_char} `{attr}`" if value is True else f"{bullet_char} **{attr}:** {value}")
@@ -78,7 +82,8 @@ def obj_attributes_to_markdown(obj: BaseModel, excluded_fields: set = set(), is_
78
82
  if value
79
83
  ]
80
84
  description = f"*{description_to_markdown(description_value)}*"
81
- return newline_char.join([description] + attributes)
85
+ extra = [extra_to_markdown(obj)] if obj.model_extra else []
86
+ return newline_char.join([description] + attributes + extra)
82
87
 
83
88
 
84
89
  def servers_to_markdown(servers: Dict[str, Server]) -> str:
@@ -153,8 +158,8 @@ def field_to_markdown(field_name: str, field: Field, level: int = 0) -> str:
153
158
  Returns:
154
159
  str: A Markdown table rows for the field.
155
160
  """
156
- tabs = " " * level
157
- arrow = "↳" if level > 0 else ""
161
+ tabs = TAB * level
162
+ arrow = ARROW if level > 0 else ""
158
163
  column_name = f"{tabs}{arrow} {field_name}"
159
164
 
160
165
  attributes = obj_attributes_to_markdown(field, {"type", "fields", "items", "keys", "values"}, True)
@@ -206,3 +211,108 @@ def service_level_to_markdown(service_level: ServiceLevel | None) -> str:
206
211
 
207
212
  def description_to_markdown(description: str | None) -> str:
208
213
  return (description or "No description.").replace("\n", "<br>")
214
+
215
+
216
+ def array_of_dict_to_markdown(array: List[Dict[str, str]]) -> str:
217
+ """
218
+ Convert a list of dictionaries to a Markdown table.
219
+
220
+ Args:
221
+ array (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a row in the table.
222
+
223
+ Returns:
224
+ str: A Markdown formatted table.
225
+ """
226
+ if not array:
227
+ return ""
228
+
229
+ headers = []
230
+
231
+ for item in array:
232
+ headers += item.keys()
233
+ headers = list(dict.fromkeys(headers)) # Preserve order and remove duplicates
234
+
235
+ markdown_parts = [
236
+ "| " + " | ".join(headers) + " |",
237
+ "| " + " | ".join(["---"] * len(headers)) + " |",
238
+ ]
239
+
240
+ for row in array:
241
+ element = row
242
+ markdown_parts.append(
243
+ "| "
244
+ + " | ".join(
245
+ f"{str(element.get(header, ''))}".replace("\n", "<br>").replace("\t", TAB) for header in headers
246
+ )
247
+ + " |"
248
+ )
249
+
250
+ return "\n".join(markdown_parts) + "\n"
251
+
252
+
253
+ def array_to_markdown(array: List[str]) -> str:
254
+ """
255
+ Convert a list of strings to a Markdown formatted list.
256
+
257
+ Args:
258
+ array (List[str]): A list of strings to convert.
259
+
260
+ Returns:
261
+ str: A Markdown formatted list.
262
+ """
263
+ if not array:
264
+ return ""
265
+ return "\n".join(f"- {item}" for item in array) + "\n"
266
+
267
+
268
+ def dict_to_markdown(dictionary: Dict[str, str]) -> str:
269
+ """
270
+ Convert a dictionary to a Markdown formatted list.
271
+
272
+ Args:
273
+ dictionary (Dict[str, str]): A dictionary where keys are item names and values are item descriptions.
274
+
275
+ Returns:
276
+ str: A Markdown formatted list of items.
277
+ """
278
+ if not dictionary:
279
+ return ""
280
+
281
+ markdown_parts = []
282
+ for key, value in dictionary.items():
283
+ if isinstance(value, dict):
284
+ markdown_parts.append(f"- {key}")
285
+ nested_markdown = dict_to_markdown(value)
286
+ if nested_markdown:
287
+ nested_lines = nested_markdown.split("\n")
288
+ for line in nested_lines:
289
+ if line.strip():
290
+ markdown_parts.append(f" {line}")
291
+ else:
292
+ markdown_parts.append(f"- {key}: {value}")
293
+ return "\n".join(markdown_parts) + "\n"
294
+
295
+
296
+ def extra_to_markdown(obj: BaseModel) -> str:
297
+ """
298
+ Convert the extra attributes of a data contract to Markdown format.
299
+ Args:
300
+ obj (BaseModel): The data contract object containing extra attributes.
301
+ Returns:
302
+ str: A Markdown formatted string representing the extra attributes of the data contract.
303
+ """
304
+ markdown_part = ""
305
+ extra = obj.model_extra
306
+ if extra:
307
+ for key_extra, value_extra in extra.items():
308
+ markdown_part += f"\n### {key_extra.capitalize()}\n"
309
+ if isinstance(value_extra, list) and len(value_extra):
310
+ if isinstance(value_extra[0], dict):
311
+ markdown_part += array_of_dict_to_markdown(value_extra)
312
+ elif isinstance(value_extra[0], str):
313
+ markdown_part += array_to_markdown(value_extra)
314
+ elif isinstance(value_extra, dict):
315
+ markdown_part += dict_to_markdown(value_extra)
316
+ else:
317
+ markdown_part += f"{str(value_extra)}\n"
318
+ return markdown_part
@@ -27,31 +27,33 @@ def dcs_to_mermaid(data_contract_spec: DataContractSpecification) -> str | None:
27
27
  mmd_references = []
28
28
 
29
29
  for model_name, model in data_contract_spec.models.items():
30
+ clean_model = _sanitize_name(model_name)
30
31
  entity_block = ""
31
32
 
32
33
  for field_name, field in model.fields.items():
33
34
  clean_name = _sanitize_name(field_name)
34
- indicators = ""
35
+ field_type = field.type or "unknown"
35
36
 
36
- if field.primaryKey or (field.unique and field.required):
37
- indicators += "🔑"
38
- if field.references:
39
- indicators += "⌘"
37
+ is_pk = bool(field.primaryKey or (field.unique and field.required))
38
+ is_fk = bool(field.references)
40
39
 
41
- field_type = field.type or "unknown"
42
- entity_block += f"\t{clean_name}{indicators} {field_type}\n"
40
+ entity_block += _field_line(clean_name, field_type, pk=is_pk, uk=bool(field.unique), fk=is_fk)
43
41
 
44
42
  if field.references:
45
- referenced_model = field.references.split(".")[0] if "." in field.references else ""
43
+ references = field.references.replace(".", "·")
44
+ parts = references.split("·")
45
+ referenced_model = _sanitize_name(parts[0]) if len(parts) > 0 else ""
46
+ referenced_field = _sanitize_name(parts[1]) if len(parts) > 1 else ""
46
47
  if referenced_model:
47
- mmd_references.append(f'"📑{referenced_model}"' + "}o--{ ||" + f'"📑{model_name}"')
48
+ label = referenced_field or clean_name
49
+ mmd_references.append(f'"**{referenced_model}**" ||--o{{ "**{clean_model}**" : {label}')
48
50
 
49
- mmd_entity += f'\t"**{model_name}**"' + "{\n" + entity_block + "}\n"
51
+ mmd_entity += f'\t"**{clean_model}**" {{\n{entity_block}}}\n'
50
52
 
51
53
  if mmd_references:
52
54
  mmd_entity += "\n" + "\n".join(mmd_references)
53
55
 
54
- return f"{mmd_entity}\n"
56
+ return mmd_entity + "\n"
55
57
 
56
58
  except Exception as e:
57
59
  print(f"Error generating DCS mermaid diagram: {e}")
@@ -95,3 +97,14 @@ def odcs_to_mermaid(data_contract_spec: OpenDataContractStandard) -> str | None:
95
97
 
96
98
  def _sanitize_name(name: str) -> str:
97
99
  return name.replace("#", "Nb").replace(" ", "_").replace("/", "by")
100
+
101
+
102
+ def _field_line(name: str, field_type: str, pk: bool = False, uk: bool = False, fk: bool = False) -> str:
103
+ indicators = ""
104
+ if pk:
105
+ indicators += "🔑"
106
+ if uk:
107
+ indicators += "🔒"
108
+ if fk:
109
+ indicators += "⌘"
110
+ return f"\t{name}{indicators} {field_type}\n"
@@ -1,3 +1,5 @@
1
+ import json
2
+
1
3
  from pyspark.sql import types
2
4
 
3
5
  from datacontract.export.exporter import Exporter
@@ -104,7 +106,8 @@ def to_struct_field(field: Field, field_name: str) -> types.StructField:
104
106
  types.StructField: The corresponding Spark StructField.
105
107
  """
106
108
  data_type = to_spark_data_type(field)
107
- return types.StructField(name=field_name, dataType=data_type, nullable=not field.required)
109
+ metadata = to_spark_metadata(field)
110
+ return types.StructField(name=field_name, dataType=data_type, nullable=not field.required, metadata=metadata)
108
111
 
109
112
 
110
113
  def to_spark_data_type(field: Field) -> types.DataType:
@@ -152,7 +155,25 @@ def to_spark_data_type(field: Field) -> types.DataType:
152
155
  return types.DateType()
153
156
  if field_type == "bytes":
154
157
  return types.BinaryType()
155
- return types.StringType() # default if no condition is met
158
+ return types.StringType() # default if no condition is met
159
+
160
+
161
+ def to_spark_metadata(field: Field) -> dict[str, str]:
162
+ """
163
+ Convert a field to a Spark metadata dictonary.
164
+
165
+ Args:
166
+ field (Field): The field to convert.
167
+
168
+ Returns:
169
+ dict: dictionary that can be supplied to Spark as metadata for a StructField
170
+ """
171
+
172
+ metadata = {}
173
+ if field.description:
174
+ metadata["comment"] = field.description
175
+
176
+ return metadata
156
177
 
157
178
 
158
179
  def print_schema(dtype: types.DataType) -> str:
@@ -192,7 +213,11 @@ def print_schema(dtype: types.DataType) -> str:
192
213
  name = f'"{column.name}"'
193
214
  data_type = indent(print_schema(column.dataType), 1)
194
215
  nullable = indent(f"{column.nullable}", 1)
195
- return f"StructField({name},\n{data_type},\n{nullable}\n)"
216
+ if column.metadata:
217
+ metadata = indent(f"{json.dumps(column.metadata)}", 1)
218
+ return f"StructField({name},\n{data_type},\n{nullable},\n{metadata}\n)"
219
+ else:
220
+ return f"StructField({name},\n{data_type},\n{nullable}\n)"
196
221
 
197
222
  def format_struct_type(struct_type: types.StructType) -> str:
198
223
  """
@@ -3,6 +3,9 @@ from datacontract.model.data_contract_specification import Field
3
3
 
4
4
 
5
5
  def convert_to_sql_type(field: Field, server_type: str) -> str:
6
+ if field.config and "physicalType" in field.config:
7
+ return field.config["physicalType"]
8
+
6
9
  if server_type == "snowflake":
7
10
  return convert_to_snowflake(field)
8
11
  elif server_type == "postgres":
@@ -19,6 +22,7 @@ def convert_to_sql_type(field: Field, server_type: str) -> str:
19
22
  return convert_type_to_bigquery(field)
20
23
  elif server_type == "trino":
21
24
  return convert_type_to_trino(field)
25
+
22
26
  return field.type
23
27
 
24
28
 
@@ -130,13 +130,23 @@ def import_record_fields(record_fields: List[avro.schema.Field]) -> Dict[str, Fi
130
130
  imported_field.fields = import_record_fields(field.type.fields)
131
131
  elif field.type.type == "union":
132
132
  imported_field.required = False
133
- type = import_type_of_optional_field(field)
134
- imported_field.type = type
135
- if type == "record":
136
- imported_field.fields = import_record_fields(get_record_from_union_field(field).fields)
137
- elif type == "array":
138
- imported_field.type = "array"
139
- imported_field.items = import_avro_array_items(get_array_from_union_field(field))
133
+ # Check for enum in union first, since it needs special handling
134
+ enum_schema = get_enum_from_union_field(field)
135
+ if enum_schema:
136
+ imported_field.type = "string"
137
+ imported_field.enum = enum_schema.symbols
138
+ imported_field.title = enum_schema.name
139
+ if not imported_field.config:
140
+ imported_field.config = {}
141
+ imported_field.config["avroType"] = "enum"
142
+ else:
143
+ type = import_type_of_optional_field(field)
144
+ imported_field.type = type
145
+ if type == "record":
146
+ imported_field.fields = import_record_fields(get_record_from_union_field(field).fields)
147
+ elif type == "array":
148
+ imported_field.type = "array"
149
+ imported_field.items = import_avro_array_items(get_array_from_union_field(field))
140
150
  elif field.type.type == "array":
141
151
  imported_field.type = "array"
142
152
  imported_field.items = import_avro_array_items(field.type)
@@ -277,6 +287,22 @@ def get_array_from_union_field(field: avro.schema.Field) -> avro.schema.ArraySch
277
287
  return None
278
288
 
279
289
 
290
+ def get_enum_from_union_field(field: avro.schema.Field) -> avro.schema.EnumSchema | None:
291
+ """
292
+ Get the enum schema from a union field.
293
+
294
+ Args:
295
+ field: The Avro field with a union type.
296
+
297
+ Returns:
298
+ The enum schema if found, None otherwise.
299
+ """
300
+ for field_type in field.type.schemas:
301
+ if field_type.type == "enum":
302
+ return field_type
303
+ return None
304
+
305
+
280
306
  def map_type_from_avro(avro_type_str: str) -> str:
281
307
  """
282
308
  Map Avro type strings to data contract type strings.
@@ -131,14 +131,18 @@ def import_servers(odcs: OpenDataContractStandard) -> Dict[str, Server] | None:
131
131
  server.host = odcs_server.host
132
132
  server.port = odcs_server.port
133
133
  server.catalog = odcs_server.catalog
134
+ server.stagingDir = odcs_server.stagingDir
134
135
  server.topic = getattr(odcs_server, "topic", None)
135
136
  server.http_path = getattr(odcs_server, "http_path", None)
136
137
  server.token = getattr(odcs_server, "token", None)
137
138
  server.driver = getattr(odcs_server, "driver", None)
138
139
  server.roles = import_server_roles(odcs_server.roles)
139
140
  server.storageAccount = (
140
- re.search(r"(?:@|://)([^.]+)\.", odcs_server.location, re.IGNORECASE) if server.type == "azure" else None
141
+ to_azure_storage_account(odcs_server.location)
142
+ if server.type == "azure" and "://" in server.location
143
+ else None
141
144
  )
145
+
142
146
  servers[server_name] = server
143
147
  return servers
144
148
 
@@ -413,3 +417,28 @@ def import_tags(odcs: OpenDataContractStandard) -> List[str] | None:
413
417
  if odcs.tags is None:
414
418
  return None
415
419
  return odcs.tags
420
+
421
+
422
+ def to_azure_storage_account(location: str) -> str | None:
423
+ """
424
+ Converts a storage location string to extract the storage account name.
425
+ ODCS v3.0 has no explicit field for the storage account. It uses the location field, which is a URI.
426
+
427
+ This function parses a storage location string to identify and return the
428
+ storage account name. It handles two primary patterns:
429
+ 1. Protocol://containerName@storageAccountName
430
+ 2. Protocol://storageAccountName
431
+
432
+ :param location: The storage location string to parse, typically following
433
+ the format protocol://containerName@storageAccountName. or
434
+ protocol://storageAccountName.
435
+ :return: The extracted storage account name if found, otherwise None
436
+ """
437
+ # to catch protocol://containerName@storageAccountName. pattern from location
438
+ match = re.search(r"(?<=@)([^.]*)", location, re.IGNORECASE)
439
+ if match:
440
+ return match.group()
441
+ else:
442
+ # to catch protocol://storageAccountName. pattern from location
443
+ match = re.search(r"(?<=//)(?!@)([^.]*)", location, re.IGNORECASE)
444
+ return match.group() if match else None