dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. build_backend.py +93 -0
  2. dsgrid/__init__.py +22 -0
  3. dsgrid/api/__init__.py +0 -0
  4. dsgrid/api/api_manager.py +179 -0
  5. dsgrid/api/app.py +419 -0
  6. dsgrid/api/models.py +60 -0
  7. dsgrid/api/response_models.py +116 -0
  8. dsgrid/apps/__init__.py +0 -0
  9. dsgrid/apps/project_viewer/app.py +216 -0
  10. dsgrid/apps/registration_gui.py +444 -0
  11. dsgrid/chronify.py +32 -0
  12. dsgrid/cli/__init__.py +0 -0
  13. dsgrid/cli/common.py +120 -0
  14. dsgrid/cli/config.py +176 -0
  15. dsgrid/cli/download.py +13 -0
  16. dsgrid/cli/dsgrid.py +157 -0
  17. dsgrid/cli/dsgrid_admin.py +92 -0
  18. dsgrid/cli/install_notebooks.py +62 -0
  19. dsgrid/cli/query.py +729 -0
  20. dsgrid/cli/registry.py +1862 -0
  21. dsgrid/cloud/__init__.py +0 -0
  22. dsgrid/cloud/cloud_storage_interface.py +140 -0
  23. dsgrid/cloud/factory.py +31 -0
  24. dsgrid/cloud/fake_storage_interface.py +37 -0
  25. dsgrid/cloud/s3_storage_interface.py +156 -0
  26. dsgrid/common.py +36 -0
  27. dsgrid/config/__init__.py +0 -0
  28. dsgrid/config/annual_time_dimension_config.py +194 -0
  29. dsgrid/config/common.py +142 -0
  30. dsgrid/config/config_base.py +148 -0
  31. dsgrid/config/dataset_config.py +907 -0
  32. dsgrid/config/dataset_schema_handler_factory.py +46 -0
  33. dsgrid/config/date_time_dimension_config.py +136 -0
  34. dsgrid/config/dimension_config.py +54 -0
  35. dsgrid/config/dimension_config_factory.py +65 -0
  36. dsgrid/config/dimension_mapping_base.py +350 -0
  37. dsgrid/config/dimension_mappings_config.py +48 -0
  38. dsgrid/config/dimensions.py +1025 -0
  39. dsgrid/config/dimensions_config.py +71 -0
  40. dsgrid/config/file_schema.py +190 -0
  41. dsgrid/config/index_time_dimension_config.py +80 -0
  42. dsgrid/config/input_dataset_requirements.py +31 -0
  43. dsgrid/config/mapping_tables.py +209 -0
  44. dsgrid/config/noop_time_dimension_config.py +42 -0
  45. dsgrid/config/project_config.py +1462 -0
  46. dsgrid/config/registration_models.py +188 -0
  47. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  48. dsgrid/config/simple_models.py +49 -0
  49. dsgrid/config/supplemental_dimension.py +29 -0
  50. dsgrid/config/time_dimension_base_config.py +192 -0
  51. dsgrid/data_models.py +155 -0
  52. dsgrid/dataset/__init__.py +0 -0
  53. dsgrid/dataset/dataset.py +123 -0
  54. dsgrid/dataset/dataset_expression_handler.py +86 -0
  55. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  56. dsgrid/dataset/dataset_schema_handler_base.py +945 -0
  57. dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
  58. dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
  59. dsgrid/dataset/growth_rates.py +162 -0
  60. dsgrid/dataset/models.py +51 -0
  61. dsgrid/dataset/table_format_handler_base.py +257 -0
  62. dsgrid/dataset/table_format_handler_factory.py +17 -0
  63. dsgrid/dataset/unpivoted_table.py +121 -0
  64. dsgrid/dimension/__init__.py +0 -0
  65. dsgrid/dimension/base_models.py +230 -0
  66. dsgrid/dimension/dimension_filters.py +308 -0
  67. dsgrid/dimension/standard.py +252 -0
  68. dsgrid/dimension/time.py +352 -0
  69. dsgrid/dimension/time_utils.py +103 -0
  70. dsgrid/dsgrid_rc.py +88 -0
  71. dsgrid/exceptions.py +105 -0
  72. dsgrid/filesystem/__init__.py +0 -0
  73. dsgrid/filesystem/cloud_filesystem.py +32 -0
  74. dsgrid/filesystem/factory.py +32 -0
  75. dsgrid/filesystem/filesystem_interface.py +136 -0
  76. dsgrid/filesystem/local_filesystem.py +74 -0
  77. dsgrid/filesystem/s3_filesystem.py +118 -0
  78. dsgrid/loggers.py +132 -0
  79. dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
  80. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
  81. dsgrid/notebooks/registration.ipynb +48 -0
  82. dsgrid/notebooks/start_notebook.sh +11 -0
  83. dsgrid/project.py +451 -0
  84. dsgrid/query/__init__.py +0 -0
  85. dsgrid/query/dataset_mapping_plan.py +142 -0
  86. dsgrid/query/derived_dataset.py +388 -0
  87. dsgrid/query/models.py +728 -0
  88. dsgrid/query/query_context.py +287 -0
  89. dsgrid/query/query_submitter.py +994 -0
  90. dsgrid/query/report_factory.py +19 -0
  91. dsgrid/query/report_peak_load.py +70 -0
  92. dsgrid/query/reports_base.py +20 -0
  93. dsgrid/registry/__init__.py +0 -0
  94. dsgrid/registry/bulk_register.py +165 -0
  95. dsgrid/registry/common.py +287 -0
  96. dsgrid/registry/config_update_checker_base.py +63 -0
  97. dsgrid/registry/data_store_factory.py +34 -0
  98. dsgrid/registry/data_store_interface.py +74 -0
  99. dsgrid/registry/dataset_config_generator.py +158 -0
  100. dsgrid/registry/dataset_registry_manager.py +950 -0
  101. dsgrid/registry/dataset_update_checker.py +16 -0
  102. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  103. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  104. dsgrid/registry/dimension_registry_manager.py +413 -0
  105. dsgrid/registry/dimension_update_checker.py +16 -0
  106. dsgrid/registry/duckdb_data_store.py +207 -0
  107. dsgrid/registry/filesystem_data_store.py +150 -0
  108. dsgrid/registry/filter_registry_manager.py +123 -0
  109. dsgrid/registry/project_config_generator.py +57 -0
  110. dsgrid/registry/project_registry_manager.py +1623 -0
  111. dsgrid/registry/project_update_checker.py +48 -0
  112. dsgrid/registry/registration_context.py +223 -0
  113. dsgrid/registry/registry_auto_updater.py +316 -0
  114. dsgrid/registry/registry_database.py +667 -0
  115. dsgrid/registry/registry_interface.py +446 -0
  116. dsgrid/registry/registry_manager.py +558 -0
  117. dsgrid/registry/registry_manager_base.py +367 -0
  118. dsgrid/registry/versioning.py +92 -0
  119. dsgrid/rust_ext/__init__.py +14 -0
  120. dsgrid/rust_ext/find_minimal_patterns.py +129 -0
  121. dsgrid/spark/__init__.py +0 -0
  122. dsgrid/spark/functions.py +589 -0
  123. dsgrid/spark/types.py +110 -0
  124. dsgrid/tests/__init__.py +0 -0
  125. dsgrid/tests/common.py +140 -0
  126. dsgrid/tests/make_us_data_registry.py +265 -0
  127. dsgrid/tests/register_derived_datasets.py +103 -0
  128. dsgrid/tests/utils.py +25 -0
  129. dsgrid/time/__init__.py +0 -0
  130. dsgrid/time/time_conversions.py +80 -0
  131. dsgrid/time/types.py +67 -0
  132. dsgrid/units/__init__.py +0 -0
  133. dsgrid/units/constants.py +113 -0
  134. dsgrid/units/convert.py +71 -0
  135. dsgrid/units/energy.py +145 -0
  136. dsgrid/units/power.py +87 -0
  137. dsgrid/utils/__init__.py +0 -0
  138. dsgrid/utils/dataset.py +830 -0
  139. dsgrid/utils/files.py +179 -0
  140. dsgrid/utils/filters.py +125 -0
  141. dsgrid/utils/id_remappings.py +100 -0
  142. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  143. dsgrid/utils/py_expression_eval/README.md +8 -0
  144. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  145. dsgrid/utils/py_expression_eval/tests.py +283 -0
  146. dsgrid/utils/run_command.py +70 -0
  147. dsgrid/utils/scratch_dir_context.py +65 -0
  148. dsgrid/utils/spark.py +918 -0
  149. dsgrid/utils/spark_partition.py +98 -0
  150. dsgrid/utils/timing.py +239 -0
  151. dsgrid/utils/utilities.py +221 -0
  152. dsgrid/utils/versioning.py +36 -0
  153. dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
  154. dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
  155. dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
  156. dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
  157. dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,413 @@
