dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. build_backend.py +93 -0
  2. dsgrid/__init__.py +22 -0
  3. dsgrid/api/__init__.py +0 -0
  4. dsgrid/api/api_manager.py +179 -0
  5. dsgrid/api/app.py +419 -0
  6. dsgrid/api/models.py +60 -0
  7. dsgrid/api/response_models.py +116 -0
  8. dsgrid/apps/__init__.py +0 -0
  9. dsgrid/apps/project_viewer/app.py +216 -0
  10. dsgrid/apps/registration_gui.py +444 -0
  11. dsgrid/chronify.py +32 -0
  12. dsgrid/cli/__init__.py +0 -0
  13. dsgrid/cli/common.py +120 -0
  14. dsgrid/cli/config.py +176 -0
  15. dsgrid/cli/download.py +13 -0
  16. dsgrid/cli/dsgrid.py +157 -0
  17. dsgrid/cli/dsgrid_admin.py +92 -0
  18. dsgrid/cli/install_notebooks.py +62 -0
  19. dsgrid/cli/query.py +729 -0
  20. dsgrid/cli/registry.py +1862 -0
  21. dsgrid/cloud/__init__.py +0 -0
  22. dsgrid/cloud/cloud_storage_interface.py +140 -0
  23. dsgrid/cloud/factory.py +31 -0
  24. dsgrid/cloud/fake_storage_interface.py +37 -0
  25. dsgrid/cloud/s3_storage_interface.py +156 -0
  26. dsgrid/common.py +36 -0
  27. dsgrid/config/__init__.py +0 -0
  28. dsgrid/config/annual_time_dimension_config.py +194 -0
  29. dsgrid/config/common.py +142 -0
  30. dsgrid/config/config_base.py +148 -0
  31. dsgrid/config/dataset_config.py +907 -0
  32. dsgrid/config/dataset_schema_handler_factory.py +46 -0
  33. dsgrid/config/date_time_dimension_config.py +136 -0
  34. dsgrid/config/dimension_config.py +54 -0
  35. dsgrid/config/dimension_config_factory.py +65 -0
  36. dsgrid/config/dimension_mapping_base.py +350 -0
  37. dsgrid/config/dimension_mappings_config.py +48 -0
  38. dsgrid/config/dimensions.py +1025 -0
  39. dsgrid/config/dimensions_config.py +71 -0
  40. dsgrid/config/file_schema.py +190 -0
  41. dsgrid/config/index_time_dimension_config.py +80 -0
  42. dsgrid/config/input_dataset_requirements.py +31 -0
  43. dsgrid/config/mapping_tables.py +209 -0
  44. dsgrid/config/noop_time_dimension_config.py +42 -0
  45. dsgrid/config/project_config.py +1462 -0
  46. dsgrid/config/registration_models.py +188 -0
  47. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  48. dsgrid/config/simple_models.py +49 -0
  49. dsgrid/config/supplemental_dimension.py +29 -0
  50. dsgrid/config/time_dimension_base_config.py +192 -0
  51. dsgrid/data_models.py +155 -0
  52. dsgrid/dataset/__init__.py +0 -0
  53. dsgrid/dataset/dataset.py +123 -0
  54. dsgrid/dataset/dataset_expression_handler.py +86 -0
  55. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  56. dsgrid/dataset/dataset_schema_handler_base.py +945 -0
  57. dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
  58. dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
  59. dsgrid/dataset/growth_rates.py +162 -0
  60. dsgrid/dataset/models.py +51 -0
  61. dsgrid/dataset/table_format_handler_base.py +257 -0
  62. dsgrid/dataset/table_format_handler_factory.py +17 -0
  63. dsgrid/dataset/unpivoted_table.py +121 -0
  64. dsgrid/dimension/__init__.py +0 -0
  65. dsgrid/dimension/base_models.py +230 -0
  66. dsgrid/dimension/dimension_filters.py +308 -0
  67. dsgrid/dimension/standard.py +252 -0
  68. dsgrid/dimension/time.py +352 -0
  69. dsgrid/dimension/time_utils.py +103 -0
  70. dsgrid/dsgrid_rc.py +88 -0
  71. dsgrid/exceptions.py +105 -0
  72. dsgrid/filesystem/__init__.py +0 -0
  73. dsgrid/filesystem/cloud_filesystem.py +32 -0
  74. dsgrid/filesystem/factory.py +32 -0
  75. dsgrid/filesystem/filesystem_interface.py +136 -0
  76. dsgrid/filesystem/local_filesystem.py +74 -0
  77. dsgrid/filesystem/s3_filesystem.py +118 -0
  78. dsgrid/loggers.py +132 -0
  79. dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
  80. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
  81. dsgrid/notebooks/registration.ipynb +48 -0
  82. dsgrid/notebooks/start_notebook.sh +11 -0
  83. dsgrid/project.py +451 -0
  84. dsgrid/query/__init__.py +0 -0
  85. dsgrid/query/dataset_mapping_plan.py +142 -0
  86. dsgrid/query/derived_dataset.py +388 -0
  87. dsgrid/query/models.py +728 -0
  88. dsgrid/query/query_context.py +287 -0
  89. dsgrid/query/query_submitter.py +994 -0
  90. dsgrid/query/report_factory.py +19 -0
  91. dsgrid/query/report_peak_load.py +70 -0
  92. dsgrid/query/reports_base.py +20 -0
  93. dsgrid/registry/__init__.py +0 -0
  94. dsgrid/registry/bulk_register.py +165 -0
  95. dsgrid/registry/common.py +287 -0
  96. dsgrid/registry/config_update_checker_base.py +63 -0
  97. dsgrid/registry/data_store_factory.py +34 -0
  98. dsgrid/registry/data_store_interface.py +74 -0
  99. dsgrid/registry/dataset_config_generator.py +158 -0
  100. dsgrid/registry/dataset_registry_manager.py +950 -0
  101. dsgrid/registry/dataset_update_checker.py +16 -0
  102. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  103. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  104. dsgrid/registry/dimension_registry_manager.py +413 -0
  105. dsgrid/registry/dimension_update_checker.py +16 -0
  106. dsgrid/registry/duckdb_data_store.py +207 -0
  107. dsgrid/registry/filesystem_data_store.py +150 -0
  108. dsgrid/registry/filter_registry_manager.py +123 -0
  109. dsgrid/registry/project_config_generator.py +57 -0
  110. dsgrid/registry/project_registry_manager.py +1623 -0
  111. dsgrid/registry/project_update_checker.py +48 -0
  112. dsgrid/registry/registration_context.py +223 -0
  113. dsgrid/registry/registry_auto_updater.py +316 -0
  114. dsgrid/registry/registry_database.py +667 -0
  115. dsgrid/registry/registry_interface.py +446 -0
  116. dsgrid/registry/registry_manager.py +558 -0
  117. dsgrid/registry/registry_manager_base.py +367 -0
  118. dsgrid/registry/versioning.py +92 -0
  119. dsgrid/rust_ext/__init__.py +14 -0
  120. dsgrid/rust_ext/find_minimal_patterns.py +129 -0
  121. dsgrid/spark/__init__.py +0 -0
  122. dsgrid/spark/functions.py +589 -0
  123. dsgrid/spark/types.py +110 -0
  124. dsgrid/tests/__init__.py +0 -0
  125. dsgrid/tests/common.py +140 -0
  126. dsgrid/tests/make_us_data_registry.py +265 -0
  127. dsgrid/tests/register_derived_datasets.py +103 -0
  128. dsgrid/tests/utils.py +25 -0
  129. dsgrid/time/__init__.py +0 -0
  130. dsgrid/time/time_conversions.py +80 -0
  131. dsgrid/time/types.py +67 -0
  132. dsgrid/units/__init__.py +0 -0
  133. dsgrid/units/constants.py +113 -0
  134. dsgrid/units/convert.py +71 -0
  135. dsgrid/units/energy.py +145 -0
  136. dsgrid/units/power.py +87 -0
  137. dsgrid/utils/__init__.py +0 -0
  138. dsgrid/utils/dataset.py +830 -0
  139. dsgrid/utils/files.py +179 -0
  140. dsgrid/utils/filters.py +125 -0
  141. dsgrid/utils/id_remappings.py +100 -0
  142. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  143. dsgrid/utils/py_expression_eval/README.md +8 -0
  144. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  145. dsgrid/utils/py_expression_eval/tests.py +283 -0
  146. dsgrid/utils/run_command.py +70 -0
  147. dsgrid/utils/scratch_dir_context.py +65 -0
  148. dsgrid/utils/spark.py +918 -0
  149. dsgrid/utils/spark_partition.py +98 -0
  150. dsgrid/utils/timing.py +239 -0
  151. dsgrid/utils/utilities.py +221 -0
  152. dsgrid/utils/versioning.py +36 -0
  153. dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
  154. dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
  155. dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
  156. dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
  157. dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,950 @@
