dsgrid-toolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsgrid-toolkit might be problematic. Click here for more details.

Files changed (152) hide show
  1. dsgrid/__init__.py +22 -0
  2. dsgrid/api/__init__.py +0 -0
  3. dsgrid/api/api_manager.py +179 -0
  4. dsgrid/api/app.py +420 -0
  5. dsgrid/api/models.py +60 -0
  6. dsgrid/api/response_models.py +116 -0
  7. dsgrid/apps/__init__.py +0 -0
  8. dsgrid/apps/project_viewer/app.py +216 -0
  9. dsgrid/apps/registration_gui.py +444 -0
  10. dsgrid/chronify.py +22 -0
  11. dsgrid/cli/__init__.py +0 -0
  12. dsgrid/cli/common.py +120 -0
  13. dsgrid/cli/config.py +177 -0
  14. dsgrid/cli/download.py +13 -0
  15. dsgrid/cli/dsgrid.py +142 -0
  16. dsgrid/cli/dsgrid_admin.py +349 -0
  17. dsgrid/cli/install_notebooks.py +62 -0
  18. dsgrid/cli/query.py +711 -0
  19. dsgrid/cli/registry.py +1773 -0
  20. dsgrid/cloud/__init__.py +0 -0
  21. dsgrid/cloud/cloud_storage_interface.py +140 -0
  22. dsgrid/cloud/factory.py +31 -0
  23. dsgrid/cloud/fake_storage_interface.py +37 -0
  24. dsgrid/cloud/s3_storage_interface.py +156 -0
  25. dsgrid/common.py +35 -0
  26. dsgrid/config/__init__.py +0 -0
  27. dsgrid/config/annual_time_dimension_config.py +187 -0
  28. dsgrid/config/common.py +131 -0
  29. dsgrid/config/config_base.py +148 -0
  30. dsgrid/config/dataset_config.py +684 -0
  31. dsgrid/config/dataset_schema_handler_factory.py +41 -0
  32. dsgrid/config/date_time_dimension_config.py +108 -0
  33. dsgrid/config/dimension_config.py +54 -0
  34. dsgrid/config/dimension_config_factory.py +65 -0
  35. dsgrid/config/dimension_mapping_base.py +349 -0
  36. dsgrid/config/dimension_mappings_config.py +48 -0
  37. dsgrid/config/dimensions.py +775 -0
  38. dsgrid/config/dimensions_config.py +71 -0
  39. dsgrid/config/index_time_dimension_config.py +76 -0
  40. dsgrid/config/input_dataset_requirements.py +31 -0
  41. dsgrid/config/mapping_tables.py +209 -0
  42. dsgrid/config/noop_time_dimension_config.py +42 -0
  43. dsgrid/config/project_config.py +1457 -0
  44. dsgrid/config/registration_models.py +199 -0
  45. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  46. dsgrid/config/simple_models.py +49 -0
  47. dsgrid/config/supplemental_dimension.py +29 -0
  48. dsgrid/config/time_dimension_base_config.py +200 -0
  49. dsgrid/data_models.py +155 -0
  50. dsgrid/dataset/__init__.py +0 -0
  51. dsgrid/dataset/dataset.py +123 -0
  52. dsgrid/dataset/dataset_expression_handler.py +86 -0
  53. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  54. dsgrid/dataset/dataset_schema_handler_base.py +899 -0
  55. dsgrid/dataset/dataset_schema_handler_one_table.py +196 -0
  56. dsgrid/dataset/dataset_schema_handler_standard.py +303 -0
  57. dsgrid/dataset/growth_rates.py +162 -0
  58. dsgrid/dataset/models.py +44 -0
  59. dsgrid/dataset/table_format_handler_base.py +257 -0
  60. dsgrid/dataset/table_format_handler_factory.py +17 -0
  61. dsgrid/dataset/unpivoted_table.py +121 -0
  62. dsgrid/dimension/__init__.py +0 -0
  63. dsgrid/dimension/base_models.py +218 -0
  64. dsgrid/dimension/dimension_filters.py +308 -0
  65. dsgrid/dimension/standard.py +213 -0
  66. dsgrid/dimension/time.py +531 -0
  67. dsgrid/dimension/time_utils.py +88 -0
  68. dsgrid/dsgrid_rc.py +88 -0
  69. dsgrid/exceptions.py +105 -0
  70. dsgrid/filesystem/__init__.py +0 -0
  71. dsgrid/filesystem/cloud_filesystem.py +32 -0
  72. dsgrid/filesystem/factory.py +32 -0
  73. dsgrid/filesystem/filesystem_interface.py +136 -0
  74. dsgrid/filesystem/local_filesystem.py +74 -0
  75. dsgrid/filesystem/s3_filesystem.py +118 -0
  76. dsgrid/loggers.py +132 -0
  77. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +950 -0
  78. dsgrid/notebooks/registration.ipynb +48 -0
  79. dsgrid/notebooks/start_notebook.sh +11 -0
  80. dsgrid/project.py +451 -0
  81. dsgrid/query/__init__.py +0 -0
  82. dsgrid/query/dataset_mapping_plan.py +142 -0
  83. dsgrid/query/derived_dataset.py +384 -0
  84. dsgrid/query/models.py +726 -0
  85. dsgrid/query/query_context.py +287 -0
  86. dsgrid/query/query_submitter.py +847 -0
  87. dsgrid/query/report_factory.py +19 -0
  88. dsgrid/query/report_peak_load.py +70 -0
  89. dsgrid/query/reports_base.py +20 -0
  90. dsgrid/registry/__init__.py +0 -0
  91. dsgrid/registry/bulk_register.py +161 -0
  92. dsgrid/registry/common.py +287 -0
  93. dsgrid/registry/config_update_checker_base.py +63 -0
  94. dsgrid/registry/data_store_factory.py +34 -0
  95. dsgrid/registry/data_store_interface.py +69 -0
  96. dsgrid/registry/dataset_config_generator.py +156 -0
  97. dsgrid/registry/dataset_registry_manager.py +734 -0
  98. dsgrid/registry/dataset_update_checker.py +16 -0
  99. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  100. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  101. dsgrid/registry/dimension_registry_manager.py +413 -0
  102. dsgrid/registry/dimension_update_checker.py +16 -0
  103. dsgrid/registry/duckdb_data_store.py +185 -0
  104. dsgrid/registry/filesystem_data_store.py +141 -0
  105. dsgrid/registry/filter_registry_manager.py +123 -0
  106. dsgrid/registry/project_config_generator.py +57 -0
  107. dsgrid/registry/project_registry_manager.py +1616 -0
  108. dsgrid/registry/project_update_checker.py +48 -0
  109. dsgrid/registry/registration_context.py +223 -0
  110. dsgrid/registry/registry_auto_updater.py +316 -0
  111. dsgrid/registry/registry_database.py +662 -0
  112. dsgrid/registry/registry_interface.py +446 -0
  113. dsgrid/registry/registry_manager.py +544 -0
  114. dsgrid/registry/registry_manager_base.py +367 -0
  115. dsgrid/registry/versioning.py +92 -0
  116. dsgrid/spark/__init__.py +0 -0
  117. dsgrid/spark/functions.py +545 -0
  118. dsgrid/spark/types.py +50 -0
  119. dsgrid/tests/__init__.py +0 -0
  120. dsgrid/tests/common.py +139 -0
  121. dsgrid/tests/make_us_data_registry.py +204 -0
  122. dsgrid/tests/register_derived_datasets.py +103 -0
  123. dsgrid/tests/utils.py +25 -0
  124. dsgrid/time/__init__.py +0 -0
  125. dsgrid/time/time_conversions.py +80 -0
  126. dsgrid/time/types.py +67 -0
  127. dsgrid/units/__init__.py +0 -0
  128. dsgrid/units/constants.py +113 -0
  129. dsgrid/units/convert.py +71 -0
  130. dsgrid/units/energy.py +145 -0
  131. dsgrid/units/power.py +87 -0
  132. dsgrid/utils/__init__.py +0 -0
  133. dsgrid/utils/dataset.py +612 -0
  134. dsgrid/utils/files.py +179 -0
  135. dsgrid/utils/filters.py +125 -0
  136. dsgrid/utils/id_remappings.py +100 -0
  137. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  138. dsgrid/utils/py_expression_eval/README.md +8 -0
  139. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  140. dsgrid/utils/py_expression_eval/tests.py +283 -0
  141. dsgrid/utils/run_command.py +70 -0
  142. dsgrid/utils/scratch_dir_context.py +64 -0
  143. dsgrid/utils/spark.py +918 -0
  144. dsgrid/utils/spark_partition.py +98 -0
  145. dsgrid/utils/timing.py +239 -0
  146. dsgrid/utils/utilities.py +184 -0
  147. dsgrid/utils/versioning.py +36 -0
  148. dsgrid_toolkit-0.2.0.dist-info/METADATA +216 -0
  149. dsgrid_toolkit-0.2.0.dist-info/RECORD +152 -0
  150. dsgrid_toolkit-0.2.0.dist-info/WHEEL +4 -0
  151. dsgrid_toolkit-0.2.0.dist-info/entry_points.txt +4 -0
  152. dsgrid_toolkit-0.2.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,734 @@