1
+ """Manages the registry for dimensions"""
2
+
3
+ import logging
4
+ from collections import defaultdict
5
+ from pathlib import Path
6
+ from typing import Any, Generator, Sequence, Union
7
+ from uuid import uuid4
8
+
9
+ from prettytable import PrettyTable
10
+ from sqlalchemy import Connection
11
+
12
+ from dsgrid.config.dimension_config_factory import get_dimension_config, load_dimension_config
13
+ from dsgrid.config.dimension_config import (
14
+ DimensionBaseConfig,
15
+ DimensionBaseConfigWithFiles,
16
+ DimensionConfig,
17
+ )
18
+ from dsgrid.config.dimensions_config import DimensionsConfig
19
+ from dsgrid.config.dimensions import (
20
+ DimensionBaseModel,
21
+ DimensionModel,
22
+ TimeDimensionBaseModel,
23
+ DimensionReferenceModel,
24
+ )
25
+ from dsgrid.dimension.base_models import DimensionType
26
+ from dsgrid.registry.common import ConfigKey, RegistryType, VersionUpdateType
27
+ from dsgrid.registry.registry_interface import DimensionRegistryInterface
28
+ from dsgrid.utils.filters import transform_and_validate_filters, matches_filters
29
+ from dsgrid.utils.utilities import display_table
30
+ from .registration_context import RegistrationContext
31
+ from .dimension_update_checker import DimensionUpdateChecker
32
+ from .registry_manager_base import RegistryManagerBase
33
+
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class DimensionRegistryManager(RegistryManagerBase):
39
+ """Manages registered dimensions."""
40
+
41
+ def __init__(self, path, params):
42
+ super().__init__(path, params)
43
+ self._dimensions = {} # key = ConfigKey, value = DimensionConfig
44
+
45
+ @staticmethod
46
+ def config_class():
47
+ return DimensionConfig
48
+
49
+ @property
50
+ def db(self) -> DimensionRegistryInterface:
51
+ return self._db
52
+
53
+ @db.setter
54
+ def db(self, db: DimensionRegistryInterface) -> None:
55
+ self._db = db
56
+
57
+ @staticmethod
58
+ def name() -> str:
59
+ return "Dimensions"
60
+
61
+ def _replace_duplicates(
62
+ self, config: DimensionsConfig, context: RegistrationContext
63
+ ) -> set[str]:
64
+ hashes = defaultdict(list)
65
+ time_dims = {}
66
+ for dimension in self._db.iter_models(context.connection, all_versions=True):
67
+ if isinstance(dimension, TimeDimensionBaseModel):
68
+ time_dims[dimension.id] = dimension
69
+ else:
70
+ assert isinstance(dimension, DimensionModel)
71
+ hashes[dimension.file_hash].append(dimension)
72
+
73
+ existing_ids = set()
74
+ for i, dim in enumerate(config.model.dimensions):
75
+ replace_dim = False
76
+ existing = None
77
+ if isinstance(dim, TimeDimensionBaseModel):
78
+ existing = self._get_matching_time_dimension(time_dims.values(), dim)
79
+ if existing is not None:
80
+ replace_dim = True
81
+ elif dim.file_hash in hashes:
82
+ for existing in hashes[dim.file_hash]:
83
+ if (
84
+ dim.dimension_type == existing.dimension_type
85
+ and dim.name == existing.name
86
+ and dim.description == existing.description
87
+ ):
88
+ replace_dim = True
89
+ break
90
+ if not replace_dim:
91
+ logger.info(
92
+ "Register new dimension even though records are duplicate with "
93
+ "one or more existing dimensions. New name=%s",
94
+ dim.name,
95
+ )
96
+ if replace_dim:
97
+ assert existing is not None
98
+ logger.info(
99
+ "Replace %s with existing dimension ID %s", dim.name, existing.dimension_id
100
+ )
101
+ config.model.dimensions[i] = existing
102
+ existing_ids.add(existing.dimension_id)
103
+ return existing_ids
104
+
105
+ @staticmethod
106
+ def _get_matching_time_dimension(existing_dims, new_dim):
107
+ for time_dim in existing_dims:
108
+ if type(time_dim) is not type(new_dim):
109
+ continue
110
+ match = True
111
+ exclude = set(("dimension_id", "version", "id"))
112
+ for field in type(new_dim).model_fields:
113
+ if field not in exclude and getattr(new_dim, field) != getattr(time_dim, field):
114
+ match = False
115
+ break
116
+ if match:
117
+ return time_dim
118
+
119
+ return None
120
+
121
+ def get_by_id(
122
+ self, config_id: str, version: str | None = None, conn: Connection | None = None
123
+ ) -> DimensionBaseConfig:
124
+ if version is None:
125
+ version = self._db.get_latest_version(conn, config_id)
126
+
127
+ key = ConfigKey(config_id, version)
128
+ dimension = self._dimensions.get(key)
129
+ if dimension is not None:
130
+ return dimension
131
+
132
+ if version is None:
133
+ model = self.db.get_latest(conn, config_id)
134
+ else:
135
+ model = self.db.get_by_version(conn, config_id, version)
136
+
137
+ config = get_dimension_config(model)
138
+ self._dimensions[key] = config
139
+ return config
140
+
141
+ def list_ids(
142
+ self,
143
+ conn: Connection | None = None,
144
+ dimension_type: DimensionType | None = None,
145
+ **kwargs: Any,
146
+ ) -> list[str]:
147
+ """Return the dimension ids for the given type.
148
+
149
+ Parameters
150
+ ----------
151
+ dimension_type
152
+ If not provided, return all dimension ids.
153
+
154
+ Returns
155
+ -------
156
+ list
157
+
158
+ """
159
+ if dimension_type is None:
160
+ ids = super().list_ids(conn)
161
+ else:
162
+ ids = [
163
+ x.dimension_id # type: ignore
164
+ for x in self.db.iter_models(
165
+ conn, filter_config={"dimension_type": dimension_type}
166
+ )
167
+ ]
168
+ ids.sort()
169
+ return ids
170
+
171
+ def load_dimensions(
172
+ self,
173
+ dimension_references: Sequence[DimensionReferenceModel],
174
+ conn: Connection | None = None,
175
+ ) -> dict[ConfigKey, DimensionBaseConfig]:
176
+ """Load dimensions from the database.
177
+
178
+ Parameters
179
+ ----------
180
+ dimension_references
181
+ conn
182
+ Connection to the database, optional. If not provided, a new connection will be created.
183
+
184
+ Returns
185
+ -------
186
+ dict
187
+ ConfigKey to DimensionConfig
188
+ """
189
+ dimensions = {}
190
+ for dim in dimension_references:
191
+ key = ConfigKey(dim.dimension_id, dim.version)
192
+ dimensions[key] = self.get_by_id(dim.dimension_id, version=dim.version, conn=conn)
193
+
194
+ return dimensions
195
+
196
+ def register_from_config(
197
+ self,
198
+ config: DimensionsConfig,
199
+ context: RegistrationContext,
200
+ ) -> list[str]:
201
+ return self._register(config, context)
202
+
203
+ def register(self, config_file: Path, submitter: str, log_message: str) -> list[str]:
204
+ with RegistrationContext(
205
+ self.db, log_message, VersionUpdateType.MAJOR, submitter
206
+ ) as context:
207
+ config = DimensionsConfig.load(config_file)
208
+ return self.register_from_config(config, context=context)
209
+
210
+ def _register(self, config, context: RegistrationContext) -> list[str]:
211
+ existing_ids = self._replace_duplicates(config, context)
212
+ registered_dimension_ids = []
213
+
214
+ # This function will either register the dimension specified by each model or re-use an
215
+ # existing ID. The returned list must be in the same order as the list of models.
216
+ final_dimension_ids = []
217
+ for dim in config.model.dimensions:
218
+ if dim.id is None:
219
+ assert dim.dimension_id is None
220
+ dim.dimension_id = str(uuid4())
221
+ dim.version = "1.0.0"
222
+ dim = self.db.insert(context.connection, dim, context.registration)
223
+ assert isinstance(dim, DimensionBaseModel)
224
+ final_dimension_ids.append(dim.dimension_id)
225
+ registered_dimension_ids.append(dim.dimension_id)
226
+ logger.info(
227
+ "%s Registered dimension id=%s type=%s version=%s name=%s",
228
+ self._log_offline_mode_prefix(),
229
+ dim.id,
230
+ dim.dimension_type.value,
231
+ dim.version,
232
+ dim.name,
233
+ )
234
+ else:
235
+ if dim.dimension_id not in existing_ids:
236
+ msg = f"Bug: {dim.dimension_id=} should have been in existing_ids"
237
+ raise Exception(msg)
238
+ final_dimension_ids.append(dim.dimension_id)
239
+
240
+ logger.info("Registered %s dimensions", len(config.model.dimensions))
241
+ context.add_ids(RegistryType.DIMENSION, registered_dimension_ids, self)
242
+ return final_dimension_ids
243
+
244
+ def make_dimension_references(self, conn: Connection, dimension_ids: list[str]):
245
+ """Return a list of dimension references from a list of registered dimension IDs.
246
+ This assumes that the latest version of the dimensions will be used because they were
247
+ just created.
248
+
249
+ Parameters
250
+ ----------
251
+ dimension_ids : list[str]
252
+
253
+ """
254
+ refs = []
255
+ for dim_id in dimension_ids:
256
+ dim = self.db.get_latest(conn, dim_id)
257
+ assert isinstance(dim, DimensionBaseModel)
258
+ assert isinstance(dim.version, str)
259
+ refs.append(
260
+ DimensionReferenceModel(
261
+ dimension_id=dim_id,
262
+ type=dim.dimension_type,
263
+ version=dim.version,
264
+ )
265
+ )
266
+ return refs
267
+
268
+ def find_matching_dimensions(
269
+ self, sorted_record_ids: list[str], dimension_type: DimensionType
270
+ ) -> Generator[DimensionBaseConfigWithFiles, None, None]:
271
+ """Yield all dimensions that match the given record IDs and dimension type."""
272
+ with self.db.engine.connect() as conn:
273
+ filter_config = {"dimension_type": dimension_type}
274
+ for model in self.db.iter_models(filter_config=filter_config, conn=conn):
275
+ assert isinstance(model, DimensionBaseModel)
276
+ config = self.get_by_id(model.dimension_id, conn=conn)
277
+ if sorted_record_ids == sorted(config.get_unique_ids()):
278
+ yield config
279
+
280
+ def show(
281
+ self,
282
+ conn: Connection | None = None,
283
+ filters: list[str] | None = None,
284
+ max_width: Union[int, dict] | None = None,
285
+ drop_fields: list[str] | None = None,
286
+ dimension_ids: set[str] | None = None,
287
+ return_table: bool = False,
288
+ **kwargs,
289
+ ):
290
+ """Show registry in PrettyTable
291
+
292
+ Parameters
293
+ ----------
294
+ filters : list or tuple
295
+ List of filter expressions for reigstry content (e.g., filters=["Submitter==USER", "Description contains comstock"])
296
+ max_width
297
+ Max column width in PrettyTable, specify as a single value or as a dict of values by field name
298
+ drop_fields
299
+ List of field names not to show
300
+
301
+ """
302
+
303
+ if filters:
304
+ logger.info("List registered dimensions for: %s", filters)
305
+
306
+ table = PrettyTable(title="Dimensions")
307
+ all_field_names = (
308
+ "Type",
309
+ "Query Name",
310
+ "ID",
311
+ "Version",
312
+ "Date",
313
+ "Submitter",
314
+ "Description",
315
+ )
316
+ if drop_fields is None:
317
+ table.field_names = all_field_names
318
+ else:
319
+ table.field_names = tuple(x for x in all_field_names if x not in drop_fields)
320
+
321
+ if max_width is None:
322
+ table._max_width = {
323
+ "ID": 40,
324
+ "Date": 10,
325
+ "Description": 40,
326
+ }
327
+ if isinstance(max_width, int):
328
+ table.max_width = max_width
329
+ elif isinstance(max_width, dict):
330
+ table._max_width = max_width
331
+
332
+ if filters:
333
+ transformed_filters = transform_and_validate_filters(filters)
334
+ field_to_index = {x: i for i, x in enumerate(table.field_names)}
335
+ rows = []
336
+ for model in self.db.iter_models(conn):
337
+ registration = self.db.get_registration(conn, model)
338
+ if dimension_ids and model.dimension_id not in dimension_ids:
339
+ continue
340
+
341
+ all_fields = (
342
+ model.dimension_type.value,
343
+ model.name,
344
+ model.dimension_id,
345
+ model.version,
346
+ registration.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
347
+ registration.submitter,
348
+ registration.log_message,
349
+ )
350
+ if drop_fields is None:
351
+ row = all_fields
352
+ else:
353
+ row = tuple(
354
+ y for (x, y) in zip(all_field_names, all_fields) if x not in drop_fields
355
+ )
356
+
357
+ if not filters or matches_filters(row, field_to_index, transformed_filters):
358
+ rows.append(row)
359
+
360
+ rows.sort(key=lambda x: x[0])
361
+ table.add_rows(rows)
362
+ table.align = "l"
363
+ if return_table:
364
+ return table
365
+ display_table(table)
366
+
367
+ def update_from_file(
368
+ self,
369
+ config_file: Path,
370
+ dimension_id: str,
371
+ submitter: str,
372
+ update_type: VersionUpdateType,
373
+ log_message: str,
374
+ version: str,
375
+ ):
376
+ with RegistrationContext(self.db, log_message, update_type, submitter) as context:
377
+ config = load_dimension_config(config_file)
378
+ self._check_update(context.connection, config, dimension_id, version)
379
+ self.update_with_context(config, context)
380
+
381
+ def update(
382
+ self,
383
+ config,
384
+ update_type: VersionUpdateType,
385
+ log_message: str,
386
+ submitter: str | None = None,
387
+ ) -> DimensionConfig:
388
+ with RegistrationContext(self.db, log_message, update_type, submitter) as context:
389
+ return self.update_with_context(config, context)
390
+
391
+ def update_with_context(self, config, context: RegistrationContext) -> DimensionConfig:
392
+ old_config = self.get_by_id(config.model.dimension_id, conn=context.connection)
393
+ checker = DimensionUpdateChecker(old_config.model, config.model)
394
+ checker.run()
395
+ cur_version = old_config.model.version
396
+ old_key = ConfigKey(config.model.dimension_id, cur_version)
397
+ model = self._update_config(config, context)
398
+ new_key = ConfigKey(model.dimension_id, model.version)
399
+ self._dimensions.pop(old_key, None)
400
+ self._dimensions[new_key] = get_dimension_config(model)
401
+ return self._dimensions[new_key]
402
+
403
+ def finalize_registration(self, conn: Connection, config_ids: set[str], error_occurred: bool):
404
+ if error_occurred:
405
+ for key in [x for x in self._dimensions if x.id in config_ids]:
406
+ self._dimensions.pop(key)
407
+
408
+ def remove(self, dimension_id, conn: Connection | None = None):
409
+ self.db.delete_all(conn, dimension_id)
410
+ for key in [x for x in self._dimensions if x.id == dimension_id]:
411
+ self._dimensions.pop(key)
412
+
413
+ logger.info("Removed %s from the registry.", dimension_id)
@@ -0,0 +1,16 @@
1
+ import logging
2
+
3
+ from .config_update_checker_base import ConfigUpdateCheckerBase
4
+
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class DimensionUpdateChecker(ConfigUpdateCheckerBase):
10
+ """Handles update checks for dimensions."""
11
+
12
+ def check_preconditions(self):
13
+ pass
14
+
15
+ def handle_postconditions(self):
16
+ pass
@@ -0,0 +1,207 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Literal, Self
4
+
5
+ import duckdb
6
+ from duckdb import DuckDBPyConnection
7
+
8
+ import dsgrid
9
+ from dsgrid.common import BackendEngine
10
+ from dsgrid.exceptions import DSGInvalidOperation
11
+ from dsgrid.registry.data_store_interface import DataStoreInterface
12
+ from dsgrid.spark.functions import get_spark_session
13
+ from dsgrid.spark.types import DataFrame
14
+
15
+
16
+ DATABASE_FILENAME = "data.duckdb"
17
+ SCHEMA_DATA = "dsgrid_data"
18
+ SCHEMA_LOOKUP_DATA = "dsgrid_lookup"
19
+ SCHEMA_MISSING_DIMENSION_ASSOCIATIONS = "dsgrid_missing_dimension_associations"
20
+ TABLE_TYPE_TO_SCHEMA = {
21
+ "data": SCHEMA_DATA,
22
+ "lookup": SCHEMA_LOOKUP_DATA,
23
+ "missing_dimension_associations": SCHEMA_MISSING_DIMENSION_ASSOCIATIONS,
24
+ }
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class DuckDbDataStore(DataStoreInterface):
30
+ """Data store that stores tables in a DuckDB database."""
31
+
32
+ def __init__(self, base_path: Path):
33
+ super().__init__(base_path)
34
+ if dsgrid.runtime_config.backend_engine == BackendEngine.SPARK:
35
+ # This currently doesn't work because we convert Spark DataFrames to Pandas DataFrames
36
+ # and Pandas does not support null values. This causes it to convert integer columns
37
+ # to floats and there isn't a great workaround as of now. This is not important
38
+ # because we wouldn't ever want to use Spark backed by a DuckDB database.
39
+ msg = "Spark backend engine is not supported with DuckDbDataStore."
40
+ raise DSGInvalidOperation(msg)
41
+
42
+ @classmethod
43
+ def create(cls, base_path: Path) -> Self:
44
+ base_path.mkdir(exist_ok=True)
45
+ store = cls(base_path)
46
+ db_file = base_path / DATABASE_FILENAME
47
+ if db_file.exists():
48
+ msg = f"Database file {db_file} already exists. Cannot initialize DuckDB data store."
49
+ raise FileExistsError(msg)
50
+ con = duckdb.connect(db_file)
51
+ con.sql(f"CREATE SCHEMA {SCHEMA_DATA}")
52
+ con.sql(f"CREATE SCHEMA {SCHEMA_LOOKUP_DATA}")
53
+ con.sql(f"CREATE SCHEMA {SCHEMA_MISSING_DIMENSION_ASSOCIATIONS}")
54
+ return store
55
+
56
+ @classmethod
57
+ def load(cls, base_path: Path) -> Self:
58
+ """Load an existing DuckDB data store from the given base path."""
59
+ db_file = base_path / DATABASE_FILENAME
60
+ if not db_file.exists():
61
+ msg = f"Database file {db_file} does not exist."
62
+ raise FileNotFoundError(msg)
63
+
64
+ return cls(base_path)
65
+
66
+ def read_table(self, dataset_id: str, version: str) -> DataFrame:
67
+ con = self._get_connection()
68
+ table_name = _make_table_full_name("data", dataset_id, version)
69
+ df = con.sql(f"SELECT * FROM {table_name}").to_df()
70
+ return get_spark_session().createDataFrame(df)
71
+
72
+ def replace_table(self, df: DataFrame, dataset_id: str, version: str) -> None:
73
+ schema = TABLE_TYPE_TO_SCHEMA["data"]
74
+ short_name = _make_table_short_name(dataset_id, version)
75
+ self._replace_table(df, schema, short_name)
76
+
77
+ def read_lookup_table(self, dataset_id: str, version: str) -> DataFrame:
78
+ con = self._get_connection()
79
+ table_name = _make_table_full_name("lookup", dataset_id, version)
80
+ df = con.sql(f"SELECT * FROM {table_name}").to_df()
81
+ return get_spark_session().createDataFrame(df)
82
+
83
+ def replace_lookup_table(self, df: DataFrame, dataset_id: str, version: str) -> None:
84
+ schema = TABLE_TYPE_TO_SCHEMA["lookup"]
85
+ short_name = _make_table_short_name(dataset_id, version)
86
+ self._replace_table(df, schema, short_name)
87
+
88
+ def read_missing_associations_tables(
89
+ self, dataset_id: str, version: str
90
+ ) -> dict[str, DataFrame]:
91
+ con = self._get_connection()
92
+ dfs: dict[str, DataFrame] = {}
93
+ names = self._list_dim_associations_table_names(dataset_id, version)
94
+ if not names:
95
+ return dfs
96
+ for name in names:
97
+ full_name = f"{SCHEMA_MISSING_DIMENSION_ASSOCIATIONS}.{name}"
98
+ df = con.sql(f"SELECT * FROM {full_name}").to_df()
99
+ dfs[name] = get_spark_session().createDataFrame(df)
100
+ return dfs
101
+
102
+ def write_table(
103
+ self, df: DataFrame, dataset_id: str, version: str, overwrite: bool = False
104
+ ) -> None:
105
+ con = self._get_connection()
106
+ table_name = _make_table_full_name("data", dataset_id, version)
107
+ if overwrite:
108
+ con.sql(f"DROP TABLE IF EXISTS {table_name}")
109
+ _create_table_from_dataframe(con, df, table_name)
110
+
111
+ def write_lookup_table(
112
+ self, df: DataFrame, dataset_id: str, version: str, overwrite: bool = False
113
+ ) -> None:
114
+ con = self._get_connection()
115
+ table_name = _make_table_full_name("lookup", dataset_id, version)
116
+ if overwrite:
117
+ con.sql(f"DROP TABLE IF EXISTS {table_name}")
118
+ _create_table_from_dataframe(con, df, table_name)
119
+
120
+ def write_missing_associations_tables(
121
+ self, dfs: dict[str, DataFrame], dataset_id: str, version: str, overwrite: bool = False
122
+ ) -> None:
123
+ con = self._get_connection()
124
+ for tag, df in dfs.items():
125
+ table_name = _make_table_full_name(
126
+ "missing_dimension_associations", dataset_id, version
127
+ )
128
+ table_name = f"{table_name}__{tag}"
129
+ if overwrite:
130
+ con.sql(f"DROP TABLE IF EXISTS {table_name}")
131
+ _create_table_from_dataframe(con, df, table_name)
132
+
133
+ def remove_tables(self, dataset_id: str, version: str) -> None:
134
+ con = self._get_connection()
135
+ for table_type in ("data", "lookup"):
136
+ table_name = _make_table_full_name(table_type, dataset_id, version)
137
+ con.sql(f"DROP TABLE IF EXISTS {table_name}")
138
+ for name in self._list_dim_associations_table_names(dataset_id, version):
139
+ full_name = f"{SCHEMA_MISSING_DIMENSION_ASSOCIATIONS}.{name}"
140
+ con.sql(f"DROP TABLE IF EXISTS {full_name}")
141
+
142
+ @property
143
+ def _data_dir(self) -> Path:
144
+ return self.base_path / "data"
145
+
146
+ @property
147
+ def _db_file(self) -> Path:
148
+ return self.base_path / DATABASE_FILENAME
149
+
150
+ def _get_connection(self) -> duckdb.DuckDBPyConnection:
151
+ return duckdb.connect(self._db_file)
152
+
153
+ def _has_table(self, con: DuckDBPyConnection, schema: str, table_name: str) -> bool:
154
+ return (
155
+ con.sql(
156
+ f"""
157
+ SELECT COUNT(*)
158
+ FROM information_schema.tables
159
+ WHERE table_schema = '{schema}' AND table_name = '{table_name}'
160
+ """
161
+ ).fetchone()[0]
162
+ > 0
163
+ )
164
+
165
+ def _replace_table(self, df: DataFrame, schema: str, table_name: str) -> None:
166
+ con = self._get_connection()
167
+ if not self._has_table(con, schema, table_name):
168
+ _create_table_from_dataframe(con, df, table_name)
169
+ return
170
+
171
+ tmp_name = f"{schema}.{table_name}_tmp"
172
+ _create_table_from_dataframe(con, df, tmp_name)
173
+ con.sql(f"DROP TABLE {table_name}")
174
+ con.sql(f"ALTER TABLE {tmp_name} RENAME TO {table_name}")
175
+
176
+ def _list_dim_associations_table_names(self, dataset_id: str, version: str) -> list[str]:
177
+ con = self._get_connection()
178
+ short_name = _make_table_short_name(dataset_id, version)
179
+ query = f"""
180
+ SELECT table_name
181
+ FROM information_schema.tables
182
+ WHERE table_schema = '{TABLE_TYPE_TO_SCHEMA["missing_dimension_associations"]}' AND table_name LIKE '%{short_name}%'
183
+ """
184
+ return [row[0] for row in con.sql(query).fetchall()]
185
+
186
+
187
+ def _create_table_from_dataframe(
188
+ con: DuckDBPyConnection, df: DataFrame, full_table_name: str
189
+ ) -> None:
190
+ pdf = df.toPandas() # noqa: F841
191
+ con.sql(f"CREATE TABLE {full_table_name} AS SELECT * from pdf")
192
+
193
+
194
+ def _make_table_full_name(
195
+ base_name: Literal["data", "lookup", "missing_dimension_associations"],
196
+ dataset_id: str,
197
+ version: str,
198
+ ) -> str:
199
+ schema = TABLE_TYPE_TO_SCHEMA[base_name]
200
+ short_name = _make_table_short_name(dataset_id, version)
201
+ return f"{schema}.{short_name}"
202
+
203
+
204
+ def _make_table_short_name(dataset_id: str, version: str) -> str:
205
+ # Replace dots so that manual SQL queries don't have to escape them.
206
+ ver = version.replace(".", "_")
207
+ return f"{dataset_id}__{ver}"