datacontract-cli 0.10.0__py3-none-any.whl → 0.10.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. datacontract/__init__.py +13 -0
  2. datacontract/api.py +260 -0
  3. datacontract/breaking/breaking.py +242 -12
  4. datacontract/breaking/breaking_rules.py +37 -1
  5. datacontract/catalog/catalog.py +80 -0
  6. datacontract/cli.py +387 -117
  7. datacontract/data_contract.py +216 -353
  8. datacontract/engines/data_contract_checks.py +1041 -0
  9. datacontract/engines/data_contract_test.py +113 -0
  10. datacontract/engines/datacontract/check_that_datacontract_contains_valid_servers_configuration.py +2 -3
  11. datacontract/engines/datacontract/check_that_datacontract_file_exists.py +1 -1
  12. datacontract/engines/fastjsonschema/check_jsonschema.py +176 -42
  13. datacontract/engines/fastjsonschema/s3/s3_read_files.py +16 -1
  14. datacontract/engines/soda/check_soda_execute.py +100 -56
  15. datacontract/engines/soda/connections/athena.py +79 -0
  16. datacontract/engines/soda/connections/bigquery.py +8 -1
  17. datacontract/engines/soda/connections/databricks.py +12 -3
  18. datacontract/engines/soda/connections/duckdb_connection.py +241 -0
  19. datacontract/engines/soda/connections/kafka.py +206 -113
  20. datacontract/engines/soda/connections/snowflake.py +8 -5
  21. datacontract/engines/soda/connections/sqlserver.py +43 -0
  22. datacontract/engines/soda/connections/trino.py +26 -0
  23. datacontract/export/avro_converter.py +72 -8
  24. datacontract/export/avro_idl_converter.py +31 -25
  25. datacontract/export/bigquery_converter.py +130 -0
  26. datacontract/export/custom_converter.py +40 -0
  27. datacontract/export/data_caterer_converter.py +161 -0
  28. datacontract/export/dbml_converter.py +148 -0
  29. datacontract/export/dbt_converter.py +141 -54
  30. datacontract/export/dcs_exporter.py +6 -0
  31. datacontract/export/dqx_converter.py +126 -0
  32. datacontract/export/duckdb_type_converter.py +57 -0
  33. datacontract/export/excel_exporter.py +923 -0
  34. datacontract/export/exporter.py +100 -0
  35. datacontract/export/exporter_factory.py +216 -0
  36. datacontract/export/go_converter.py +105 -0
  37. datacontract/export/great_expectations_converter.py +257 -36
  38. datacontract/export/html_exporter.py +86 -0
  39. datacontract/export/iceberg_converter.py +188 -0
  40. datacontract/export/jsonschema_converter.py +71 -16
  41. datacontract/export/markdown_converter.py +337 -0
  42. datacontract/export/mermaid_exporter.py +110 -0
  43. datacontract/export/odcs_v3_exporter.py +375 -0
  44. datacontract/export/pandas_type_converter.py +40 -0
  45. datacontract/export/protobuf_converter.py +168 -68
  46. datacontract/export/pydantic_converter.py +6 -0
  47. datacontract/export/rdf_converter.py +13 -6
  48. datacontract/export/sodacl_converter.py +36 -188
  49. datacontract/export/spark_converter.py +245 -0
  50. datacontract/export/sql_converter.py +37 -3
  51. datacontract/export/sql_type_converter.py +269 -8
  52. datacontract/export/sqlalchemy_converter.py +170 -0
  53. datacontract/export/terraform_converter.py +7 -2
  54. datacontract/imports/avro_importer.py +246 -26
  55. datacontract/imports/bigquery_importer.py +221 -0
  56. datacontract/imports/csv_importer.py +143 -0
  57. datacontract/imports/dbml_importer.py +112 -0
  58. datacontract/imports/dbt_importer.py +240 -0
  59. datacontract/imports/excel_importer.py +1111 -0
  60. datacontract/imports/glue_importer.py +288 -0
  61. datacontract/imports/iceberg_importer.py +172 -0
  62. datacontract/imports/importer.py +51 -0
  63. datacontract/imports/importer_factory.py +128 -0
  64. datacontract/imports/json_importer.py +325 -0
  65. datacontract/imports/jsonschema_importer.py +146 -0
  66. datacontract/imports/odcs_importer.py +60 -0
  67. datacontract/imports/odcs_v3_importer.py +516 -0
  68. datacontract/imports/parquet_importer.py +81 -0
  69. datacontract/imports/protobuf_importer.py +264 -0
  70. datacontract/imports/spark_importer.py +262 -0
  71. datacontract/imports/sql_importer.py +274 -35
  72. datacontract/imports/unity_importer.py +219 -0
  73. datacontract/init/init_template.py +20 -0
  74. datacontract/integration/datamesh_manager.py +86 -0
  75. datacontract/lint/resolve.py +271 -49
  76. datacontract/lint/resources.py +21 -0
  77. datacontract/lint/schema.py +53 -17
  78. datacontract/lint/urls.py +32 -12
  79. datacontract/model/data_contract_specification/__init__.py +1 -0
  80. datacontract/model/exceptions.py +4 -1
  81. datacontract/model/odcs.py +24 -0
  82. datacontract/model/run.py +49 -29
  83. datacontract/output/__init__.py +0 -0
  84. datacontract/output/junit_test_results.py +135 -0
  85. datacontract/output/output_format.py +10 -0
  86. datacontract/output/test_results_writer.py +79 -0
  87. datacontract/py.typed +0 -0
  88. datacontract/schemas/datacontract-1.1.0.init.yaml +91 -0
  89. datacontract/schemas/datacontract-1.1.0.schema.json +1975 -0
  90. datacontract/schemas/datacontract-1.2.0.init.yaml +91 -0
  91. datacontract/schemas/datacontract-1.2.0.schema.json +2029 -0
  92. datacontract/schemas/datacontract-1.2.1.init.yaml +91 -0
  93. datacontract/schemas/datacontract-1.2.1.schema.json +2058 -0
  94. datacontract/schemas/odcs-3.0.1.schema.json +2634 -0
  95. datacontract/schemas/odcs-3.0.2.schema.json +2382 -0
  96. datacontract/templates/datacontract.html +139 -294
  97. datacontract/templates/datacontract_odcs.html +685 -0
  98. datacontract/templates/index.html +236 -0
  99. datacontract/templates/partials/datacontract_information.html +86 -0
  100. datacontract/templates/partials/datacontract_servicelevels.html +253 -0
  101. datacontract/templates/partials/datacontract_terms.html +51 -0
  102. datacontract/templates/partials/definition.html +25 -0
  103. datacontract/templates/partials/example.html +27 -0
  104. datacontract/templates/partials/model_field.html +144 -0
  105. datacontract/templates/partials/quality.html +49 -0
  106. datacontract/templates/partials/server.html +211 -0
  107. datacontract/templates/style/output.css +491 -72
  108. datacontract_cli-0.10.37.dist-info/METADATA +2235 -0
  109. datacontract_cli-0.10.37.dist-info/RECORD +119 -0
  110. {datacontract_cli-0.10.0.dist-info → datacontract_cli-0.10.37.dist-info}/WHEEL +1 -1
  111. {datacontract_cli-0.10.0.dist-info → datacontract_cli-0.10.37.dist-info/licenses}/LICENSE +1 -1
  112. datacontract/engines/datacontract/check_that_datacontract_str_is_valid.py +0 -48
  113. datacontract/engines/soda/connections/dask.py +0 -28
  114. datacontract/engines/soda/connections/duckdb.py +0 -76
  115. datacontract/export/csv_type_converter.py +0 -36
  116. datacontract/export/html_export.py +0 -66
  117. datacontract/export/odcs_converter.py +0 -102
  118. datacontract/init/download_datacontract_file.py +0 -17
  119. datacontract/integration/publish_datamesh_manager.py +0 -33
  120. datacontract/integration/publish_opentelemetry.py +0 -107
  121. datacontract/lint/lint.py +0 -141
  122. datacontract/lint/linters/description_linter.py +0 -34
  123. datacontract/lint/linters/example_model_linter.py +0 -91
  124. datacontract/lint/linters/field_pattern_linter.py +0 -34
  125. datacontract/lint/linters/field_reference_linter.py +0 -38
  126. datacontract/lint/linters/notice_period_linter.py +0 -55
  127. datacontract/lint/linters/quality_schema_linter.py +0 -52
  128. datacontract/lint/linters/valid_constraints_linter.py +0 -99
  129. datacontract/model/data_contract_specification.py +0 -141
  130. datacontract/web.py +0 -14
  131. datacontract_cli-0.10.0.dist-info/METADATA +0 -951
  132. datacontract_cli-0.10.0.dist-info/RECORD +0 -66
  133. /datacontract/{model → breaking}/breaking_change.py +0 -0
  134. /datacontract/{lint/linters → export}/__init__.py +0 -0
  135. {datacontract_cli-0.10.0.dist-info → datacontract_cli-0.10.37.dist-info}/entry_points.txt +0 -0
  136. {datacontract_cli-0.10.0.dist-info → datacontract_cli-0.10.37.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,375 @@
1
+ from typing import Any, Dict
2
+
3
+ from open_data_contract_standard.model import (
4
+ CustomProperty,
5
+ DataQuality,
6
+ Description,
7
+ OpenDataContractStandard,
8
+ Role,
9
+ SchemaObject,
10
+ SchemaProperty,
11
+ Server,
12
+ ServiceLevelAgreementProperty,
13
+ Support,
14
+ )
15
+
16
+ from datacontract.export.exporter import Exporter
17
+ from datacontract.model.data_contract_specification import DataContractSpecification, Field, Model
18
+
19
+
20
+ class OdcsV3Exporter(Exporter):
21
+ def export(self, data_contract, model, server, sql_server_type, export_args) -> dict:
22
+ return to_odcs_v3_yaml(data_contract)
23
+
24
+
25
+ def to_odcs_v3_yaml(data_contract_spec: DataContractSpecification) -> str:
26
+ result = to_odcs_v3(data_contract_spec)
27
+
28
+ return result.to_yaml()
29
+
30
+
31
+ def to_odcs_v3(data_contract_spec: DataContractSpecification) -> OpenDataContractStandard:
32
+ result = OpenDataContractStandard(
33
+ apiVersion="v3.0.1",
34
+ kind="DataContract",
35
+ id=data_contract_spec.id,
36
+ name=data_contract_spec.info.title,
37
+ version=data_contract_spec.info.version,
38
+ status=to_status(data_contract_spec.info.status),
39
+ )
40
+ if data_contract_spec.terms is not None:
41
+ result.description = Description(
42
+ purpose=data_contract_spec.terms.description.strip()
43
+ if data_contract_spec.terms.description is not None
44
+ else None,
45
+ usage=data_contract_spec.terms.usage.strip() if data_contract_spec.terms.usage is not None else None,
46
+ limitations=data_contract_spec.terms.limitations.strip()
47
+ if data_contract_spec.terms.limitations is not None
48
+ else None,
49
+ )
50
+ result.schema_ = []
51
+ for model_key, model_value in data_contract_spec.models.items():
52
+ odcs_schema = to_odcs_schema(model_key, model_value)
53
+ result.schema_.append(odcs_schema)
54
+ if data_contract_spec.servicelevels is not None:
55
+ slas = []
56
+ if data_contract_spec.servicelevels.availability is not None:
57
+ slas.append(
58
+ ServiceLevelAgreementProperty(
59
+ property="generalAvailability", value=data_contract_spec.servicelevels.availability.description
60
+ )
61
+ )
62
+ if data_contract_spec.servicelevels.retention is not None:
63
+ slas.append(
64
+ ServiceLevelAgreementProperty(
65
+ property="retention", value=data_contract_spec.servicelevels.retention.period
66
+ )
67
+ )
68
+
69
+ if len(slas) > 0:
70
+ result.slaProperties = slas
71
+ if data_contract_spec.info.contact is not None:
72
+ support = []
73
+ if data_contract_spec.info.contact.email is not None:
74
+ support.append(Support(channel="email", url="mailto:" + data_contract_spec.info.contact.email))
75
+ if data_contract_spec.info.contact.url is not None:
76
+ support.append(Support(channel="other", url=data_contract_spec.info.contact.url))
77
+ if len(support) > 0:
78
+ result.support = support
79
+ if data_contract_spec.servers is not None and len(data_contract_spec.servers) > 0:
80
+ servers = []
81
+
82
+ for server_key, server_value in data_contract_spec.servers.items():
83
+ server = Server(server=server_key, type=server_value.type or "")
84
+
85
+ # Set all the attributes that are not None
86
+ if server_value.environment is not None:
87
+ server.environment = server_value.environment
88
+ if server_value.account is not None:
89
+ server.account = server_value.account
90
+ if server_value.database is not None:
91
+ server.database = server_value.database
92
+ if server_value.schema_ is not None:
93
+ server.schema_ = server_value.schema_
94
+ if server_value.format is not None:
95
+ server.format = server_value.format
96
+ if server_value.project is not None:
97
+ server.project = server_value.project
98
+ if server_value.dataset is not None:
99
+ server.dataset = server_value.dataset
100
+ if server_value.path is not None:
101
+ server.path = server_value.path
102
+ if server_value.delimiter is not None:
103
+ server.delimiter = server_value.delimiter
104
+ if server_value.endpointUrl is not None:
105
+ server.endpointUrl = server_value.endpointUrl
106
+ if server_value.location is not None:
107
+ server.location = server_value.location
108
+ if server_value.host is not None:
109
+ server.host = server_value.host
110
+ if server_value.port is not None:
111
+ server.port = server_value.port
112
+ if server_value.catalog is not None:
113
+ server.catalog = server_value.catalog
114
+ if server_value.topic is not None:
115
+ server.topic = server_value.topic
116
+ if server_value.http_path is not None:
117
+ server.http_path = server_value.http_path
118
+ if server_value.token is not None:
119
+ server.token = server_value.token
120
+ if server_value.driver is not None:
121
+ server.driver = server_value.driver
122
+
123
+ if server_value.roles is not None:
124
+ server.roles = [Role(role=role.name, description=role.description) for role in server_value.roles]
125
+
126
+ servers.append(server)
127
+
128
+ if len(servers) > 0:
129
+ result.servers = servers
130
+ custom_properties = []
131
+ if data_contract_spec.info.owner is not None:
132
+ custom_properties.append(CustomProperty(property="owner", value=data_contract_spec.info.owner))
133
+ if data_contract_spec.info.model_extra is not None:
134
+ for key, value in data_contract_spec.info.model_extra.items():
135
+ custom_properties.append(CustomProperty(property=key, value=value))
136
+ if len(custom_properties) > 0:
137
+ result.customProperties = custom_properties
138
+ return result
139
+
140
+
141
+ def to_odcs_schema(model_key, model_value: Model) -> SchemaObject:
142
+ schema_obj = SchemaObject(
143
+ name=model_key, physicalName=model_key, logicalType="object", physicalType=model_value.type
144
+ )
145
+
146
+ if model_value.description is not None:
147
+ schema_obj.description = model_value.description
148
+
149
+ properties = to_properties(model_value.fields)
150
+ if properties:
151
+ schema_obj.properties = properties
152
+
153
+ model_quality = to_odcs_quality_list(model_value.quality)
154
+ if len(model_quality) > 0:
155
+ schema_obj.quality = model_quality
156
+
157
+ custom_properties = []
158
+ if model_value.model_extra is not None:
159
+ for key, value in model_value.model_extra.items():
160
+ custom_properties.append(CustomProperty(property=key, value=value))
161
+
162
+ if len(custom_properties) > 0:
163
+ schema_obj.customProperties = custom_properties
164
+
165
+ return schema_obj
166
+
167
+
168
+ def to_properties(fields: Dict[str, Field]) -> list:
169
+ properties = []
170
+ for field_name, field in fields.items():
171
+ property = to_property(field_name, field)
172
+ properties.append(property)
173
+ return properties
174
+
175
+
176
+ def to_logical_type(type: str) -> str | None:
177
+ if type is None:
178
+ return None
179
+ if type.lower() in ["string", "varchar", "text"]:
180
+ return "string"
181
+ if type.lower() in ["timestamp", "timestamp_tz"]:
182
+ return "date"
183
+ if type.lower() in ["timestamp_ntz"]:
184
+ return "date"
185
+ if type.lower() in ["date"]:
186
+ return "date"
187
+ if type.lower() in ["time"]:
188
+ return "string"
189
+ if type.lower() in ["number", "decimal", "numeric"]:
190
+ return "number"
191
+ if type.lower() in ["float", "double"]:
192
+ return "number"
193
+ if type.lower() in ["integer", "int", "long", "bigint"]:
194
+ return "integer"
195
+ if type.lower() in ["boolean"]:
196
+ return "boolean"
197
+ if type.lower() in ["object", "record", "struct"]:
198
+ return "object"
199
+ if type.lower() in ["bytes"]:
200
+ return "array"
201
+ if type.lower() in ["array"]:
202
+ return "array"
203
+ if type.lower() in ["variant"]:
204
+ return "variant"
205
+ if type.lower() in ["null"]:
206
+ return None
207
+ return None
208
+
209
+
210
+ def to_physical_type(config: Dict[str, Any]) -> str | None:
211
+ if config is None:
212
+ return None
213
+ if "postgresType" in config:
214
+ return config["postgresType"]
215
+ elif "bigqueryType" in config:
216
+ return config["bigqueryType"]
217
+ elif "snowflakeType" in config:
218
+ return config["snowflakeType"]
219
+ elif "redshiftType" in config:
220
+ return config["redshiftType"]
221
+ elif "sqlserverType" in config:
222
+ return config["sqlserverType"]
223
+ elif "databricksType" in config:
224
+ return config["databricksType"]
225
+ elif "physicalType" in config:
226
+ return config["physicalType"]
227
+ return None
228
+
229
+
230
+ def to_property(field_name: str, field: Field) -> SchemaProperty:
231
+ property = SchemaProperty(name=field_name)
232
+
233
+ if field.fields:
234
+ properties = []
235
+ for field_name_, field_ in field.fields.items():
236
+ property_ = to_property(field_name_, field_)
237
+ properties.append(property_)
238
+ property.properties = properties
239
+
240
+ if field.items:
241
+ items = to_property(field_name, field.items)
242
+ items.name = None # Clear the name for items
243
+ property.items = items
244
+
245
+ if field.title is not None:
246
+ property.businessName = field.title
247
+
248
+ if field.type is not None:
249
+ property.logicalType = to_logical_type(field.type)
250
+ property.physicalType = to_physical_type(field.config) or field.type
251
+
252
+ if field.description is not None:
253
+ property.description = field.description
254
+
255
+ if field.required is not None:
256
+ property.required = field.required
257
+
258
+ if field.unique is not None:
259
+ property.unique = field.unique
260
+
261
+ if field.classification is not None:
262
+ property.classification = field.classification
263
+
264
+ if field.examples is not None:
265
+ property.examples = field.examples.copy()
266
+
267
+ if field.example is not None:
268
+ property.examples = [field.example]
269
+
270
+ if field.primaryKey is not None and field.primaryKey:
271
+ property.primaryKey = field.primaryKey
272
+ property.primaryKeyPosition = 1
273
+
274
+ if field.primary is not None and field.primary:
275
+ property.primaryKey = field.primary
276
+ property.primaryKeyPosition = 1
277
+
278
+ custom_properties = []
279
+ if field.model_extra is not None:
280
+ for key, value in field.model_extra.items():
281
+ custom_properties.append(CustomProperty(property=key, value=value))
282
+
283
+ if field.pii is not None:
284
+ custom_properties.append(CustomProperty(property="pii", value=field.pii))
285
+
286
+ if len(custom_properties) > 0:
287
+ property.customProperties = custom_properties
288
+
289
+ if field.tags is not None and len(field.tags) > 0:
290
+ property.tags = field.tags
291
+
292
+ logical_type_options = {}
293
+ if field.minLength is not None:
294
+ logical_type_options["minLength"] = field.minLength
295
+ if field.maxLength is not None:
296
+ logical_type_options["maxLength"] = field.maxLength
297
+ if field.pattern is not None:
298
+ logical_type_options["pattern"] = field.pattern
299
+ if field.minimum is not None:
300
+ logical_type_options["minimum"] = field.minimum
301
+ if field.maximum is not None:
302
+ logical_type_options["maximum"] = field.maximum
303
+ if field.exclusiveMinimum is not None:
304
+ logical_type_options["exclusiveMinimum"] = field.exclusiveMinimum
305
+ if field.exclusiveMaximum is not None:
306
+ logical_type_options["exclusiveMaximum"] = field.exclusiveMaximum
307
+
308
+ if logical_type_options:
309
+ property.logicalTypeOptions = logical_type_options
310
+
311
+ if field.quality is not None:
312
+ quality_list = field.quality
313
+ quality_property = to_odcs_quality_list(quality_list)
314
+ if len(quality_property) > 0:
315
+ property.quality = quality_property
316
+
317
+ return property
318
+
319
+
320
+ def to_odcs_quality_list(quality_list):
321
+ quality_property = []
322
+ for quality in quality_list:
323
+ quality_property.append(to_odcs_quality(quality))
324
+ return quality_property
325
+
326
+
327
+ def to_odcs_quality(quality):
328
+ quality_obj = DataQuality(type=quality.type)
329
+
330
+ if quality.description is not None:
331
+ quality_obj.description = quality.description
332
+ if quality.query is not None:
333
+ quality_obj.query = quality.query
334
+ # dialect is not supported in v3.0.0
335
+ if quality.mustBe is not None:
336
+ quality_obj.mustBe = quality.mustBe
337
+ if quality.mustNotBe is not None:
338
+ quality_obj.mustNotBe = quality.mustNotBe
339
+ if quality.mustBeGreaterThan is not None:
340
+ quality_obj.mustBeGreaterThan = quality.mustBeGreaterThan
341
+ if quality.mustBeGreaterThanOrEqualTo is not None:
342
+ quality_obj.mustBeGreaterOrEqualTo = quality.mustBeGreaterThanOrEqualTo
343
+ if quality.mustBeLessThan is not None:
344
+ quality_obj.mustBeLessThan = quality.mustBeLessThan
345
+ if quality.mustBeLessThanOrEqualTo is not None:
346
+ quality_obj.mustBeLessOrEqualTo = quality.mustBeLessThanOrEqualTo
347
+ if quality.mustBeBetween is not None:
348
+ quality_obj.mustBeBetween = quality.mustBeBetween
349
+ if quality.mustNotBeBetween is not None:
350
+ quality_obj.mustNotBeBetween = quality.mustNotBeBetween
351
+ if quality.engine is not None:
352
+ quality_obj.engine = quality.engine
353
+ if quality.implementation is not None:
354
+ quality_obj.implementation = quality.implementation
355
+
356
+ return quality_obj
357
+
358
+
359
+ def to_status(status):
360
+ """Convert the data contract status to ODCS v3 format."""
361
+ if status is None:
362
+ return "draft" # Default to draft if no status is provided
363
+
364
+ # Valid status values according to ODCS v3.0.1 spec
365
+ valid_statuses = ["proposed", "draft", "active", "deprecated", "retired"]
366
+
367
+ # Convert to lowercase for comparison
368
+ status_lower = status.lower()
369
+
370
+ # If status is already valid, return it as is
371
+ if status_lower in valid_statuses:
372
+ return status_lower
373
+
374
+ # Default to "draft" for any non-standard status
375
+ return "draft"
@@ -0,0 +1,40 @@
1
+ """
2
+ Module for converting data contract field types to corresponding pandas data types.
3
+ """
4
+
5
+ from datacontract.model.data_contract_specification import Field
6
+
7
+
8
+ def convert_to_pandas_type(field: Field) -> str:
9
+ """
10
+ Convert a data contract field type to the equivalent pandas data type.
11
+
12
+ Parameters:
13
+ ----------
14
+ field : Field
15
+ A Field object containing metadata about the data type of the field.
16
+
17
+ Returns:
18
+ -------
19
+ str
20
+ The corresponding pandas data type as a string.
21
+ """
22
+ field_type = field.type
23
+
24
+ if field_type in ["string", "varchar", "text"]:
25
+ return "str"
26
+ if field_type in ["integer", "int"]:
27
+ return "int32"
28
+ if field_type == "long":
29
+ return "int64"
30
+ if field_type == "float":
31
+ return "float32"
32
+ if field_type in ["number", "decimal", "numeric", "double"]:
33
+ return "float64"
34
+ if field_type == "boolean":
35
+ return "bool"
36
+ if field_type in ["timestamp", "timestamp_tz", "timestamp_ntz", "date"]:
37
+ return "datetime64[ns]"
38
+ if field_type == "bytes":
39
+ return "object"
40
+ return "object"
@@ -1,99 +1,199 @@
1
+ from datacontract.export.exporter import Exporter
1
2
  from datacontract.model.data_contract_specification import DataContractSpecification
2
3
 
3
4
 
4
- def to_protobuf(data_contract_spec: DataContractSpecification):
5
- messages = ""
6
- for model_name, model in data_contract_spec.models.items():
7
- messages += to_protobuf_message(model_name, model.fields, model.description, 0)
8
- messages += "\n"
5
+ class ProtoBufExporter(Exporter):
6
+ def export(self, data_contract, model, server, sql_server_type, export_args) -> dict:
7
+ # Returns a dict containing the protobuf representation.
8
+ proto = to_protobuf(data_contract)
9
+ return {"protobuf": proto}
9
10
 
10
- result = f"""syntax = "proto3";
11
11
 
12
- {messages}
13
- """
14
-
15
- return result
16
-
17
-
18
- def _to_protobuf_message_name(model_name):
19
- return model_name[0].upper() + model_name[1:]
12
+ def to_protobuf(data_contract_spec: DataContractSpecification) -> str:
13
+ """
14
+ Generates a Protobuf file from the data contract specification.
15
+ Scans all models for enum fields (even if the type is "string") by checking for a "values" property.
16
+ """
17
+ messages = ""
18
+ enum_definitions = {}
20
19
 
20
+ # Iterate over all models to generate messages and collect enum definitions.
21
+ for model_name, model in data_contract_spec.models.items():
22
+ for field_name, field in model.fields.items():
23
+ # If the field has enum values, collect them.
24
+ if _is_enum_field(field):
25
+ enum_name = _get_enum_name(field, field_name)
26
+ enum_values = _get_enum_values(field)
27
+ if enum_values and enum_name not in enum_definitions:
28
+ enum_definitions[enum_name] = enum_values
29
+
30
+ messages += to_protobuf_message(model_name, model.fields, getattr(model, "description", ""), 0)
31
+ messages += "\n"
21
32
 
22
- def to_protobuf_message(model_name, fields, description, indent_level: int = 0):
33
+ # Build header with syntax and package declarations.
34
+ header = 'syntax = "proto3";\n\n'
35
+ package = getattr(data_contract_spec, "package", "example")
36
+ header += f"package {package};\n\n"
37
+
38
+ # Append enum definitions.
39
+ for enum_name, enum_values in enum_definitions.items():
40
+ header += f"// Enum for {enum_name}\n"
41
+ header += f"enum {enum_name} {{\n"
42
+ # Only iterate if enum_values is a dictionary.
43
+ if isinstance(enum_values, dict):
44
+ for enum_const, value in sorted(enum_values.items(), key=lambda item: item[1]):
45
+ normalized_const = enum_const.upper().replace(" ", "_")
46
+ header += f" {normalized_const} = {value};\n"
47
+ else:
48
+ header += f" // Warning: Enum values for {enum_name} are not a dictionary\n"
49
+ header += "}\n\n"
50
+ return header + messages
51
+
52
+
53
+ def _is_enum_field(field) -> bool:
54
+ """
55
+ Returns True if the field (dict or object) has a non-empty "values" property.
56
+ """
57
+ if isinstance(field, dict):
58
+ return bool(field.get("values"))
59
+ return bool(getattr(field, "values", None))
60
+
61
+
62
+ def _get_enum_name(field, field_name: str) -> str:
63
+ """
64
+ Returns the enum name either from the field's "enum_name" or derived from the field name.
65
+ """
66
+ if isinstance(field, dict):
67
+ return field.get("enum_name", _to_protobuf_message_name(field_name))
68
+ return getattr(field, "enum_name", None) or _to_protobuf_message_name(field_name)
69
+
70
+
71
+ def _get_enum_values(field) -> dict:
72
+ """
73
+ Returns the enum values from the field.
74
+ If the values are not a dictionary, attempts to extract enum attributes.
75
+ """
76
+ if isinstance(field, dict):
77
+ values = field.get("values", {})
78
+ else:
79
+ values = getattr(field, "values", {})
80
+
81
+ if not isinstance(values, dict):
82
+ # If values is a BaseModel (or similar) with a .dict() method, use it.
83
+ if hasattr(values, "dict") and callable(values.dict):
84
+ values_dict = values.dict()
85
+ return {k: v for k, v in values_dict.items() if k.isupper() and isinstance(v, int)}
86
+ else:
87
+ # Otherwise, iterate over attributes that look like enums.
88
+ return {
89
+ key: getattr(values, key)
90
+ for key in dir(values)
91
+ if key.isupper() and isinstance(getattr(values, key), int)
92
+ }
93
+ return values
94
+
95
+
96
+ def _to_protobuf_message_name(name: str) -> str:
97
+ """
98
+ Returns a valid Protobuf message/enum name by capitalizing the first letter.
99
+ """
100
+ return name[0].upper() + name[1:] if name else name
101
+
102
+
103
+ def to_protobuf_message(model_name: str, fields: dict, description: str, indent_level: int = 0) -> str:
104
+ """
105
+ Generates a Protobuf message definition from the model's fields.
106
+ Handles nested messages for complex types.
107
+ """
23
108
  result = ""
109
+ if description:
110
+ result += f"{indent(indent_level)}// {description}\n"
24
111
 
25
- if description is not None:
26
- result += f"""{indent(indent_level)}/* {description} */\n"""
27
-
28
- fields_protobuf = ""
112
+ result += f"message {_to_protobuf_message_name(model_name)} {{\n"
29
113
  number = 1
30
114
  for field_name, field in fields.items():
31
- if field.type in ["object", "record", "struct"]:
32
- fields_protobuf += (
33
- "\n".join(
34
- map(
35
- lambda x: " " + x,
36
- to_protobuf_message(field_name, field.fields, field.description, indent_level + 1).splitlines(),
37
- )
38
- )
39
- + "\n"
40
- )
41
-
42
- fields_protobuf += to_protobuf_field(field_name, field, field.description, number, 1) + "\n"
115
+ # For nested objects, generate a nested message.
116
+ field_type = _get_field_type(field)
117
+ if field_type in ["object", "record", "struct"]:
118
+ nested_desc = field.get("description", "") if isinstance(field, dict) else getattr(field, "description", "")
119
+ nested_fields = field.get("fields", {}) if isinstance(field, dict) else field.fields
120
+ nested_message = to_protobuf_message(field_name, nested_fields, nested_desc, indent_level + 1)
121
+ result += nested_message + "\n"
122
+
123
+ field_desc = field.get("description", "") if isinstance(field, dict) else getattr(field, "description", "")
124
+ result += to_protobuf_field(field_name, field, field_desc, number, indent_level + 1) + "\n"
43
125
  number += 1
44
- result += f"message {_to_protobuf_message_name(model_name)} {{\n{fields_protobuf}}}\n"
45
126
 
127
+ result += f"{indent(indent_level)}}}\n"
46
128
  return result
47
129
 
48
130
 
49
- def to_protobuf_field(field_name, field, description, number: int, indent_level: int = 0):
50
- optional = ""
51
- if not field.required:
52
- optional = "optional "
53
-
131
+ def to_protobuf_field(field_name: str, field, description: str, number: int, indent_level: int = 0) -> str:
132
+ """
133
+ Generates a field definition within a Protobuf message.
134
+ """
54
135
  result = ""
55
-
56
- if description is not None:
57
- result += f"""{indent(indent_level)}/* {description} */\n"""
58
-
59
- result += f"{indent(indent_level)}{optional}{_convert_type(field_name, field)} {field_name} = {number};"
60
-
136
+ if description:
137
+ result += f"{indent(indent_level)}// {description}\n"
138
+ result += f"{indent(indent_level)}{_convert_type(field_name, field)} {field_name} = {number};"
61
139
  return result
62
140
 
63
141
 
64
- def indent(indent_level):
142
+ def indent(indent_level: int) -> str:
65
143
  return " " * indent_level
66
144
 
67
145
 
68
- def _convert_type(field_name, field) -> None | str:
69
- type = field.type
70
- if type is None:
71
- return None
72
- if type.lower() in ["string", "varchar", "text"]:
73
- return "string"
74
- if type.lower() in ["timestamp", "timestamp_tz"]:
75
- return "string"
76
- if type.lower() in ["timestamp_ntz"]:
77
- return "string"
78
- if type.lower() in ["date"]:
146
+ def _get_field_type(field) -> str:
147
+ """
148
+ Retrieves the field type from the field definition.
149
+ """
150
+ if isinstance(field, dict):
151
+ return field.get("type", "").lower()
152
+ return getattr(field, "type", "").lower()
153
+
154
+
155
+ def _convert_type(field_name: str, field) -> str:
156
+ """
157
+ Converts a field's type (from the data contract) to a Protobuf type.
158
+ Prioritizes enum conversion if a non-empty "values" property exists.
159
+ """
160
+ # For debugging purposes
161
+ print("Converting field:", field_name)
162
+ # If the field should be treated as an enum, return its enum name.
163
+ if _is_enum_field(field):
164
+ return _get_enum_name(field, field_name)
165
+
166
+ lower_type = _get_field_type(field)
167
+ if lower_type in ["string", "varchar", "text"]:
79
168
  return "string"
80
- if type.lower() in ["time"]:
169
+ if lower_type in ["timestamp", "timestamp_tz", "timestamp_ntz", "date", "time"]:
81
170
  return "string"
82
- if type.lower() in ["number", "decimal", "numeric"]:
171
+ if lower_type in ["number", "decimal", "numeric"]:
83
172
  return "double"
84
- if type.lower() in ["float", "double"]:
85
- return type.lower()
86
- if type.lower() in ["integer", "int"]:
173
+ if lower_type in ["float", "double"]:
174
+ return lower_type
175
+ if lower_type in ["integer", "int"]:
87
176
  return "int32"
88
- if type.lower() in ["long", "bigint"]:
177
+ if lower_type in ["long", "bigint"]:
89
178
  return "int64"
90
- if type.lower() in ["boolean"]:
179
+ if lower_type in ["boolean"]:
91
180
  return "bool"
92
- if type.lower() in ["bytes"]:
181
+ if lower_type in ["bytes"]:
93
182
  return "bytes"
94
- if type.lower() in ["object", "record", "struct"]:
183
+ if lower_type in ["object", "record", "struct"]:
95
184
  return _to_protobuf_message_name(field_name)
96
- if type.lower() in ["array"]:
97
- # TODO spec is missing arrays
98
- return "repeated string"
99
- return None
185
+ if lower_type == "array":
186
+ # Handle array types. Check for an "items" property.
187
+ items = field.get("items") if isinstance(field, dict) else getattr(field, "items", None)
188
+ if items and isinstance(items, dict) and items.get("type"):
189
+ item_type = items.get("type", "").lower()
190
+ if item_type in ["object", "record", "struct"]:
191
+ # Singularize the field name (a simple approach).
192
+ singular = field_name[:-1] if field_name.endswith("s") else field_name
193
+ return "repeated " + _to_protobuf_message_name(singular)
194
+ else:
195
+ return "repeated " + _convert_type(field_name, items)
196
+ else:
197
+ return "repeated string"
198
+ # Fallback for unrecognized types.
199
+ return "string"