1
+ """Manages the registry for dimension datasets"""
2
+
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Any, Self, Type, Union
7
+
8
+ import pandas as pd
9
+ from prettytable import PrettyTable
10
+ from sqlalchemy import Connection
11
+
12
+ from dsgrid.common import SCALING_FACTOR_COLUMN, SYNC_EXCLUDE_LIST
13
+ from dsgrid.config.dataset_config import (
14
+ DatasetConfig,
15
+ ALLOWED_LOAD_DATA_FILENAMES,
16
+ ALLOWED_LOAD_DATA_LOOKUP_FILENAMES,
17
+ ALLOWED_MISSING_DIMENSION_ASSOCATIONS_FILENAMES,
18
+ DatasetConfigModel,
19
+ )
20
+ from dsgrid.config.dataset_config import DataSchemaType
21
+ from dsgrid.config.dataset_schema_handler_factory import make_dataset_schema_handler
22
+ from dsgrid.config.dimensions_config import DimensionsConfig, DimensionsConfigModel
23
+ from dsgrid.dataset.models import TableFormatType, UnpivotedTableFormatModel
24
+ from dsgrid.dimension.base_models import (
25
+ check_required_dataset_dimensions,
26
+ )
27
+ from dsgrid.exceptions import DSGInvalidDataset
28
+ from dsgrid.registry.dimension_registry_manager import DimensionRegistryManager
29
+ from dsgrid.registry.dimension_mapping_registry_manager import (
30
+ DimensionMappingRegistryManager,
31
+ )
32
+ from dsgrid.registry.data_store_interface import DataStoreInterface
33
+ from dsgrid.registry.registry_interface import DatasetRegistryInterface
34
+ from dsgrid.spark.functions import (
35
+ get_spark_session,
36
+ is_dataframe_empty,
37
+ )
38
+ from dsgrid.spark.types import DataFrame, F, StringType
39
+ from dsgrid.utils.dataset import split_expected_missing_rows, unpivot_dataframe
40
+ from dsgrid.utils.spark import (
41
+ read_dataframe,
42
+ )
43
+ from dsgrid.utils.timing import timer_stats_collector, track_timing
44
+ from dsgrid.utils.filters import transform_and_validate_filters, matches_filters
45
+ from dsgrid.utils.utilities import check_uniqueness, display_table
46
+ from .common import (
47
+ VersionUpdateType,
48
+ ConfigKey,
49
+ RegistryType,
50
+ )
51
+ from .registration_context import RegistrationContext
52
+ from .registry_manager_base import RegistryManagerBase
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ class DatasetRegistryManager(RegistryManagerBase):
58
+ """Manages registered dimension datasets."""
59
+
60
+ def __init__(
61
+ self,
62
+ path,
63
+ fs_interface,
64
+ dimension_manager: DimensionRegistryManager,
65
+ dimension_mapping_manager: DimensionMappingRegistryManager,
66
+ db: DatasetRegistryInterface,
67
+ store: DataStoreInterface,
68
+ ):
69
+ super().__init__(path, fs_interface)
70
+ self._datasets: dict[ConfigKey, DatasetConfig] = {}
71
+ self._dimension_mgr = dimension_manager
72
+ self._dimension_mapping_mgr = dimension_mapping_manager
73
+ self._db = db
74
+ self._store = store
75
+
76
+ @classmethod
77
+ def load(
78
+ cls,
79
+ path: Path,
80
+ params,
81
+ dimension_manager: DimensionRegistryManager,
82
+ dimension_mapping_manager: DimensionMappingRegistryManager,
83
+ db: DatasetRegistryInterface,
84
+ store: DataStoreInterface,
85
+ ) -> Self:
86
+ return cls._load(path, params, dimension_manager, dimension_mapping_manager, db, store)
87
+
88
+ @staticmethod
89
+ def config_class() -> Type:
90
+ return DatasetConfig
91
+
92
+ @property
93
+ def db(self) -> DatasetRegistryInterface:
94
+ return self._db
95
+
96
+ @property
97
+ def store(self) -> DataStoreInterface:
98
+ return self._store
99
+
100
+ @staticmethod
101
+ def name() -> str:
102
+ return "Datasets"
103
+
104
+ def _get_registry_data_path(self):
105
+ if self._params.use_remote_data:
106
+ dataset_path = self._params.remote_path
107
+ else:
108
+ dataset_path = str(self._params.base_path)
109
+ return dataset_path
110
+
111
+ @track_timing(timer_stats_collector)
112
+ def _run_checks(
113
+ self,
114
+ conn: Connection,
115
+ config: DatasetConfig,
116
+ missing_dimension_associations: DataFrame | None,
117
+ ) -> None:
118
+ logger.info("Run dataset registration checks.")
119
+ check_required_dataset_dimensions(config.model.dimension_references, "dataset dimensions")
120
+ check_uniqueness((x.model.name for x in config.model.dimensions), "dimension name")
121
+ if not os.environ.get("__DSGRID_SKIP_CHECK_DATASET_CONSISTENCY__"):
122
+ self._check_dataset_consistency(
123
+ conn,
124
+ config,
125
+ missing_dimension_associations,
126
+ )
127
+
128
+ def _check_dataset_consistency(
129
+ self,
130
+ conn: Connection,
131
+ config: DatasetConfig,
132
+ missing_dimension_associations: DataFrame | None,
133
+ ) -> None:
134
+ schema_handler = make_dataset_schema_handler(
135
+ conn,
136
+ config,
137
+ self._dimension_mgr,
138
+ self._dimension_mapping_mgr,
139
+ store=self._store,
140
+ )
141
+ schema_handler.check_consistency(missing_dimension_associations)
142
+
143
+ @property
144
+ def dimension_manager(self) -> DimensionRegistryManager:
145
+ return self._dimension_mgr
146
+
147
+ @property
148
+ def dimension_mapping_manager(self) -> DimensionMappingRegistryManager:
149
+ return self._dimension_mapping_mgr
150
+
151
+ def finalize_registration(self, conn: Connection, config_ids: set[str], error_occurred: bool):
152
+ assert len(config_ids) == 1, config_ids
153
+ if error_occurred:
154
+ for dataset_id in config_ids:
155
+ logger.info("Remove intermediate dataset after error")
156
+ self.remove_data(dataset_id, "1.0.0")
157
+ for key in [x for x in self._datasets if x.id in config_ids]:
158
+ self._datasets.pop(key)
159
+
160
+ if not self.offline_mode:
161
+ for dataset_id in config_ids:
162
+ lock_file = self.get_registry_lock_file(dataset_id)
163
+ self.cloud_interface.check_lock_file(lock_file)
164
+ if not error_occurred:
165
+ self.sync_push(self.get_registry_data_directory(dataset_id))
166
+ self.cloud_interface.remove_lock_file(lock_file)
167
+
168
+ def get_by_id(
169
+ self, config_id: str, version: str | None = None, conn: Connection | None = None
170
+ ) -> DatasetConfig:
171
+ if version is None:
172
+ version = self._db.get_latest_version(conn, config_id)
173
+
174
+ key = ConfigKey(config_id, version)
175
+ dataset = self._datasets.get(key)
176
+ if dataset is not None:
177
+ return dataset
178
+
179
+ if version is None:
180
+ model = self.db.get_latest(conn, config_id)
181
+ else:
182
+ model = self.db.get_by_version(conn, config_id, version)
183
+
184
+ config = DatasetConfig(model)
185
+ self._update_dimensions(conn, config)
186
+ self._datasets[key] = config
187
+ return config
188
+
189
+ def acquire_registry_locks(self, config_ids: list[str]):
190
+ """Acquire lock(s) on the registry for all config_ids.
191
+
192
+ Parameters
193
+ ----------
194
+ config_ids : list[str]
195
+
196
+ Raises
197
+ ------
198
+ DSGRegistryLockError
199
+ Raised if a lock cannot be acquired.
200
+
201
+ """
202
+ for dataset_id in config_ids:
203
+ lock_file = self.get_registry_lock_file(dataset_id)
204
+ self.cloud_interface.make_lock_file(lock_file)
205
+
206
+ def get_registry_lock_file(self, config_id: str):
207
+ """Return registry lock file path.
208
+
209
+ Parameters
210
+ ----------
211
+ config_id : str
212
+ Config ID
213
+
214
+ Returns
215
+ -------
216
+ str
217
+ Lock file path
218
+ """
219
+ return f"configs/.locks/{config_id}.lock"
220
+
221
+ def _update_dimensions(self, conn: Connection | None, config: DatasetConfig):
222
+ dimensions = self._dimension_mgr.load_dimensions(
223
+ config.model.dimension_references, conn=conn
224
+ )
225
+ config.update_dimensions(dimensions)
226
+
227
+ def register(
228
+ self,
229
+ config_file: Path,
230
+ dataset_path: Path,
231
+ submitter: str | None = None,
232
+ log_message: str | None = None,
233
+ context: RegistrationContext | None = None,
234
+ ):
235
+ config = DatasetConfig.load_from_user_path(config_file, dataset_path)
236
+ if context is None:
237
+ assert submitter is not None
238
+ assert log_message is not None
239
+ with RegistrationContext(
240
+ self.db, log_message, VersionUpdateType.MAJOR, submitter
241
+ ) as context:
242
+ return self.register_from_config(
243
+ config,
244
+ dataset_path,
245
+ context,
246
+ )
247
+ else:
248
+ return self.register_from_config(
249
+ config,
250
+ dataset_path,
251
+ context,
252
+ )
253
+
254
+ @track_timing(timer_stats_collector)
255
+ def register_from_config(
256
+ self,
257
+ config: DatasetConfig,
258
+ dataset_path: Path,
259
+ context: RegistrationContext,
260
+ ):
261
+ self._update_dimensions(context.connection, config)
262
+ self._register_dataset_and_dimensions(
263
+ config,
264
+ dataset_path,
265
+ context,
266
+ )
267
+
268
+ def _register_dataset_and_dimensions(
269
+ self,
270
+ config: DatasetConfig,
271
+ dataset_path: Path,
272
+ context: RegistrationContext,
273
+ ):
274
+ logger.info("Start registration of dataset %s", config.model.dataset_id)
275
+ # TODO S3: This requires downloading data to the local system.
276
+ # Can we perform all validation on S3 with an EC2 instance?
277
+ if str(dataset_path).startswith("s3://"):
278
+ msg = f"Loading a dataset from S3 is not currently supported: {dataset_path}"
279
+ raise DSGInvalidDataset(msg)
280
+
281
+ conn = context.connection
282
+ self._check_if_already_registered(conn, config.model.dataset_id)
283
+
284
+ if config.model.dimensions:
285
+ dim_model = DimensionsConfigModel(dimensions=config.model.dimensions)
286
+ dims_config = DimensionsConfig.load_from_model(dim_model)
287
+ dimension_ids = self._dimension_mgr.register_from_config(dims_config, context=context)
288
+ config.model.dimension_references += self._dimension_mgr.make_dimension_references(
289
+ conn, dimension_ids
290
+ )
291
+ config.model.dimensions.clear()
292
+
293
+ self._update_dimensions(conn, config)
294
+ self._register(
295
+ config,
296
+ dataset_path,
297
+ context,
298
+ )
299
+ context.add_id(RegistryType.DATASET, config.model.dataset_id, self)
300
+
301
+ def _register(
302
+ self,
303
+ config: DatasetConfig,
304
+ dataset_path: Path,
305
+ context: RegistrationContext,
306
+ ):
307
+ config.model.version = "1.0.0"
308
+ # Explanation for this order of operations:
309
+ # 1. Check time consistency in the original dataset format.
310
+ # Many datasets are stored in pivoted format and have many value columns. If we
311
+ # check timestamps after unpivoting the dataset, we will multiply the required work
312
+ # by the number of columns.
313
+ # 2. Write to the registry in unpivoted format before running the other checks.
314
+ # The final data is always stored in unpivoted format. We can reduce code if we
315
+ # transform pivoted tables first.
316
+ # In the nominal case where the dataset is valid, there is no difference in performance.
317
+ # In the failure case where the dataset is invalid, it will take longer to detect the
318
+ # errors.
319
+ self._check_time_consistency(config, context)
320
+ self._write_to_registry(config)
321
+
322
+ assoc_df = self._store.read_missing_associations_table(
323
+ config.model.dataset_id, config.model.version
324
+ )
325
+ try:
326
+ self._run_checks(
327
+ context.connection,
328
+ config,
329
+ assoc_df,
330
+ )
331
+ except Exception:
332
+ self._store.remove_tables(config.model.dataset_id, config.model.version)
333
+ raise
334
+
335
+ self._db.insert(context.connection, config.model, context.registration)
336
+ logger.info(
337
+ "%s Registered dataset %s with version=%s",
338
+ self._log_offline_mode_prefix(),
339
+ config.model.dataset_id,
340
+ config.model.version,
341
+ )
342
+
343
+ def _check_time_consistency(
344
+ self,
345
+ config: DatasetConfig,
346
+ context: RegistrationContext,
347
+ ) -> None:
348
+ schema_handler = make_dataset_schema_handler(
349
+ context.connection,
350
+ config,
351
+ self._dimension_mgr,
352
+ self._dimension_mapping_mgr,
353
+ store=None,
354
+ )
355
+ schema_handler.check_time_consistency()
356
+
357
+ def _read_lookup_table_from_user_path(self, path: Path) -> tuple[DataFrame, DataFrame | None]:
358
+ for filename in ALLOWED_LOAD_DATA_LOOKUP_FILENAMES:
359
+ lk_path = path / filename
360
+ if lk_path.exists():
361
+ df = read_dataframe(lk_path)
362
+ if "id" not in df.columns:
363
+ msg = "load_data_lookup does not include an 'id' column"
364
+ raise DSGInvalidDataset(msg)
365
+ missing = df.filter("id IS NULL").drop("id")
366
+ if is_dataframe_empty(missing):
367
+ missing_df = None
368
+ else:
369
+ missing_df = missing
370
+ if SCALING_FACTOR_COLUMN in missing_df.columns:
371
+ missing_df = missing_df.drop(SCALING_FACTOR_COLUMN)
372
+ return df, missing_df
373
+
374
+ msg = (
375
+ f"Did not find any lookup data files in {path}. "
376
+ "Expected one of {ALLOWED_LOAD_DATA_LOOKUP_FILENAMES}"
377
+ )
378
+ raise DSGInvalidDataset(msg)
379
+
380
+ def _read_missing_associations_table_from_user_path(
381
+ self, dataset_path: Path
382
+ ) -> DataFrame | None:
383
+ for filename in ALLOWED_MISSING_DIMENSION_ASSOCATIONS_FILENAMES:
384
+ path = dataset_path / filename
385
+ if path.exists():
386
+ if path.suffix.lower() == ".csv":
387
+ df = get_spark_session().createDataFrame(pd.read_csv(path, dtype="string"))
388
+ else:
389
+ df = read_dataframe(path)
390
+ for field in df.schema.fields:
391
+ if field.dataType != StringType():
392
+ df = df.withColumn(field.name, F.col(field.name).cast(StringType()))
393
+ return df
394
+ return None
395
+
396
+ def _read_table_from_user_path(
397
+ self, config: DatasetConfig, path: Path
398
+ ) -> tuple[DataFrame, DataFrame | None]:
399
+ ld_path: Path | None = None
400
+ for filename in ALLOWED_LOAD_DATA_FILENAMES:
401
+ tmp = path / filename
402
+ if tmp.exists():
403
+ ld_path = tmp
404
+ break
405
+ if ld_path is None:
406
+ msg = f"Did not find any load data files in {path}. Expected one of {ALLOWED_LOAD_DATA_FILENAMES}"
407
+ raise DSGInvalidDataset(msg)
408
+
409
+ if config.get_table_format_type() == TableFormatType.PIVOTED:
410
+ logger.info("Convert dataset %s from pivoted to unpivoted.", config.model.dataset_id)
411
+ needs_unpivot = True
412
+ pivoted_columns = config.get_pivoted_dimension_columns()
413
+ pivoted_dimension_type = config.get_pivoted_dimension_type()
414
+ config.model.data_schema.table_format = UnpivotedTableFormatModel()
415
+ else:
416
+ needs_unpivot = False
417
+ pivoted_columns = None
418
+ pivoted_dimension_type = None
419
+
420
+ df = read_dataframe(ld_path)
421
+
422
+ time_dim = config.get_time_dimension()
423
+ time_columns: list[str] = []
424
+ if time_dim is not None:
425
+ time_columns.extend(time_dim.get_load_data_time_columns())
426
+ df = time_dim.convert_time_format(df, update_model=True)
427
+
428
+ if needs_unpivot:
429
+ assert pivoted_columns is not None
430
+ assert pivoted_dimension_type is not None
431
+ existing_columns = set(df.columns)
432
+ if diff := set(time_columns) - existing_columns:
433
+ msg = f"Expected time columns are not present in the table: {diff=}"
434
+ raise DSGInvalidDataset(msg)
435
+ if diff := set(pivoted_columns) - existing_columns:
436
+ msg = f"Expected pivoted_columns are not present in the table: {diff=}"
437
+ raise DSGInvalidDataset(msg)
438
+ df = unpivot_dataframe(df, pivoted_columns, pivoted_dimension_type.value, time_columns)
439
+
440
+ return split_expected_missing_rows(df, time_columns)
441
+
442
+ def _write_to_registry(
443
+ self,
444
+ config: DatasetConfig,
445
+ orig_version: str | None = None,
446
+ ) -> None:
447
+ lk_df: DataFrame | None = None
448
+ missing_df: DataFrame | None = None
449
+ match config.get_data_schema_type():
450
+ case DataSchemaType.ONE_TABLE:
451
+ if config.dataset_path is None:
452
+ assert (
453
+ orig_version is not None
454
+ ), "orig_version must be set if dataset_path is None"
455
+ missing_df = self._store.read_missing_associations_table(
456
+ config.model.dataset_id, orig_version
457
+ )
458
+ ld_df = self._store.read_table(config.model.dataset_id, orig_version)
459
+ else:
460
+ # Note: config will be updated if this is a pivoted table.
461
+ ld_df, missing_df1 = self._read_table_from_user_path(
462
+ config, config.dataset_path
463
+ )
464
+ missing_df2 = self._read_missing_associations_table_from_user_path(
465
+ config.dataset_path
466
+ )
467
+ missing_df = self._check_duplicate_missing_associations(
468
+ missing_df1, missing_df2
469
+ )
470
+
471
+ case DataSchemaType.STANDARD:
472
+ if config.dataset_path is None:
473
+ assert (
474
+ orig_version is not None
475
+ ), "orig_version must be set if dataset_path is None"
476
+ lk_df = self._store.read_lookup_table(config.model.dataset_id, orig_version)
477
+ ld_df = self._store.read_table(config.model.dataset_id, orig_version)
478
+ missing_df = self._store.read_missing_associations_table(
479
+ config.model.dataset_id, orig_version
480
+ )
481
+ else:
482
+ # Note: config will be updated if this is a pivoted table.
483
+ ld_df, tmp = self._read_table_from_user_path(config, config.dataset_path)
484
+ if tmp is not None:
485
+ msg = (
486
+ "NULL rows cannot be present in the load_data table in standard format. "
487
+ "They must be provided in the load_data_lookup table."
488
+ )
489
+ raise DSGInvalidDataset(msg)
490
+ lk_df, missing_df1 = self._read_lookup_table_from_user_path(
491
+ Path(config.dataset_path)
492
+ )
493
+ missing_df2 = self._read_missing_associations_table_from_user_path(
494
+ config.dataset_path
495
+ )
496
+ missing_df = self._check_duplicate_missing_associations(
497
+ missing_df1, missing_df2
498
+ )
499
+ case _:
500
+ msg = f"Unsupported data schema type: {config.get_data_schema_type()}"
501
+ raise Exception(msg)
502
+
503
+ self._store.write_table(ld_df, config.model.dataset_id, config.model.version)
504
+ if lk_df is not None:
505
+ self._store.write_lookup_table(lk_df, config.model.dataset_id, config.model.version)
506
+ if missing_df is not None:
507
+ self._store.write_missing_associations_table(
508
+ missing_df, config.model.dataset_id, config.model.version
509
+ )
510
+
511
+ @staticmethod
512
+ def _check_duplicate_missing_associations(
513
+ df1: DataFrame | None, df2: DataFrame | None
514
+ ) -> DataFrame | None:
515
+ if df1 is not None and df2 is not None:
516
+ msg = "A dataset cannot have expected missing rows in the data and "
517
+ "provide a missing_associations file. Provide one or the other."
518
+ raise DSGInvalidDataset(msg)
519
+ return df1 or df2
520
+
521
+ def _copy_dataset_config(self, conn: Connection, config: DatasetConfig) -> DatasetConfig:
522
+ new_config = DatasetConfig(config.model)
523
+ assert config.dataset_path is not None
524
+ new_config.dataset_path = config.dataset_path
525
+ self._update_dimensions(conn, new_config)
526
+ return new_config
527
+
528
+ def update_from_file(
529
+ self,
530
+ config_file: Path,
531
+ dataset_id: str,
532
+ submitter: str,
533
+ update_type: VersionUpdateType,
534
+ log_message: str,
535
+ version: str,
536
+ dataset_path: Path | None = None,
537
+ ):
538
+ with RegistrationContext(self.db, log_message, update_type, submitter) as context:
539
+ conn = context.connection
540
+ path = (
541
+ self._params.base_path / "data" / dataset_id / version
542
+ if dataset_path is None
543
+ else dataset_path
544
+ )
545
+ if dataset_path is None:
546
+ config = DatasetConfig.load(config_file)
547
+ else:
548
+ config = DatasetConfig.load_from_user_path(config_file, path)
549
+ self._update_dimensions(conn, config)
550
+ self._check_update(conn, config, dataset_id, version)
551
+ self.update_with_context(
552
+ config,
553
+ context,
554
+ )
555
+
556
+ @track_timing(timer_stats_collector)
557
+ def update(
558
+ self,
559
+ config: DatasetConfig,
560
+ update_type: VersionUpdateType,
561
+ log_message: str,
562
+ submitter: str | None = None,
563
+ ) -> DatasetConfig:
564
+ lock_file_path = self.get_registry_lock_file(config.model.dataset_id)
565
+ with self.cloud_interface.make_lock_file_managed(lock_file_path):
566
+ # Note that projects will not pick up these changes until submit-dataset
567
+ # is called again.
568
+ with RegistrationContext(self.db, log_message, update_type, submitter) as context:
569
+ return self.update_with_context(
570
+ config,
571
+ context,
572
+ )
573
+
574
+ def update_with_context(
575
+ self,
576
+ config: DatasetConfig,
577
+ context: RegistrationContext,
578
+ ) -> DatasetConfig:
579
+ conn = context.connection
580
+ dataset_id = config.model.dataset_id
581
+ cur_config = self.get_by_id(dataset_id, conn=conn)
582
+ updated_model = self._update_config(config, context)
583
+ updated_config = DatasetConfig(updated_model)
584
+ updated_config.dataset_path = config.dataset_path
585
+ self._update_dimensions(conn, updated_config)
586
+
587
+ # Note: this method mutates updated_config.
588
+ self._write_to_registry(updated_config, orig_version=cur_config.model.version)
589
+
590
+ assoc_df = self._store.read_missing_associations_table(
591
+ updated_config.model.dataset_id, updated_config.model.version
592
+ )
593
+ try:
594
+ self._run_checks(conn, updated_config, assoc_df)
595
+ except Exception:
596
+ self._store.remove_tables(
597
+ updated_config.model.dataset_id, updated_config.model.version
598
+ )
599
+ raise
600
+
601
+ old_key = ConfigKey(dataset_id, cur_config.model.version)
602
+ new_key = ConfigKey(dataset_id, updated_config.model.version)
603
+ self._datasets.pop(old_key, None)
604
+ self._datasets[new_key] = updated_config
605
+
606
+ if not self.offline_mode:
607
+ self.sync_push(self.get_registry_data_directory(dataset_id))
608
+
609
+ return updated_config
610
+
611
+ def remove(self, config_id: str, conn: Connection | None = None):
612
+ for key in [x for x in self._datasets if x.id == config_id]:
613
+ self.remove_data(config_id, key.version)
614
+ self._datasets.pop(key)
615
+
616
+ self.db.delete_all(conn, config_id)
617
+ logger.info("Removed %s from the registry.", config_id)
618
+
619
+ def remove_data(self, dataset_id: str, version: str):
620
+ self._store.remove_tables(dataset_id, version)
621
+ logger.info("Removed data for %s version=%s from the registry.", dataset_id, version)
622
+
623
+ def show(
624
+ self,
625
+ conn: Connection | None = None,
626
+ filters: list[str] | None = None,
627
+ max_width: Union[int, dict] | None = None,
628
+ drop_fields: list[str] | None = None,
629
+ return_table: bool = False,
630
+ **kwargs: Any,
631
+ ):
632
+ """Show registry in PrettyTable
633
+
634
+ Parameters
635
+ ----------
636
+ filters : list or tuple
637
+ List of filter expressions for reigstry content (e.g., filters=["Submitter==USER", "Description contains comstock"])
638
+ max_width
639
+ Max column width in PrettyTable, specify as a single value or as a dict of values by field name
640
+ drop_fields
641
+ List of field names not to show
642
+
643
+ """
644
+
645
+ if filters:
646
+ logger.info("List registry for: %s", filters)
647
+
648
+ table = PrettyTable(title=self.name())
649
+ all_field_names = (
650
+ "ID",
651
+ "Version",
652
+ "Date",
653
+ "Submitter",
654
+ "Description",
655
+ )
656
+ if drop_fields is None:
657
+ table.field_names = all_field_names
658
+ else:
659
+ table.field_names = tuple(x for x in all_field_names if x not in drop_fields)
660
+
661
+ if max_width is None:
662
+ table._max_width = {
663
+ "ID": 50,
664
+ "Date": 10,
665
+ "Description": 50,
666
+ }
667
+ if isinstance(max_width, int):
668
+ table.max_width = max_width
669
+ elif isinstance(max_width, dict):
670
+ table._max_width = max_width
671
+
672
+ if filters:
673
+ transformed_filters = transform_and_validate_filters(filters)
674
+ field_to_index = {x: i for i, x in enumerate(table.field_names)}
675
+ rows = []
676
+ for model in self.db.iter_models(conn, all_versions=True):
677
+ assert isinstance(model, DatasetConfigModel)
678
+ registration = self.db.get_registration(conn, model)
679
+ all_fields = (
680
+ model.dataset_id,
681
+ model.version,
682
+ registration.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
683
+ registration.submitter,
684
+ registration.log_message,
685
+ )
686
+ if drop_fields is None:
687
+ row = all_fields
688
+ else:
689
+ row = tuple(
690
+ y for (x, y) in zip(all_field_names, all_fields) if x not in drop_fields
691
+ )
692
+
693
+ if not filters or matches_filters(row, field_to_index, transformed_filters):
694
+ rows.append(row)
695
+
696
+ rows.sort(key=lambda x: x[0])
697
+ table.add_rows(rows)
698
+ table.align = "l"
699
+ if return_table:
700
+ return table
701
+ display_table(table)
702
+
703
+ def sync_pull(self, path):
704
+ """Synchronizes files from the remote registry to local.
705
+ Deletes any files locally that do not exist on remote.
706
+
707
+ path : Path
708
+ Local path
709
+
710
+ """
711
+ remote_path = self.relative_remote_path(path)
712
+ self.cloud_interface.sync_pull(
713
+ remote_path, path, exclude=SYNC_EXCLUDE_LIST, delete_local=True
714
+ )
715
+
716
+ def sync_push(self, path):
717
+ """Synchronizes files from the local path to the remote registry.
718
+
719
+ path : Path
720
+ Local path
721
+
722
+ """
723
+ remote_path = self.relative_remote_path(path)
724
+ lock_file_path = self.get_registry_lock_file(path.name)
725
+ self.cloud_interface.check_lock_file(lock_file_path)
726
+ try:
727
+ self.cloud_interface.sync_push(
728
+ remote_path=remote_path, local_path=path, exclude=SYNC_EXCLUDE_LIST
729
+ )
730
+ except Exception:
731
+ logger.exception(
732
+ "Please report this error to the dsgrid team. The registry may need recovery."
733
+ )
734
+ raise