1
+ """Manages the registry for dimension datasets"""
2
+
3
+ import logging
4
+ import os
5
+ from datetime import datetime, timedelta
6
+ from pathlib import Path
7
+ from typing import Any, Self, Type, Union
8
+ from zoneinfo import ZoneInfo
9
+
10
+ import pandas as pd
11
+ from prettytable import PrettyTable
12
+ from sqlalchemy import Connection
13
+
14
+ from dsgrid.common import SCALING_FACTOR_COLUMN, SYNC_EXCLUDE_LIST
15
+ from dsgrid.config.dataset_config import (
16
+ DatasetConfig,
17
+ DatasetConfigModel,
18
+ user_layout_to_registry_layout,
19
+ )
20
+ from dsgrid.config.dimensions import (
21
+ DateTimeDimensionModel,
22
+ TimeFormatDateTimeNTZModel,
23
+ TimeFormatDateTimeTZModel,
24
+ TimeFormatInPartsModel,
25
+ )
26
+ from dsgrid.dataset.dataset_schema_handler_base import DatasetSchemaHandlerBase
27
+ from dsgrid.dataset.models import TableFormat, ValueFormat
28
+ from dsgrid.config.file_schema import Column, read_data_file
29
+ from dsgrid.config.dataset_schema_handler_factory import make_dataset_schema_handler
30
+ from dsgrid.config.dimensions_config import DimensionsConfig, DimensionsConfigModel
31
+ from dsgrid.dimension.base_models import (
32
+ DatasetDimensionRequirements,
33
+ DimensionType,
34
+ check_required_dataset_dimensions,
35
+ )
36
+ from dsgrid.dsgrid_rc import DsgridRuntimeConfig
37
+ from dsgrid.exceptions import DSGInvalidDataset
38
+ from dsgrid.registry.dimension_registry_manager import DimensionRegistryManager
39
+ from dsgrid.registry.dimension_mapping_registry_manager import (
40
+ DimensionMappingRegistryManager,
41
+ )
42
+ from dsgrid.registry.data_store_interface import DataStoreInterface
43
+ from dsgrid.registry.registry_interface import DatasetRegistryInterface
44
+ from dsgrid.spark.functions import (
45
+ get_spark_session,
46
+ is_dataframe_empty,
47
+ select_expr,
48
+ )
49
+ from dsgrid.spark.types import get_str_type, use_duckdb
50
+ from dsgrid.spark.types import DataFrame, F, StringType
51
+ from dsgrid.utils.dataset import add_time_zone, split_expected_missing_rows, unpivot_dataframe
52
+ from dsgrid.utils.scratch_dir_context import ScratchDirContext
53
+ from dsgrid.utils.spark import (
54
+ read_dataframe,
55
+ write_dataframe,
56
+ )
57
+ from dsgrid.utils.timing import timer_stats_collector, track_timing
58
+ from dsgrid.utils.filters import transform_and_validate_filters, matches_filters
59
+ from dsgrid.utils.utilities import check_uniqueness, display_table, make_unique_key
60
+ from .common import (
61
+ VersionUpdateType,
62
+ ConfigKey,
63
+ RegistryType,
64
+ )
65
+ from .registration_context import RegistrationContext
66
+ from .registry_manager_base import RegistryManagerBase
67
+
68
+ logger = logging.getLogger(__name__)
69
+
70
+
71
+ class DatasetRegistryManager(RegistryManagerBase):
72
+ """Manages registered dimension datasets."""
73
+
74
+ def __init__(
75
+ self,
76
+ path,
77
+ fs_interface,
78
+ dimension_manager: DimensionRegistryManager,
79
+ dimension_mapping_manager: DimensionMappingRegistryManager,
80
+ db: DatasetRegistryInterface,
81
+ store: DataStoreInterface,
82
+ ):
83
+ super().__init__(path, fs_interface)
84
+ self._datasets: dict[ConfigKey, DatasetConfig] = {}
85
+ self._dimension_mgr = dimension_manager
86
+ self._dimension_mapping_mgr = dimension_mapping_manager
87
+ self._db = db
88
+ self._store = store
89
+
90
+ @classmethod
91
+ def load(
92
+ cls,
93
+ path: Path,
94
+ params,
95
+ dimension_manager: DimensionRegistryManager,
96
+ dimension_mapping_manager: DimensionMappingRegistryManager,
97
+ db: DatasetRegistryInterface,
98
+ store: DataStoreInterface,
99
+ ) -> Self:
100
+ return cls._load(path, params, dimension_manager, dimension_mapping_manager, db, store)
101
+
102
+ @staticmethod
103
+ def config_class() -> Type:
104
+ return DatasetConfig
105
+
106
+ @property
107
+ def db(self) -> DatasetRegistryInterface:
108
+ return self._db
109
+
110
+ @property
111
+ def store(self) -> DataStoreInterface:
112
+ return self._store
113
+
114
+ @staticmethod
115
+ def name() -> str:
116
+ return "Datasets"
117
+
118
+ def _get_registry_data_path(self):
119
+ if self._params.use_remote_data:
120
+ dataset_path = self._params.remote_path
121
+ else:
122
+ dataset_path = str(self._params.base_path)
123
+ return dataset_path
124
+
125
+ @track_timing(timer_stats_collector)
126
+ def _run_checks(
127
+ self,
128
+ conn: Connection,
129
+ config: DatasetConfig,
130
+ missing_dimension_associations: dict[str, DataFrame],
131
+ scratch_dir_context: ScratchDirContext,
132
+ requirements: DatasetDimensionRequirements,
133
+ ) -> None:
134
+ logger.info("Run dataset registration checks.")
135
+ check_required_dataset_dimensions(
136
+ config.model.dimension_references, requirements, "dataset dimensions"
137
+ )
138
+ check_uniqueness((x.model.name for x in config.model.dimensions), "dimension name")
139
+ if not os.environ.get("__DSGRID_SKIP_CHECK_DATASET_CONSISTENCY__"):
140
+ self._check_dataset_consistency(
141
+ conn,
142
+ config,
143
+ missing_dimension_associations,
144
+ scratch_dir_context,
145
+ requirements,
146
+ )
147
+
148
+ def _check_dataset_consistency(
149
+ self,
150
+ conn: Connection,
151
+ config: DatasetConfig,
152
+ missing_dimension_associations: dict[str, DataFrame],
153
+ scratch_dir_context: ScratchDirContext,
154
+ requirements: DatasetDimensionRequirements,
155
+ ) -> None:
156
+ schema_handler = make_dataset_schema_handler(
157
+ conn,
158
+ config,
159
+ self._dimension_mgr,
160
+ self._dimension_mapping_mgr,
161
+ store=self._store,
162
+ )
163
+ schema_handler.check_consistency(
164
+ missing_dimension_associations, scratch_dir_context, requirements
165
+ )
166
+
167
+ @property
168
+ def dimension_manager(self) -> DimensionRegistryManager:
169
+ return self._dimension_mgr
170
+
171
+ @property
172
+ def dimension_mapping_manager(self) -> DimensionMappingRegistryManager:
173
+ return self._dimension_mapping_mgr
174
+
175
+ def finalize_registration(self, conn: Connection, config_ids: set[str], error_occurred: bool):
176
+ assert len(config_ids) == 1, config_ids
177
+ if error_occurred:
178
+ for dataset_id in config_ids:
179
+ logger.info("Remove intermediate dataset after error")
180
+ self.remove_data(dataset_id, "1.0.0")
181
+ for key in [x for x in self._datasets if x.id in config_ids]:
182
+ self._datasets.pop(key)
183
+
184
+ if not self.offline_mode:
185
+ for dataset_id in config_ids:
186
+ lock_file = self.get_registry_lock_file(dataset_id)
187
+ self.cloud_interface.check_lock_file(lock_file)
188
+ if not error_occurred:
189
+ self.sync_push(self.get_registry_data_directory(dataset_id))
190
+ self.cloud_interface.remove_lock_file(lock_file)
191
+
192
+ def get_by_id(
193
+ self, config_id: str, version: str | None = None, conn: Connection | None = None
194
+ ) -> DatasetConfig:
195
+ if version is None:
196
+ version = self._db.get_latest_version(conn, config_id)
197
+
198
+ key = ConfigKey(config_id, version)
199
+ dataset = self._datasets.get(key)
200
+ if dataset is not None:
201
+ return dataset
202
+
203
+ if version is None:
204
+ model = self.db.get_latest(conn, config_id)
205
+ else:
206
+ model = self.db.get_by_version(conn, config_id, version)
207
+
208
+ config = DatasetConfig(model)
209
+ if config.model.data_layout is not None:
210
+ msg = f"Dataset {config_id} loaded from registry has data_layout set; expected None"
211
+ raise DSGInvalidDataset(msg)
212
+ if config.model.registry_data_layout is None:
213
+ msg = f"Dataset {config_id} loaded from registry has registry_data_layout=None; expected a value"
214
+ raise DSGInvalidDataset(msg)
215
+ self._update_dimensions(conn, config)
216
+ self._datasets[key] = config
217
+ return config
218
+
219
+ def acquire_registry_locks(self, config_ids: list[str]):
220
+ """Acquire lock(s) on the registry for all config_ids.
221
+
222
+ Parameters
223
+ ----------
224
+ config_ids : list[str]
225
+
226
+ Raises
227
+ ------
228
+ DSGRegistryLockError
229
+ Raised if a lock cannot be acquired.
230
+
231
+ """
232
+ for dataset_id in config_ids:
233
+ lock_file = self.get_registry_lock_file(dataset_id)
234
+ self.cloud_interface.make_lock_file(lock_file)
235
+
236
+ def get_registry_lock_file(self, config_id: str):
237
+ """Return registry lock file path.
238
+
239
+ Parameters
240
+ ----------
241
+ config_id : str
242
+ Config ID
243
+
244
+ Returns
245
+ -------
246
+ str
247
+ Lock file path
248
+ """
249
+ return f"configs/.locks/{config_id}.lock"
250
+
251
+ def _update_dimensions(self, conn: Connection | None, config: DatasetConfig):
252
+ dimensions = self._dimension_mgr.load_dimensions(
253
+ config.model.dimension_references, conn=conn
254
+ )
255
+ config.update_dimensions(dimensions)
256
+
257
+ def register(
258
+ self,
259
+ config_file: Path,
260
+ submitter: str | None = None,
261
+ log_message: str | None = None,
262
+ context: RegistrationContext | None = None,
263
+ data_base_dir: Path | None = None,
264
+ missing_associations_base_dir: Path | None = None,
265
+ requirements: DatasetDimensionRequirements | None = None,
266
+ ):
267
+ config = DatasetConfig.load_from_user_path(
268
+ config_file,
269
+ data_base_dir=data_base_dir,
270
+ missing_associations_base_dir=missing_associations_base_dir,
271
+ )
272
+ if context is None:
273
+ assert submitter is not None
274
+ assert log_message is not None
275
+ with RegistrationContext(
276
+ self.db, log_message, VersionUpdateType.MAJOR, submitter
277
+ ) as context:
278
+ return self.register_from_config(
279
+ config,
280
+ context,
281
+ requirements=requirements,
282
+ )
283
+ else:
284
+ return self.register_from_config(
285
+ config,
286
+ context,
287
+ )
288
+
289
+ @track_timing(timer_stats_collector)
290
+ def register_from_config(
291
+ self,
292
+ config: DatasetConfig,
293
+ context: RegistrationContext,
294
+ requirements: DatasetDimensionRequirements | None = None,
295
+ ):
296
+ self._update_dimensions(context.connection, config)
297
+ self._register_dataset_and_dimensions(
298
+ config,
299
+ context,
300
+ requirements,
301
+ )
302
+
303
+ def _register_dataset_and_dimensions(
304
+ self,
305
+ config: DatasetConfig,
306
+ context: RegistrationContext,
307
+ requirements: DatasetDimensionRequirements | None = None,
308
+ ):
309
+ logger.info("Start registration of dataset %s", config.model.dataset_id)
310
+
311
+ conn = context.connection
312
+ self._check_if_already_registered(conn, config.model.dataset_id)
313
+
314
+ if config.model.dimensions:
315
+ dim_model = DimensionsConfigModel(dimensions=config.model.dimensions)
316
+ dims_config = DimensionsConfig.load_from_model(dim_model)
317
+ dimension_ids = self._dimension_mgr.register_from_config(dims_config, context=context)
318
+ config.model.dimension_references += self._dimension_mgr.make_dimension_references(
319
+ conn, dimension_ids
320
+ )
321
+ config.model.dimensions.clear()
322
+
323
+ self._update_dimensions(conn, config)
324
+ self._register(
325
+ config,
326
+ context,
327
+ requirements=requirements,
328
+ )
329
+ context.add_id(RegistryType.DATASET, config.model.dataset_id, self)
330
+
331
+ def _register(
332
+ self,
333
+ config: DatasetConfig,
334
+ context: RegistrationContext,
335
+ requirements: DatasetDimensionRequirements | None = None,
336
+ ):
337
+ reqs = requirements or DatasetDimensionRequirements()
338
+ config.model.version = "1.0.0"
339
+ # Explanation for this order of operations:
340
+ # 1. Check time consistency in the original dataset format.
341
+ # Many datasets are stored in pivoted format and have many value columns. If we
342
+ # check timestamps after unpivoting the dataset, we will multiply the required work
343
+ # by the number of columns.
344
+ # 2. Write to the registry in unpivoted format before running the other checks.
345
+ # The final data is always stored in unpivoted format. We can reduce code if we
346
+ # transform pivoted tables first.
347
+ # In the nominal case where the dataset is valid, there is no difference in performance.
348
+ # In the failure case where the dataset is invalid, it will take longer to detect the
349
+ # errors.
350
+ dsgrid_config = DsgridRuntimeConfig.load()
351
+ scratch_dir = dsgrid_config.get_scratch_dir()
352
+ scratch_dir.mkdir(parents=True, exist_ok=True)
353
+ with ScratchDirContext(scratch_dir) as scratch_dir_context:
354
+ schema_handler = make_dataset_schema_handler(
355
+ context.connection,
356
+ config,
357
+ self._dimension_mgr,
358
+ self._dimension_mapping_mgr,
359
+ store=None,
360
+ scratch_dir_context=scratch_dir_context,
361
+ )
362
+ self._convert_time_format_if_necessary(config, schema_handler, scratch_dir_context)
363
+ if reqs.check_time_consistency:
364
+ schema_handler.check_time_consistency()
365
+ else:
366
+ logger.info("Skip dataset time checks for %s", config.model.dataset_id)
367
+ self._write_to_registry(config, scratch_dir_context=scratch_dir_context)
368
+
369
+ assoc_dfs = self._store.read_missing_associations_tables(
370
+ config.model.dataset_id, config.model.version
371
+ )
372
+ try:
373
+ self._run_checks(
374
+ context.connection,
375
+ config,
376
+ assoc_dfs,
377
+ scratch_dir_context,
378
+ reqs,
379
+ )
380
+ except Exception:
381
+ self._store.remove_tables(config.model.dataset_id, config.model.version)
382
+ raise
383
+
384
+ if config.model.data_layout is not None:
385
+ registry_layout = user_layout_to_registry_layout(config.model.data_layout)
386
+ config.model.data_layout = None
387
+ config.model.registry_data_layout = registry_layout
388
+ self._db.insert(context.connection, config.model, context.registration)
389
+ logger.info(
390
+ "%s Registered dataset %s with version=%s",
391
+ self._log_offline_mode_prefix(),
392
+ config.model.dataset_id,
393
+ config.model.version,
394
+ )
395
+
396
+ def _convert_time_format_if_necessary(
397
+ self,
398
+ config: DatasetConfig,
399
+ handler: DatasetSchemaHandlerBase,
400
+ scratch_dir_context: ScratchDirContext,
401
+ ) -> None:
402
+ """Convert time-in-parts format to timestamp format if necessary."""
403
+ time_dim = config.get_dimension(DimensionType.TIME)
404
+ if time_dim is None:
405
+ return
406
+
407
+ if not isinstance(time_dim.model, DateTimeDimensionModel) or not isinstance(
408
+ time_dim.model.column_format, TimeFormatInPartsModel
409
+ ):
410
+ return
411
+
412
+ # TEMPORARY
413
+ # This code only exists because we lack full support for time zone naive timestamps.
414
+ # Refactor when the existing chronify work is complete.
415
+ df = handler.get_base_load_data_table()
416
+ col_format = time_dim.model.column_format
417
+
418
+ timestamp_str_expr = self._build_timestamp_string_expr(col_format)
419
+ df, fixed_tz = self._resolve_timezone(df, config, col_format)
420
+ timestamp_sql, new_col_format = self._build_timestamp_sql(timestamp_str_expr, fixed_tz)
421
+
422
+ cols_to_drop = self._get_time_columns_to_drop(col_format)
423
+ reformatted_df = self._apply_timestamp_transformation(df, cols_to_drop, timestamp_sql)
424
+
425
+ self._update_config_for_timestamp(
426
+ config, reformatted_df, scratch_dir_context, cols_to_drop, new_col_format
427
+ )
428
+ self._update_time_dimension(time_dim, new_col_format, col_format.hour_column)
429
+ logger.info("Replaced time columns %s with %s", cols_to_drop, new_col_format.time_column)
430
+
431
+ @staticmethod
432
+ def _build_timestamp_string_expr(col_format: TimeFormatInPartsModel) -> str:
433
+ """Build SQL expression that creates a timestamp string from time-in-parts columns."""
434
+ str_type = get_str_type()
435
+ hour_col = col_format.hour_column
436
+ hour_expr = f"lpad(cast({hour_col} as {str_type}), 2, '0')" if hour_col else "'00'"
437
+
438
+ return (
439
+ f"cast({col_format.year_column} as {str_type}) || '-' || "
440
+ f"lpad(cast({col_format.month_column} as {str_type}), 2, '0') || '-' || "
441
+ f"lpad(cast({col_format.day_column} as {str_type}), 2, '0') || ' ' || "
442
+ f"{hour_expr} || ':00:00'"
443
+ )
444
+
445
+ @staticmethod
446
+ def _resolve_timezone(
447
+ df: DataFrame, config: DatasetConfig, col_format: TimeFormatInPartsModel
448
+ ) -> tuple[DataFrame, str | None]:
449
+ """Resolve timezone from config or geography dimension.
450
+
451
+ Returns the (possibly modified) dataframe and the fixed timezone if single-tz,
452
+ or None if multi-timezone.
453
+ """
454
+ if col_format.time_zone is not None:
455
+ return df, col_format.time_zone
456
+
457
+ geo_dim = config.get_dimension(DimensionType.GEOGRAPHY)
458
+ assert geo_dim is not None
459
+ return add_time_zone(df, geo_dim), None
460
+
461
+ @staticmethod
462
+ def _build_timestamp_sql(
463
+ timestamp_str_expr: str, fixed_tz: str | None
464
+ ) -> tuple[str, TimeFormatDateTimeTZModel | TimeFormatDateTimeNTZModel]:
465
+ """Build the final timestamp SQL expression and determine the column format."""
466
+ if fixed_tz is not None:
467
+ if use_duckdb():
468
+ sql = f"cast({timestamp_str_expr} || ' {fixed_tz}' as timestamptz) as timestamp"
469
+ else:
470
+ tz = ZoneInfo(fixed_tz)
471
+ offset = datetime(2012, 1, 1, tzinfo=tz).strftime("%z")
472
+ offset_formatted = f"{offset[:3]}:{offset[3:]}"
473
+ sql = (
474
+ f"cast({timestamp_str_expr} || '{offset_formatted}' as timestamp) as timestamp"
475
+ )
476
+ return sql, TimeFormatDateTimeTZModel(time_column="timestamp")
477
+
478
+ # Multi-timezone: create naive timestamp, keep time_zone column
479
+ sql = f"cast({timestamp_str_expr} as timestamp) as timestamp"
480
+ return sql, TimeFormatDateTimeNTZModel(time_column="timestamp")
481
+
482
+ @staticmethod
483
+ def _get_time_columns_to_drop(col_format: TimeFormatInPartsModel) -> set[str]:
484
+ """Get the set of time-in-parts columns to drop, excluding dimension columns."""
485
+ cols_to_drop = {col_format.year_column, col_format.month_column, col_format.day_column}
486
+ if col_format.hour_column:
487
+ cols_to_drop.add(col_format.hour_column)
488
+ return cols_to_drop - DimensionType.get_allowed_dimension_column_names()
489
+
490
+ @staticmethod
491
+ def _apply_timestamp_transformation(
492
+ df: DataFrame, cols_to_drop: set[str], timestamp_sql: str
493
+ ) -> DataFrame:
494
+ """Apply the timestamp transformation to the dataframe."""
495
+ existing_cols = [c for c in df.columns if c not in cols_to_drop]
496
+ return select_expr(df, existing_cols + [timestamp_sql])
497
+
498
+ def _update_config_for_timestamp(
499
+ self,
500
+ config: DatasetConfig,
501
+ df: DataFrame,
502
+ scratch_dir_context: ScratchDirContext,
503
+ cols_to_drop: set[str],
504
+ new_col_format: TimeFormatDateTimeTZModel | TimeFormatDateTimeNTZModel,
505
+ ) -> None:
506
+ """Write the transformed dataframe and update config paths and columns."""
507
+ path = scratch_dir_context.get_temp_filename(suffix=".parquet")
508
+ write_dataframe(df, path)
509
+ config.model.data_layout.data_file.path = str(path)
510
+
511
+ if config.model.data_layout.data_file.columns is not None:
512
+ updated_columns = [
513
+ c for c in config.model.data_layout.data_file.columns if c.name not in cols_to_drop
514
+ ]
515
+ timestamp_data_type = (
516
+ "TIMESTAMP_TZ"
517
+ if isinstance(new_col_format, TimeFormatDateTimeTZModel)
518
+ else "TIMESTAMP_NTZ"
519
+ )
520
+ updated_columns.append(Column(name="timestamp", data_type=timestamp_data_type))
521
+ config.model.data_layout.data_file.columns = updated_columns
522
+
523
+ @staticmethod
524
+ def _update_time_dimension(
525
+ time_dim,
526
+ new_col_format: TimeFormatDateTimeTZModel | TimeFormatDateTimeNTZModel,
527
+ hour_col: str | None,
528
+ ) -> None:
529
+ """Update the time dimension model with the new format."""
530
+ time_dim.model.column_format = new_col_format
531
+ time_dim.model.time_column = "timestamp"
532
+ if hour_col is None:
533
+ for time_range in time_dim.model.ranges:
534
+ time_range.frequency = timedelta(days=1)
535
+
536
+ def _read_lookup_table_from_user_path(
537
+ self, config: DatasetConfig, scratch_dir_context: ScratchDirContext | None = None
538
+ ) -> tuple[DataFrame, DataFrame | None]:
539
+ if config.lookup_file_schema is None:
540
+ msg = "Cannot read lookup table without lookup file schema"
541
+ raise DSGInvalidDataset(msg)
542
+
543
+ df = read_data_file(config.lookup_file_schema, scratch_dir_context=scratch_dir_context)
544
+ if "id" not in df.columns:
545
+ msg = "load_data_lookup does not include an 'id' column"
546
+ raise DSGInvalidDataset(msg)
547
+ missing = df.filter("id IS NULL").drop("id")
548
+ if is_dataframe_empty(missing):
549
+ missing_df = None
550
+ else:
551
+ missing_df = missing
552
+ if SCALING_FACTOR_COLUMN in missing_df.columns:
553
+ missing_df = missing_df.drop(SCALING_FACTOR_COLUMN)
554
+ return df, missing_df
555
+
556
+ def _read_missing_associations_tables_from_user_path(
557
+ self, config: DatasetConfig
558
+ ) -> dict[str, DataFrame]:
559
+ """Return all missing association tables keyed by the file path stem.
560
+ Tables can be all-dimension-types-in-one or split by groups of dimension types.
561
+ """
562
+ dfs: dict[str, DataFrame] = {}
563
+ missing_paths = config.missing_associations_paths
564
+ if not missing_paths:
565
+ return dfs
566
+
567
+ def add_df(path):
568
+ df = self._read_associations_file(path)
569
+ key = make_unique_key(path.stem, dfs)
570
+ dfs[key] = df
571
+
572
+ for path in missing_paths:
573
+ if path.suffix.lower() == ".parquet":
574
+ add_df(path)
575
+ elif path.is_dir():
576
+ for file_path in path.iterdir():
577
+ if file_path.suffix.lower() in (".csv", ".parquet"):
578
+ add_df(file_path)
579
+ elif path.suffix.lower() in (".csv", ".parquet"):
580
+ add_df(path)
581
+ return dfs
582
+
583
+ @staticmethod
584
+ def _read_associations_file(path: Path) -> DataFrame:
585
+ if path.suffix.lower() == ".csv":
586
+ df = get_spark_session().createDataFrame(pd.read_csv(path, dtype="string"))
587
+ else:
588
+ df = read_dataframe(path)
589
+ for field in df.schema.fields:
590
+ if field.dataType != StringType():
591
+ df = df.withColumn(field.name, F.col(field.name).cast(StringType()))
592
+ return df
593
+
594
+ def _read_table_from_user_path(
595
+ self, config: DatasetConfig, scratch_dir_context: ScratchDirContext | None = None
596
+ ) -> tuple[DataFrame, DataFrame | None]:
597
+ """Read a table from a user-provided path. Split expected-missing rows into a separate
598
+ DataFrame.
599
+
600
+ Parameters
601
+ ----------
602
+ config : DatasetConfig
603
+ The dataset configuration.
604
+ scratch_dir_context : ScratchDirContext | None
605
+ Optional location to store temporary files
606
+
607
+ Returns
608
+ -------
609
+ tuple[DataFrame, DataFrame | None]
610
+ The first DataFrame contains the expected rows, and the second DataFrame contains the
611
+ missing rows or will be None if there are no missing rows.
612
+ """
613
+ if config.data_file_schema is None:
614
+ msg = "Cannot read table without data file schema"
615
+ raise DSGInvalidDataset(msg)
616
+ df = read_data_file(config.data_file_schema, scratch_dir_context=scratch_dir_context)
617
+
618
+ if config.get_value_format() == ValueFormat.PIVOTED:
619
+ logger.info("Convert dataset %s from pivoted to stacked.", config.model.dataset_id)
620
+ needs_unpivot = True
621
+ pivoted_columns = config.get_pivoted_dimension_columns()
622
+ pivoted_dimension_type = config.get_pivoted_dimension_type()
623
+ # Update both fields together to avoid validation errors from the model validator
624
+ config.model.data_layout = config.model.data_layout.model_copy(
625
+ update={"value_format": ValueFormat.STACKED, "pivoted_dimension_type": None}
626
+ )
627
+ else:
628
+ needs_unpivot = False
629
+ pivoted_columns = None
630
+ pivoted_dimension_type = None
631
+
632
+ time_columns: list[str] = []
633
+ time_dim = config.get_time_dimension()
634
+ if time_dim is not None:
635
+ time_columns.extend(time_dim.get_load_data_time_columns())
636
+
637
+ if needs_unpivot:
638
+ assert pivoted_columns is not None
639
+ assert pivoted_dimension_type is not None
640
+ existing_columns = set(df.columns)
641
+ if diff := set(time_columns) - existing_columns:
642
+ msg = f"Expected time columns are not present in the table: {diff=}"
643
+ raise DSGInvalidDataset(msg)
644
+ if diff := set(pivoted_columns) - existing_columns:
645
+ msg = f"Expected pivoted_columns are not present in the table: {diff=}"
646
+ raise DSGInvalidDataset(msg)
647
+ df = unpivot_dataframe(df, pivoted_columns, pivoted_dimension_type.value, time_columns)
648
+
649
+ return split_expected_missing_rows(df, time_columns)
650
+
651
+ def _write_to_registry(
652
+ self,
653
+ config: DatasetConfig,
654
+ orig_version: str | None = None,
655
+ scratch_dir_context: ScratchDirContext | None = None,
656
+ ) -> None:
657
+ lk_df: DataFrame | None = None
658
+ missing_dfs: dict[str, DataFrame] = {}
659
+ match config.get_table_format():
660
+ case TableFormat.ONE_TABLE:
661
+ if not config.has_user_layout:
662
+ assert (
663
+ orig_version is not None
664
+ ), "orig_version must be set if config came from the registry"
665
+ missing_dfs.update(
666
+ self._store.read_missing_associations_tables(
667
+ config.model.dataset_id, orig_version
668
+ )
669
+ )
670
+ ld_df = self._store.read_table(config.model.dataset_id, orig_version)
671
+ else:
672
+ # Note: config will be updated if this is a pivoted table.
673
+ ld_df, missing_df1 = self._read_table_from_user_path(
674
+ config, scratch_dir_context=scratch_dir_context
675
+ )
676
+ missing_dfs2 = self._read_missing_associations_tables_from_user_path(config)
677
+ missing_dfs.update(
678
+ self._check_duplicate_missing_associations(missing_df1, missing_dfs2)
679
+ )
680
+
681
+ case TableFormat.TWO_TABLE:
682
+ if not config.has_user_layout:
683
+ assert (
684
+ orig_version is not None
685
+ ), "orig_version must be set if config came from the registry"
686
+ lk_df = self._store.read_lookup_table(config.model.dataset_id, orig_version)
687
+ ld_df = self._store.read_table(config.model.dataset_id, orig_version)
688
+ missing_dfs = self._store.read_missing_associations_tables(
689
+ config.model.dataset_id, orig_version
690
+ )
691
+ else:
692
+ # Note: config will be updated if this is a pivoted table.
693
+ ld_df, tmp = self._read_table_from_user_path(
694
+ config, scratch_dir_context=scratch_dir_context
695
+ )
696
+ if tmp is not None:
697
+ msg = (
698
+ "NULL rows cannot be present in the load_data table in standard format. "
699
+ "They must be provided in the load_data_lookup table."
700
+ )
701
+ raise DSGInvalidDataset(msg)
702
+ lk_df, missing_df1 = self._read_lookup_table_from_user_path(
703
+ config, scratch_dir_context=scratch_dir_context
704
+ )
705
+ missing_dfs2 = self._read_missing_associations_tables_from_user_path(config)
706
+ missing_dfs.update(
707
+ self._check_duplicate_missing_associations(missing_df1, missing_dfs2)
708
+ )
709
+ case _:
710
+ msg = f"Unsupported table format: {config.get_table_format()}"
711
+ raise Exception(msg)
712
+
713
+ self._store.write_table(ld_df, config.model.dataset_id, config.model.version)
714
+ if lk_df is not None:
715
+ self._store.write_lookup_table(lk_df, config.model.dataset_id, config.model.version)
716
+ if missing_dfs:
717
+ self._store.write_missing_associations_tables(
718
+ missing_dfs, config.model.dataset_id, config.model.version
719
+ )
720
+
721
+ @staticmethod
722
+ def _check_duplicate_missing_associations(
723
+ df1: DataFrame | None, dfs2: dict[str, DataFrame]
724
+ ) -> dict[str, DataFrame]:
725
+ if df1 is not None and dfs2:
726
+ msg = "A dataset cannot have expected missing rows in the data and "
727
+ "provide a missing_associations file. Provide one or the other."
728
+ raise DSGInvalidDataset(msg)
729
+
730
+ if df1 is not None:
731
+ return {"missing_associations": df1}
732
+ elif dfs2:
733
+ return dfs2
734
+ return {}
735
+
736
+ def _copy_dataset_config(self, conn: Connection, config: DatasetConfig) -> DatasetConfig:
737
+ new_config = DatasetConfig(config.model)
738
+ self._update_dimensions(conn, new_config)
739
+ return new_config
740
+
741
+ def update_from_file(
742
+ self,
743
+ config_file: Path,
744
+ dataset_id: str,
745
+ submitter: str,
746
+ update_type: VersionUpdateType,
747
+ log_message: str,
748
+ version: str,
749
+ ):
750
+ with RegistrationContext(self.db, log_message, update_type, submitter) as context:
751
+ conn = context.connection
752
+ # If config has UserDataLayout, load with validation; otherwise just load
753
+ config = DatasetConfig.load(config_file)
754
+ if config.has_user_layout:
755
+ config = DatasetConfig.load_from_user_path(config_file)
756
+ self._update_dimensions(conn, config)
757
+ self._check_update(conn, config, dataset_id, version)
758
+ self.update_with_context(
759
+ config,
760
+ context,
761
+ )
762
+
763
+ @track_timing(timer_stats_collector)
764
+ def update(
765
+ self,
766
+ config: DatasetConfig,
767
+ update_type: VersionUpdateType,
768
+ log_message: str,
769
+ submitter: str | None = None,
770
+ ) -> DatasetConfig:
771
+ lock_file_path = self.get_registry_lock_file(config.model.dataset_id)
772
+ with self.cloud_interface.make_lock_file_managed(lock_file_path):
773
+ # Note that projects will not pick up these changes until submit-dataset
774
+ # is called again.
775
+ with RegistrationContext(self.db, log_message, update_type, submitter) as context:
776
+ return self.update_with_context(
777
+ config,
778
+ context,
779
+ )
780
+
781
+ def update_with_context(
782
+ self,
783
+ config: DatasetConfig,
784
+ context: RegistrationContext,
785
+ requirements: DatasetDimensionRequirements | None = None,
786
+ ) -> DatasetConfig:
787
+ reqs = requirements or DatasetDimensionRequirements()
788
+ conn = context.connection
789
+ dataset_id = config.model.dataset_id
790
+ cur_config = self.get_by_id(dataset_id, conn=conn)
791
+ updated_model = self._update_config(config, context)
792
+ updated_config = DatasetConfig(updated_model)
793
+ self._update_dimensions(conn, updated_config)
794
+
795
+ dsgrid_config = DsgridRuntimeConfig.load()
796
+ scratch_dir = dsgrid_config.get_scratch_dir()
797
+ scratch_dir.mkdir(parents=True, exist_ok=True)
798
+ with ScratchDirContext(scratch_dir) as scratch_dir_context:
799
+ # Note: this method mutates updated_config.
800
+ self._write_to_registry(
801
+ updated_config,
802
+ orig_version=cur_config.model.version,
803
+ scratch_dir_context=scratch_dir_context,
804
+ )
805
+
806
+ assoc_df = self._store.read_missing_associations_tables(
807
+ updated_config.model.dataset_id, updated_config.model.version
808
+ )
809
+ try:
810
+ self._run_checks(conn, updated_config, assoc_df, scratch_dir_context, reqs)
811
+ except Exception:
812
+ self._store.remove_tables(
813
+ updated_config.model.dataset_id, updated_config.model.version
814
+ )
815
+ raise
816
+
817
+ old_key = ConfigKey(dataset_id, cur_config.model.version)
818
+ new_key = ConfigKey(dataset_id, updated_config.model.version)
819
+ self._datasets.pop(old_key, None)
820
+ self._datasets[new_key] = updated_config
821
+
822
+ if not self.offline_mode:
823
+ self.sync_push(self.get_registry_data_directory(dataset_id))
824
+
825
+ return updated_config
826
+
827
+ def remove(self, config_id: str, conn: Connection | None = None):
828
+ for key in [x for x in self._datasets if x.id == config_id]:
829
+ self.remove_data(config_id, key.version)
830
+ self._datasets.pop(key)
831
+
832
+ self.db.delete_all(conn, config_id)
833
+ logger.info("Removed %s from the registry.", config_id)
834
+
835
+ def remove_data(self, dataset_id: str, version: str):
836
+ self._store.remove_tables(dataset_id, version)
837
+ logger.info("Removed data for %s version=%s from the registry.", dataset_id, version)
838
+
839
+ def show(
840
+ self,
841
+ conn: Connection | None = None,
842
+ filters: list[str] | None = None,
843
+ max_width: Union[int, dict] | None = None,
844
+ drop_fields: list[str] | None = None,
845
+ return_table: bool = False,
846
+ **kwargs: Any,
847
+ ):
848
+ """Show registry in PrettyTable
849
+
850
+ Parameters
851
+ ----------
852
+ filters : list or tuple
853
+ List of filter expressions for reigstry content (e.g., filters=["Submitter==USER", "Description contains comstock"])
854
+ max_width
855
+ Max column width in PrettyTable, specify as a single value or as a dict of values by field name
856
+ drop_fields
857
+ List of field names not to show
858
+
859
+ """
860
+
861
+ if filters:
862
+ logger.info("List registry for: %s", filters)
863
+
864
+ table = PrettyTable(title=self.name())
865
+ all_field_names = (
866
+ "ID",
867
+ "Version",
868
+ "Date",
869
+ "Submitter",
870
+ "Description",
871
+ )
872
+ if drop_fields is None:
873
+ table.field_names = all_field_names
874
+ else:
875
+ table.field_names = tuple(x for x in all_field_names if x not in drop_fields)
876
+
877
+ if max_width is None:
878
+ table._max_width = {
879
+ "ID": 50,
880
+ "Date": 10,
881
+ "Description": 50,
882
+ }
883
+ if isinstance(max_width, int):
884
+ table.max_width = max_width
885
+ elif isinstance(max_width, dict):
886
+ table._max_width = max_width
887
+
888
+ if filters:
889
+ transformed_filters = transform_and_validate_filters(filters)
890
+ field_to_index = {x: i for i, x in enumerate(table.field_names)}
891
+ rows = []
892
+ for model in self.db.iter_models(conn, all_versions=True):
893
+ assert isinstance(model, DatasetConfigModel)
894
+ registration = self.db.get_registration(conn, model)
895
+ all_fields = (
896
+ model.dataset_id,
897
+ model.version,
898
+ registration.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
899
+ registration.submitter,
900
+ registration.log_message,
901
+ )
902
+ if drop_fields is None:
903
+ row = all_fields
904
+ else:
905
+ row = tuple(
906
+ y for (x, y) in zip(all_field_names, all_fields) if x not in drop_fields
907
+ )
908
+
909
+ if not filters or matches_filters(row, field_to_index, transformed_filters):
910
+ rows.append(row)
911
+
912
+ rows.sort(key=lambda x: x[0])
913
+ table.add_rows(rows)
914
+ table.align = "l"
915
+ if return_table:
916
+ return table
917
+ display_table(table)
918
+
919
+ def sync_pull(self, path):
920
+ """Synchronizes files from the remote registry to local.
921
+ Deletes any files locally that do not exist on remote.
922
+
923
+ path : Path
924
+ Local path
925
+
926
+ """
927
+ remote_path = self.relative_remote_path(path)
928
+ self.cloud_interface.sync_pull(
929
+ remote_path, path, exclude=SYNC_EXCLUDE_LIST, delete_local=True
930
+ )
931
+
932
+ def sync_push(self, path):
933
+ """Synchronizes files from the local path to the remote registry.
934
+
935
+ path : Path
936
+ Local path
937
+
938
+ """
939
+ remote_path = self.relative_remote_path(path)
940
+ lock_file_path = self.get_registry_lock_file(path.name)
941
+ self.cloud_interface.check_lock_file(lock_file_path)
942
+ try:
943
+ self.cloud_interface.sync_push(
944
+ remote_path=remote_path, local_path=path, exclude=SYNC_EXCLUDE_LIST
945
+ )
946
+ except Exception:
947
+ logger.exception(
948
+ "Please report this error to the dsgrid team. The registry may need recovery."
949
+ )
950
+ raise