dsgrid-toolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsgrid-toolkit might be problematic. Click here for more details.

Files changed (152) hide show
  1. dsgrid/__init__.py +22 -0
  2. dsgrid/api/__init__.py +0 -0
  3. dsgrid/api/api_manager.py +179 -0
  4. dsgrid/api/app.py +420 -0
  5. dsgrid/api/models.py +60 -0
  6. dsgrid/api/response_models.py +116 -0
  7. dsgrid/apps/__init__.py +0 -0
  8. dsgrid/apps/project_viewer/app.py +216 -0
  9. dsgrid/apps/registration_gui.py +444 -0
  10. dsgrid/chronify.py +22 -0
  11. dsgrid/cli/__init__.py +0 -0
  12. dsgrid/cli/common.py +120 -0
  13. dsgrid/cli/config.py +177 -0
  14. dsgrid/cli/download.py +13 -0
  15. dsgrid/cli/dsgrid.py +142 -0
  16. dsgrid/cli/dsgrid_admin.py +349 -0
  17. dsgrid/cli/install_notebooks.py +62 -0
  18. dsgrid/cli/query.py +711 -0
  19. dsgrid/cli/registry.py +1773 -0
  20. dsgrid/cloud/__init__.py +0 -0
  21. dsgrid/cloud/cloud_storage_interface.py +140 -0
  22. dsgrid/cloud/factory.py +31 -0
  23. dsgrid/cloud/fake_storage_interface.py +37 -0
  24. dsgrid/cloud/s3_storage_interface.py +156 -0
  25. dsgrid/common.py +35 -0
  26. dsgrid/config/__init__.py +0 -0
  27. dsgrid/config/annual_time_dimension_config.py +187 -0
  28. dsgrid/config/common.py +131 -0
  29. dsgrid/config/config_base.py +148 -0
  30. dsgrid/config/dataset_config.py +684 -0
  31. dsgrid/config/dataset_schema_handler_factory.py +41 -0
  32. dsgrid/config/date_time_dimension_config.py +108 -0
  33. dsgrid/config/dimension_config.py +54 -0
  34. dsgrid/config/dimension_config_factory.py +65 -0
  35. dsgrid/config/dimension_mapping_base.py +349 -0
  36. dsgrid/config/dimension_mappings_config.py +48 -0
  37. dsgrid/config/dimensions.py +775 -0
  38. dsgrid/config/dimensions_config.py +71 -0
  39. dsgrid/config/index_time_dimension_config.py +76 -0
  40. dsgrid/config/input_dataset_requirements.py +31 -0
  41. dsgrid/config/mapping_tables.py +209 -0
  42. dsgrid/config/noop_time_dimension_config.py +42 -0
  43. dsgrid/config/project_config.py +1457 -0
  44. dsgrid/config/registration_models.py +199 -0
  45. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  46. dsgrid/config/simple_models.py +49 -0
  47. dsgrid/config/supplemental_dimension.py +29 -0
  48. dsgrid/config/time_dimension_base_config.py +200 -0
  49. dsgrid/data_models.py +155 -0
  50. dsgrid/dataset/__init__.py +0 -0
  51. dsgrid/dataset/dataset.py +123 -0
  52. dsgrid/dataset/dataset_expression_handler.py +86 -0
  53. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  54. dsgrid/dataset/dataset_schema_handler_base.py +899 -0
  55. dsgrid/dataset/dataset_schema_handler_one_table.py +196 -0
  56. dsgrid/dataset/dataset_schema_handler_standard.py +303 -0
  57. dsgrid/dataset/growth_rates.py +162 -0
  58. dsgrid/dataset/models.py +44 -0
  59. dsgrid/dataset/table_format_handler_base.py +257 -0
  60. dsgrid/dataset/table_format_handler_factory.py +17 -0
  61. dsgrid/dataset/unpivoted_table.py +121 -0
  62. dsgrid/dimension/__init__.py +0 -0
  63. dsgrid/dimension/base_models.py +218 -0
  64. dsgrid/dimension/dimension_filters.py +308 -0
  65. dsgrid/dimension/standard.py +213 -0
  66. dsgrid/dimension/time.py +531 -0
  67. dsgrid/dimension/time_utils.py +88 -0
  68. dsgrid/dsgrid_rc.py +88 -0
  69. dsgrid/exceptions.py +105 -0
  70. dsgrid/filesystem/__init__.py +0 -0
  71. dsgrid/filesystem/cloud_filesystem.py +32 -0
  72. dsgrid/filesystem/factory.py +32 -0
  73. dsgrid/filesystem/filesystem_interface.py +136 -0
  74. dsgrid/filesystem/local_filesystem.py +74 -0
  75. dsgrid/filesystem/s3_filesystem.py +118 -0
  76. dsgrid/loggers.py +132 -0
  77. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +950 -0
  78. dsgrid/notebooks/registration.ipynb +48 -0
  79. dsgrid/notebooks/start_notebook.sh +11 -0
  80. dsgrid/project.py +451 -0
  81. dsgrid/query/__init__.py +0 -0
  82. dsgrid/query/dataset_mapping_plan.py +142 -0
  83. dsgrid/query/derived_dataset.py +384 -0
  84. dsgrid/query/models.py +726 -0
  85. dsgrid/query/query_context.py +287 -0
  86. dsgrid/query/query_submitter.py +847 -0
  87. dsgrid/query/report_factory.py +19 -0
  88. dsgrid/query/report_peak_load.py +70 -0
  89. dsgrid/query/reports_base.py +20 -0
  90. dsgrid/registry/__init__.py +0 -0
  91. dsgrid/registry/bulk_register.py +161 -0
  92. dsgrid/registry/common.py +287 -0
  93. dsgrid/registry/config_update_checker_base.py +63 -0
  94. dsgrid/registry/data_store_factory.py +34 -0
  95. dsgrid/registry/data_store_interface.py +69 -0
  96. dsgrid/registry/dataset_config_generator.py +156 -0
  97. dsgrid/registry/dataset_registry_manager.py +734 -0
  98. dsgrid/registry/dataset_update_checker.py +16 -0
  99. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  100. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  101. dsgrid/registry/dimension_registry_manager.py +413 -0
  102. dsgrid/registry/dimension_update_checker.py +16 -0
  103. dsgrid/registry/duckdb_data_store.py +185 -0
  104. dsgrid/registry/filesystem_data_store.py +141 -0
  105. dsgrid/registry/filter_registry_manager.py +123 -0
  106. dsgrid/registry/project_config_generator.py +57 -0
  107. dsgrid/registry/project_registry_manager.py +1616 -0
  108. dsgrid/registry/project_update_checker.py +48 -0
  109. dsgrid/registry/registration_context.py +223 -0
  110. dsgrid/registry/registry_auto_updater.py +316 -0
  111. dsgrid/registry/registry_database.py +662 -0
  112. dsgrid/registry/registry_interface.py +446 -0
  113. dsgrid/registry/registry_manager.py +544 -0
  114. dsgrid/registry/registry_manager_base.py +367 -0
  115. dsgrid/registry/versioning.py +92 -0
  116. dsgrid/spark/__init__.py +0 -0
  117. dsgrid/spark/functions.py +545 -0
  118. dsgrid/spark/types.py +50 -0
  119. dsgrid/tests/__init__.py +0 -0
  120. dsgrid/tests/common.py +139 -0
  121. dsgrid/tests/make_us_data_registry.py +204 -0
  122. dsgrid/tests/register_derived_datasets.py +103 -0
  123. dsgrid/tests/utils.py +25 -0
  124. dsgrid/time/__init__.py +0 -0
  125. dsgrid/time/time_conversions.py +80 -0
  126. dsgrid/time/types.py +67 -0
  127. dsgrid/units/__init__.py +0 -0
  128. dsgrid/units/constants.py +113 -0
  129. dsgrid/units/convert.py +71 -0
  130. dsgrid/units/energy.py +145 -0
  131. dsgrid/units/power.py +87 -0
  132. dsgrid/utils/__init__.py +0 -0
  133. dsgrid/utils/dataset.py +612 -0
  134. dsgrid/utils/files.py +179 -0
  135. dsgrid/utils/filters.py +125 -0
  136. dsgrid/utils/id_remappings.py +100 -0
  137. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  138. dsgrid/utils/py_expression_eval/README.md +8 -0
  139. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  140. dsgrid/utils/py_expression_eval/tests.py +283 -0
  141. dsgrid/utils/run_command.py +70 -0
  142. dsgrid/utils/scratch_dir_context.py +64 -0
  143. dsgrid/utils/spark.py +918 -0
  144. dsgrid/utils/spark_partition.py +98 -0
  145. dsgrid/utils/timing.py +239 -0
  146. dsgrid/utils/utilities.py +184 -0
  147. dsgrid/utils/versioning.py +36 -0
  148. dsgrid_toolkit-0.2.0.dist-info/METADATA +216 -0
  149. dsgrid_toolkit-0.2.0.dist-info/RECORD +152 -0
  150. dsgrid_toolkit-0.2.0.dist-info/WHEEL +4 -0
  151. dsgrid_toolkit-0.2.0.dist-info/entry_points.txt +4 -0
  152. dsgrid_toolkit-0.2.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,899 @@
1
+ import abc
2
+ import logging
3
+ import os
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Iterable, Self
7
+
8
+ import chronify
9
+ from sqlalchemy import Connection
10
+
11
+ import dsgrid
12
+ from dsgrid.chronify import create_store
13
+ from dsgrid.config.annual_time_dimension_config import (
14
+ AnnualTimeDimensionConfig,
15
+ map_annual_time_to_date_time,
16
+ )
17
+ from dsgrid.config.dimension_config import (
18
+ DimensionBaseConfig,
19
+ DimensionBaseConfigWithFiles,
20
+ )
21
+ from dsgrid.config.noop_time_dimension_config import NoOpTimeDimensionConfig
22
+ from dsgrid.config.date_time_dimension_config import DateTimeDimensionConfig
23
+ from dsgrid.config.index_time_dimension_config import IndexTimeDimensionConfig
24
+ from dsgrid.config.project_config import ProjectConfig
25
+ from dsgrid.config.time_dimension_base_config import TimeDimensionBaseConfig
26
+ from dsgrid.dimension.time import TimeBasedDataAdjustmentModel
27
+ from dsgrid.dsgrid_rc import DsgridRuntimeConfig
28
+ from dsgrid.common import VALUE_COLUMN, BackendEngine
29
+ from dsgrid.config.dataset_config import (
30
+ DatasetConfig,
31
+ InputDatasetType,
32
+ )
33
+ from dsgrid.config.dimension_mapping_base import (
34
+ DimensionMappingReferenceModel,
35
+ )
36
+ from dsgrid.config.simple_models import DimensionSimpleModel
37
+ from dsgrid.dataset.models import TableFormatType
38
+ from dsgrid.dataset.table_format_handler_factory import make_table_format_handler
39
+ from dsgrid.dimension.base_models import DimensionType
40
+ from dsgrid.exceptions import DSGInvalidDataset, DSGInvalidDimensionMapping
41
+ from dsgrid.dimension.time import (
42
+ DaylightSavingAdjustmentModel,
43
+ )
44
+ from dsgrid.dataset.dataset_mapping_manager import DatasetMappingManager
45
+ from dsgrid.query.dataset_mapping_plan import DatasetMappingPlan, MapOperation
46
+ from dsgrid.query.query_context import QueryContext
47
+ from dsgrid.query.models import ColumnType
48
+ from dsgrid.spark.functions import (
49
+ cache,
50
+ cross_join,
51
+ except_all,
52
+ is_dataframe_empty,
53
+ join,
54
+ make_temp_view_name,
55
+ unpersist,
56
+ )
57
+ from dsgrid.registry.data_store_interface import DataStoreInterface
58
+ from dsgrid.spark.types import DataFrame, F
59
+ from dsgrid.units.convert import convert_units_unpivoted
60
+ from dsgrid.utils.dataset import (
61
+ check_historical_annual_time_model_year_consistency,
62
+ filter_out_expected_missing_associations,
63
+ handle_dimension_association_errors,
64
+ is_noop_mapping,
65
+ map_stacked_dimension,
66
+ add_time_zone,
67
+ map_time_dimension_with_chronify_duckdb,
68
+ map_time_dimension_with_chronify_spark_hive,
69
+ map_time_dimension_with_chronify_spark_path,
70
+ ordered_subset_columns,
71
+ repartition_if_needed_by_mapping,
72
+ )
73
+
74
+ from dsgrid.utils.scratch_dir_context import ScratchDirContext
75
+ from dsgrid.utils.spark import (
76
+ check_for_nulls,
77
+ create_dataframe_from_product,
78
+ persist_intermediate_table,
79
+ read_dataframe,
80
+ save_to_warehouse,
81
+ union,
82
+ write_dataframe,
83
+ )
84
+ from dsgrid.utils.timing import timer_stats_collector, track_timing
85
+ from dsgrid.registry.dimension_registry_manager import DimensionRegistryManager
86
+ from dsgrid.registry.dimension_mapping_registry_manager import (
87
+ DimensionMappingRegistryManager,
88
+ )
89
+
90
+ logger = logging.getLogger(__name__)
91
+
92
+
93
+ class DatasetSchemaHandlerBase(abc.ABC):
94
+ """define interface/required behaviors per dataset schema"""
95
+
96
+ def __init__(
97
+ self,
98
+ config: DatasetConfig,
99
+ conn: Connection | None,
100
+ dimension_mgr: DimensionRegistryManager,
101
+ dimension_mapping_mgr: DimensionMappingRegistryManager,
102
+ mapping_references: list[DimensionMappingReferenceModel] | None = None,
103
+ ):
104
+ self._conn = conn
105
+ self._config = config
106
+ self._dimension_mgr = dimension_mgr
107
+ self._dimension_mapping_mgr = dimension_mapping_mgr
108
+ self._mapping_references: list[DimensionMappingReferenceModel] = mapping_references or []
109
+
110
+ @classmethod
111
+ @abc.abstractmethod
112
+ def load(cls, config: DatasetConfig, *args, store: DataStoreInterface | None = None) -> Self:
113
+ """Create a dataset schema handler by loading the data tables from files.
114
+
115
+ Parameters
116
+ ----------
117
+ config: DatasetConfig
118
+ store: DataStoreInterface | None
119
+ If provided, the dataset must already be registered.
120
+ If not provided, the dataset must not be registered and the file path must be
121
+ available via the DatasetConfig.
122
+
123
+ Returns
124
+ -------
125
+ DatasetSchemaHandlerBase
126
+ """
127
+
128
+ @abc.abstractmethod
129
+ def check_consistency(self, missing_dimension_associations: DataFrame | None) -> None:
130
+ """
131
+ Check all data consistencies, including data columns, dataset to dimension records, and time
132
+ """
133
+
134
+ @abc.abstractmethod
135
+ def check_time_consistency(self):
136
+ """Check the time consistency of the dataset."""
137
+
138
+ @abc.abstractmethod
139
+ def _get_load_data_table(self) -> DataFrame:
140
+ """Return the full load data table."""
141
+
142
+ def _make_actual_dimension_association_table_from_data(self) -> DataFrame:
143
+ return self._remove_non_dimension_columns(self._get_load_data_table()).distinct()
144
+
145
+ def _make_expected_dimension_association_table_from_records(
146
+ self, dimension_types: Iterable[DimensionType], context: ScratchDirContext
147
+ ) -> DataFrame:
148
+ """Return a dataframe containing one row for each unique dimension combination except time.
149
+ Use dimensions in the dataset's dimension records.
150
+ """
151
+ data: dict[str, list[str]] = {}
152
+ for dim_type in dimension_types:
153
+ dim = self._config.get_dimension_with_records(dim_type)
154
+ if dim is not None:
155
+ data[dim_type.value] = list(dim.get_unique_ids())
156
+
157
+ if not data:
158
+ msg = "Bug: did not find any dimension records"
159
+ raise Exception(msg)
160
+ return create_dataframe_from_product(data, context)
161
+
162
+ @track_timing(timer_stats_collector)
163
+ def _check_dimension_associations(
164
+ self, missing_dimension_associations: DataFrame | None
165
+ ) -> None:
166
+ """Check that a cross-join of dimension records is present, unless explicitly excepted."""
167
+ context = ScratchDirContext(Path(tempfile.gettempdir()))
168
+ assoc_by_records = self._make_expected_dimension_association_table_from_records(
169
+ [x for x in DimensionType if x != DimensionType.TIME], context
170
+ )
171
+ assoc_by_data = self._make_actual_dimension_association_table_from_data()
172
+ if missing_dimension_associations is None:
173
+ required_assoc = assoc_by_records
174
+ else:
175
+ required_assoc = filter_out_expected_missing_associations(
176
+ assoc_by_records, missing_dimension_associations
177
+ )
178
+ cols = sorted(required_assoc.columns)
179
+ diff = except_all(required_assoc.select(*cols), assoc_by_data.select(*cols))
180
+ cache(diff)
181
+ try:
182
+ if not is_dataframe_empty(diff):
183
+ handle_dimension_association_errors(diff, assoc_by_data, self.dataset_id)
184
+ finally:
185
+ unpersist(diff)
186
+
187
+ def make_mapped_dimension_association_table(
188
+ self, store: DataStoreInterface, context: ScratchDirContext
189
+ ) -> DataFrame:
190
+ """Return a dataframe containing one row for each unique dimension combination except time.
191
+ Use mapped dimensions.
192
+ """
193
+ df = self._make_actual_dimension_association_table_from_data()
194
+ missing_associations = store.read_missing_associations_table(
195
+ self._config.model.dataset_id, self._config.model.version
196
+ )
197
+ if missing_associations is not None:
198
+ missing_associations = self._union_not_covered_dimensions(
199
+ missing_associations, context
200
+ )
201
+ assert sorted(df.columns) == sorted(missing_associations.columns)
202
+ df = union([df, missing_associations.select(*df.columns)])
203
+ mapping_plan = self.build_default_dataset_mapping_plan()
204
+ with DatasetMappingManager(self.dataset_id, mapping_plan, context) as mapping_manager:
205
+ df = self._remap_dimension_columns(df, mapping_manager).drop("fraction").distinct()
206
+ check_for_nulls(df)
207
+ return df
208
+
209
+ def _union_not_covered_dimensions(
210
+ self, df: DataFrame, context: ScratchDirContext
211
+ ) -> DataFrame:
212
+ columns = set(df.columns)
213
+ not_covered_dims: list[DimensionType] = []
214
+ for dim in DimensionType:
215
+ if dim != DimensionType.TIME and dim.value not in columns:
216
+ not_covered_dims.append(dim)
217
+
218
+ if not not_covered_dims:
219
+ return df
220
+
221
+ expected_table = self._make_expected_dimension_association_table_from_records(
222
+ not_covered_dims, context
223
+ )
224
+ return cross_join(
225
+ df,
226
+ expected_table,
227
+ )
228
+
229
+ @abc.abstractmethod
230
+ def filter_data(self, dimensions: list[DimensionSimpleModel], store: DataStoreInterface):
231
+ """Filter the load data by dimensions and rewrite the files.
232
+
233
+ dimensions : list[DimensionSimpleModel]
234
+ store : DataStoreInterface
235
+ The data store to use for reading and writing the data.
236
+ """
237
+
238
+ @property
239
+ def connection(self) -> Connection | None:
240
+ """Return the active sqlalchemy connection to the registry database."""
241
+ return self._conn
242
+
243
+ @property
244
+ def dataset_id(self):
245
+ return self._config.config_id
246
+
247
+ @property
248
+ def config(self):
249
+ """Returns the DatasetConfig.
250
+
251
+ Returns
252
+ -------
253
+ DatasetConfig
254
+
255
+ """
256
+ return self._config
257
+
258
+ @abc.abstractmethod
259
+ def make_project_dataframe(self, context, project_config) -> DataFrame:
260
+ """Return a load_data dataframe with dimensions mapped to the project's with filters
261
+ as specified by the QueryContext.
262
+
263
+ Parameters
264
+ ----------
265
+ context : QueryContext
266
+ project_config : ProjectConfig
267
+
268
+ Returns
269
+ -------
270
+ pyspark.sql.DataFrame
271
+
272
+ """
273
+
274
+ @abc.abstractmethod
275
+ def make_mapped_dataframe(
276
+ self,
277
+ context: QueryContext,
278
+ time_dimension: TimeDimensionBaseConfig | None = None,
279
+ ) -> DataFrame:
280
+ """Return a load_data dataframe with dimensions mapped as stored in the handler.
281
+
282
+ Parameters
283
+ ----------
284
+ context
285
+ time_dimension
286
+ Required if the time dimension is being mapped.
287
+ This should be the destination time dimension.
288
+
289
+ """
290
+
291
+ @track_timing(timer_stats_collector)
292
+ def _check_dataset_time_consistency(self, load_data_df: DataFrame):
293
+ """Check dataset time consistency such that:
294
+ 1. time range(s) match time config record;
295
+ 2. all dimension combinations return the same set of time range(s).
296
+
297
+ Callers must ensure that the dataset has a time dimension.
298
+ """
299
+ if os.environ.get("__DSGRID_SKIP_CHECK_DATASET_TIME_CONSISTENCY__"):
300
+ logger.warning("Skip dataset time consistency checks.")
301
+ return
302
+
303
+ logger.info("Check dataset time consistency.")
304
+ time_dim = self._config.get_time_dimension()
305
+ assert time_dim is not None, "time cannot be checked if the dataset has no time dimension"
306
+ time_cols = self._get_time_dimension_columns()
307
+ time_dim.check_dataset_time_consistency(load_data_df, time_cols)
308
+ if not isinstance(time_dim, NoOpTimeDimensionConfig):
309
+ self._check_dataset_time_consistency_by_time_array(time_cols, load_data_df)
310
+ self._check_model_year_time_consistency(load_data_df)
311
+
312
+ @track_timing(timer_stats_collector)
313
+ def _check_dataset_time_consistency_with_chronify(self):
314
+ """Check dataset time consistency such that:
315
+ 1. time range(s) match time config record;
316
+ 2. all dimension combinations return the same set of time range(s).
317
+
318
+ Callers must ensure that the dataset has a time dimension.
319
+ """
320
+ if os.environ.get("__DSGRID_SKIP_CHECK_DATASET_TIME_CONSISTENCY__"):
321
+ logger.warning("Skip dataset time consistency checks.")
322
+ return
323
+
324
+ logger.info("Check dataset time consistency.")
325
+ path = Path(self._config.load_data_path)
326
+ assert path.exists()
327
+ load_data_df = read_dataframe(path)
328
+ schema = self._get_chronify_schema(load_data_df)
329
+ scratch_dir = DsgridRuntimeConfig.load().get_scratch_dir()
330
+ with ScratchDirContext(scratch_dir) as context:
331
+ if path.suffix == ".parquet":
332
+ src_path = path
333
+ else:
334
+ src_path = context.get_temp_filename(suffix=".parquet")
335
+ write_dataframe(load_data_df, src_path)
336
+
337
+ store_file = context.get_temp_filename(suffix=".db")
338
+ with create_store(store_file) as store:
339
+ # This performs all of the checks.
340
+ store.create_view_from_parquet(src_path, schema)
341
+ store.drop_view(schema.name)
342
+
343
+ self._check_model_year_time_consistency(load_data_df)
344
+
345
+ def _get_chronify_schema(self, df: DataFrame):
346
+ time_dim = self._config.get_dimension(DimensionType.TIME)
347
+ time_cols = time_dim.get_load_data_time_columns()
348
+ time_array_id_columns = [
349
+ x
350
+ for x in df.columns
351
+ # If there are multiple weather years:
352
+ # - that are continuous, weather year needs to be excluded (one overall range).
353
+ # - that are not continuous, weather year needs to be included and chronify
354
+ # needs additional support. TODO: issue #340
355
+ if x != DimensionType.WEATHER_YEAR.value
356
+ and x
357
+ in set(df.columns).difference(time_cols).difference(self._config.get_value_columns())
358
+ ]
359
+ if self._config.get_table_format_type() == TableFormatType.PIVOTED:
360
+ # We can ignore all pivoted columns but one for time checking.
361
+ # Looking at the rest would be redundant.
362
+ value_column = next(iter(self._config.get_pivoted_dimension_columns()))
363
+ else:
364
+ value_column = VALUE_COLUMN
365
+ return chronify.TableSchema(
366
+ name=make_temp_view_name(),
367
+ time_config=time_dim.to_chronify(),
368
+ time_array_id_columns=time_array_id_columns,
369
+ value_column=value_column,
370
+ )
371
+
372
+ def _check_model_year_time_consistency(self, df: DataFrame):
373
+ time_dim = self._config.get_dimension(DimensionType.TIME)
374
+ if self._config.model.dataset_type == InputDatasetType.HISTORICAL and isinstance(
375
+ time_dim, AnnualTimeDimensionConfig
376
+ ):
377
+ annual_cols = time_dim.get_load_data_time_columns()
378
+ assert len(annual_cols) == 1
379
+ annual_col = annual_cols[0]
380
+ check_historical_annual_time_model_year_consistency(
381
+ df, annual_col, DimensionType.MODEL_YEAR.value
382
+ )
383
+
384
+ @track_timing(timer_stats_collector)
385
+ def _check_dataset_time_consistency_by_time_array(self, time_cols, load_data_df):
386
+ """Check that each unique time array has the same timestamps."""
387
+ logger.info("Check dataset time consistency by time array.")
388
+ unique_array_cols = set(DimensionType.get_allowed_dimension_column_names()).intersection(
389
+ load_data_df.columns
390
+ )
391
+ counts = load_data_df.groupBy(*time_cols).count().select("count")
392
+ distinct_counts = counts.select("count").distinct().collect()
393
+ if len(distinct_counts) != 1:
394
+ msg = (
395
+ "All time arrays must be repeated the same number of times: "
396
+ f"unique timestamp repeats = {len(distinct_counts)}"
397
+ )
398
+ raise DSGInvalidDataset(msg)
399
+ ta_counts = load_data_df.groupBy(*unique_array_cols).count().select("count")
400
+ distinct_ta_counts = ta_counts.select("count").distinct().collect()
401
+ if len(distinct_ta_counts) != 1:
402
+ msg = (
403
+ "All combinations of non-time dimensions must have the same time array length: "
404
+ f"unique time array lengths = {len(distinct_ta_counts)}"
405
+ )
406
+ raise DSGInvalidDataset(msg)
407
+
408
+ def _check_load_data_unpivoted_value_column(self, df):
409
+ logger.info("Check load data unpivoted columns.")
410
+ if VALUE_COLUMN not in df.columns:
411
+ msg = f"value_column={VALUE_COLUMN} is not in columns={df.columns}"
412
+ raise DSGInvalidDataset(msg)
413
+
414
+ def _convert_units(
415
+ self,
416
+ df: DataFrame,
417
+ project_metric_records: DataFrame,
418
+ mapping_manager: DatasetMappingManager,
419
+ ):
420
+ if not self._config.model.enable_unit_conversion:
421
+ return df
422
+
423
+ op = mapping_manager.plan.convert_units_op
424
+ if mapping_manager.has_completed_operation(op):
425
+ return df
426
+
427
+ # Note that a dataset could have the same dimension record IDs as the project,
428
+ # no mappings, but then still have different units.
429
+ mapping_records = None
430
+ for ref in self._mapping_references:
431
+ dim_type = ref.from_dimension_type
432
+ if dim_type == DimensionType.METRIC:
433
+ mapping_records = self._dimension_mapping_mgr.get_by_id(
434
+ ref.mapping_id, version=ref.version, conn=self.connection
435
+ ).get_records_dataframe()
436
+ break
437
+
438
+ dataset_dim = self._config.get_dimension_with_records(DimensionType.METRIC)
439
+ dataset_records = dataset_dim.get_records_dataframe()
440
+ df = convert_units_unpivoted(
441
+ df,
442
+ DimensionType.METRIC.value,
443
+ dataset_records,
444
+ mapping_records,
445
+ project_metric_records,
446
+ )
447
+ if op.persist:
448
+ df = mapping_manager.persist_intermediate_table(df, op)
449
+ return df
450
+
451
+ def _finalize_table(self, context: QueryContext, df: DataFrame, project_config: ProjectConfig):
452
+ # TODO: remove ProjectConfig so that dataset queries can use this.
453
+ # Issue #370
454
+ table_handler = make_table_format_handler(
455
+ self._config.get_table_format_type(),
456
+ project_config,
457
+ dataset_id=self.dataset_id,
458
+ )
459
+
460
+ time_dim = project_config.get_base_dimension(DimensionType.TIME)
461
+ context.set_dataset_metadata(
462
+ self.dataset_id,
463
+ context.model.result.column_type,
464
+ project_config.get_load_data_time_columns(time_dim.model.name),
465
+ )
466
+
467
+ if context.model.result.column_type == ColumnType.DIMENSION_NAMES:
468
+ df = table_handler.convert_columns_to_query_names(
469
+ df, self._config.model.dataset_id, context
470
+ )
471
+
472
+ return df
473
+
474
+ @staticmethod
475
+ def _get_pivoted_column_name(
476
+ context: QueryContext, pivoted_dimension_type: DimensionType, project_config
477
+ ):
478
+ match context.model.result.column_type:
479
+ case ColumnType.DIMENSION_NAMES:
480
+ pivoted_column_name = project_config.get_base_dimension(
481
+ pivoted_dimension_type
482
+ ).model.name
483
+ case ColumnType.DIMENSION_TYPES:
484
+ pivoted_column_name = pivoted_dimension_type.value
485
+ case _:
486
+ msg = str(context.model.result.column_type)
487
+ raise NotImplementedError(msg)
488
+
489
+ return pivoted_column_name
490
+
491
+ def _get_dataset_to_project_mapping_records(self, dimension_type: DimensionType):
492
+ config = self._get_dataset_to_project_mapping_config(dimension_type)
493
+ if config is None:
494
+ return config
495
+ return config.get_records_dataframe()
496
+
497
+ def _get_dataset_to_project_mapping_config(self, dimension_type: DimensionType):
498
+ ref = self._get_dataset_to_project_mapping_reference(dimension_type)
499
+ if ref is None:
500
+ return ref
501
+ return self._dimension_mapping_mgr.get_by_id(
502
+ ref.mapping_id, version=ref.version, conn=self.connection
503
+ )
504
+
505
+ def _get_dataset_to_project_mapping_reference(self, dimension_type: DimensionType):
506
+ for ref in self._mapping_references:
507
+ if ref.from_dimension_type == dimension_type:
508
+ return ref
509
+ return
510
+
511
+ def _get_mapping_to_dimension(
512
+ self, dimension_type: DimensionType
513
+ ) -> DimensionBaseConfig | None:
514
+ ref = self._get_dataset_to_project_mapping_reference(dimension_type)
515
+ if ref is None:
516
+ return None
517
+ config = self._dimension_mapping_mgr.get_by_id(ref.mapping_id, conn=self._conn)
518
+ return self._dimension_mgr.get_by_id(
519
+ config.model.to_dimension.dimension_id, conn=self._conn
520
+ )
521
+
522
+ def _get_project_metric_records(self, project_config: ProjectConfig) -> DataFrame:
523
+ metric_dim_query_name = getattr(
524
+ project_config.get_dataset_base_dimension_names(self._config.model.dataset_id),
525
+ DimensionType.METRIC.value,
526
+ )
527
+ if metric_dim_query_name is None:
528
+ # This is a workaround for dsgrid projects created before the field
529
+ # base_dimension_names was added to InputDatasetModel.
530
+ metric_dims = project_config.list_base_dimensions(dimension_type=DimensionType.METRIC)
531
+ if len(metric_dims) > 1:
532
+ msg = (
533
+ "The dataset's base_dimension_names value is not set and "
534
+ "there are multiple metric dimensions in the project. Please re-register the "
535
+ f"dataset with dataset_id={self._config.model.dataset_id}."
536
+ )
537
+ raise DSGInvalidDataset(msg)
538
+ metric_dim_query_name = metric_dims[0].model.name
539
+ return project_config.get_dimension_records(metric_dim_query_name)
540
+
541
+ def _get_time_dimension_columns(self):
542
+ time_dim = self._config.get_dimension(DimensionType.TIME)
543
+ time_cols = time_dim.get_load_data_time_columns()
544
+ return time_cols
545
+
546
+ def _iter_dataset_record_ids(self, context: QueryContext):
547
+ for dim_type, project_record_ids in context.get_record_ids().items():
548
+ dataset_mapping = self._get_dataset_to_project_mapping_records(dim_type)
549
+ if dataset_mapping is None:
550
+ dataset_record_ids = project_record_ids
551
+ else:
552
+ dataset_record_ids = (
553
+ join(
554
+ dataset_mapping.withColumnRenamed("from_id", "dataset_record_id"),
555
+ project_record_ids,
556
+ "to_id",
557
+ "id",
558
+ )
559
+ .select("dataset_record_id")
560
+ .withColumnRenamed("dataset_record_id", "id")
561
+ .distinct()
562
+ )
563
+ yield dim_type, dataset_record_ids
564
+
565
+ @staticmethod
566
+ def _list_dimension_columns(df: DataFrame) -> list[str]:
567
+ columns = DimensionType.get_allowed_dimension_column_names()
568
+ return [x for x in df.columns if x in columns]
569
+
570
+ def _list_dimension_types_in_load_data(self, df: DataFrame) -> list[DimensionType]:
571
+ dims = [DimensionType(x) for x in DatasetSchemaHandlerBase._list_dimension_columns(df)]
572
+ if self._config.get_table_format_type() == TableFormatType.PIVOTED:
573
+ pivoted_type = self._config.get_pivoted_dimension_type()
574
+ assert pivoted_type is not None
575
+ dims.append(pivoted_type)
576
+ return dims
577
+
578
+ def _prefilter_pivoted_dimensions(self, context: QueryContext, df):
579
+ for dim_type, dataset_record_ids in self._iter_dataset_record_ids(context):
580
+ if dim_type == self._config.get_pivoted_dimension_type():
581
+ # Drop columns that don't match requested project record IDs.
582
+ cols_to_keep = {x.id for x in dataset_record_ids.collect()}
583
+ cols_to_drop = set(self._config.get_pivoted_dimension_columns()).difference(
584
+ cols_to_keep
585
+ )
586
+ if cols_to_drop:
587
+ df = df.drop(*cols_to_drop)
588
+
589
+ return df
590
+
591
+ def _prefilter_stacked_dimensions(self, context: QueryContext, df):
592
+ for dim_type, dataset_record_ids in self._iter_dataset_record_ids(context):
593
+ # Drop rows that don't match requested project record IDs.
594
+ tmp = dataset_record_ids.withColumnRenamed("id", "dataset_record_id")
595
+ if dim_type.value not in df.columns:
596
+ # This dimensions is stored in another table (e.g., lookup or load_data)
597
+ continue
598
+ df = join(df, tmp, dim_type.value, "dataset_record_id").drop("dataset_record_id")
599
+
600
+ return df
601
+
602
+ def _prefilter_time_dimension(self, context: QueryContext, df):
603
+ # TODO #196:
604
+ return df
605
+
606
+ def build_default_dataset_mapping_plan(self) -> DatasetMappingPlan:
607
+ """Build a default mapping order of dimensions to a project."""
608
+ mappings: list[MapOperation] = []
609
+ for ref in self._mapping_references:
610
+ config = self._dimension_mapping_mgr.get_by_id(ref.mapping_id, conn=self.connection)
611
+ dim = self._dimension_mgr.get_by_id(
612
+ config.model.to_dimension.dimension_id, conn=self.connection
613
+ )
614
+ mappings.append(
615
+ MapOperation(
616
+ name=dim.model.name,
617
+ mapping_reference=ref,
618
+ )
619
+ )
620
+
621
+ return DatasetMappingPlan(dataset_id=self._config.model.dataset_id, mappings=mappings)
622
+
623
+ def check_dataset_mapping_plan(
624
+ self, mapping_plan: DatasetMappingPlan, project_config: ProjectConfig
625
+ ) -> None:
626
+ """Check that a user-defined mapping plan is valid."""
627
+ req_dimensions: dict[DimensionType, DimensionMappingReferenceModel] = {}
628
+ actual_mapping_dims: dict[DimensionType, str] = {}
629
+
630
+ for ref in self._mapping_references:
631
+ assert ref.to_dimension_type not in req_dimensions
632
+ req_dimensions[ref.to_dimension_type] = ref
633
+
634
+ dataset_id = mapping_plan.dataset_id
635
+ indexes_to_remove: list[int] = []
636
+ for i, mapping in enumerate(mapping_plan.mappings):
637
+ to_dim = project_config.get_dimension(mapping.name)
638
+ if to_dim.model.dimension_type == DimensionType.TIME:
639
+ msg = (
640
+ f"DatasetMappingPlan for {dataset_id=} is invalid because specification "
641
+ f"of the time dimension is not supported: {mapping.name}"
642
+ )
643
+ raise DSGInvalidDimensionMapping(msg)
644
+ if to_dim.model.dimension_type in actual_mapping_dims:
645
+ msg = (
646
+ f"DatasetMappingPlan for {dataset_id=} is invalid because it can only "
647
+ f"support mapping one dimension for a given dimension type. "
648
+ f"type={to_dim.model.dimension_type} "
649
+ f"first={actual_mapping_dims[to_dim.model.dimension_type]} "
650
+ f"second={mapping.name}"
651
+ )
652
+ raise DSGInvalidDimensionMapping(msg)
653
+
654
+ from_dim = self._config.get_dimension(to_dim.model.dimension_type)
655
+ supp_dim_names = {
656
+ x.model.name
657
+ for x in project_config.list_supplemental_dimensions(to_dim.model.dimension_type)
658
+ }
659
+ if mapping.name in supp_dim_names:
660
+ # This could be useful if we wanted to use DatasetMappingPlan for mapping
661
+ # a single dataset to a project's dimensions without being concerned about
662
+ # aggregrations. As it stands, we can are only using this within our
663
+ # project query process. We need much more handling to make that work.
664
+ msg = (
665
+ "DatasetMappingPlan for {dataset_id=} is invalid because it specifies "
666
+ f"a supplemental dimension: {mapping.name}"
667
+ )
668
+ elif to_dim.model.dimension_type not in req_dimensions:
669
+ msg = (
670
+ f"DatasetMappingPlan for {dataset_id=} is invalid because there is no "
671
+ f"dataset-to-project-base mapping defined for {to_dim.model.label}"
672
+ )
673
+ raise DSGInvalidDimensionMapping(msg)
674
+
675
+ ref = req_dimensions[to_dim.model.dimension_type]
676
+ mapping_config = self._dimension_mapping_mgr.get_by_id(
677
+ ref.mapping_id, version=ref.version, conn=self.connection
678
+ )
679
+ if (
680
+ from_dim.model.dimension_id == mapping_config.model.from_dimension.dimension_id
681
+ and to_dim.model.dimension_id == mapping_config.model.to_dimension.dimension_id
682
+ ):
683
+ mapping.mapping_reference = ref
684
+ actual_mapping_dims[to_dim.model.dimension_type] = mapping.name
685
+
686
+ for index in indexes_to_remove:
687
+ mapping_plan.mappings.pop(index)
688
+
689
+ if diff_dims := set(req_dimensions.keys()).difference(actual_mapping_dims.keys()):
690
+ req = sorted((x.value for x in req_dimensions))
691
+ act = sorted((x.value for x in actual_mapping_dims))
692
+ diff = sorted((x.value for x in diff_dims))
693
+ msg = (
694
+ "If a mapping order is specified for a dataset, it must include all "
695
+ "dimension types that require mappings to the project base dimension.\n"
696
+ f"Required dimension types: {req}\nActual dimension types: {act}\n"
697
+ f"Difference: {diff}"
698
+ )
699
+ raise DSGInvalidDimensionMapping(msg)
700
+
701
+ def _remap_dimension_columns(
702
+ self,
703
+ df: DataFrame,
704
+ mapping_manager: DatasetMappingManager,
705
+ filtered_records: dict[DimensionType, DataFrame] | None = None,
706
+ ) -> DataFrame:
707
+ """Map the table's dimensions according to the plan.
708
+
709
+ Parameters
710
+ ----------
711
+ df
712
+ The dataframe to map.
713
+ mapping_manager
714
+ Manages checkpointing and order of the mapping operations.
715
+ filtered_records
716
+ If not None, use these records to filter the table.
717
+ If None, do not persist any intermediate tables.
718
+ If not None, use this context to persist intermediate tables if required.
719
+ """
720
+ completed_operations = mapping_manager.get_completed_mapping_operations()
721
+ for dim_mapping in mapping_manager.plan.mappings:
722
+ if dim_mapping.name in completed_operations:
723
+ logger.info(
724
+ "Skip mapping operation %s because the result exists in a checkpointed file.",
725
+ dim_mapping.name,
726
+ )
727
+ continue
728
+ assert dim_mapping.mapping_reference is not None
729
+ ref = dim_mapping.mapping_reference
730
+ dim_type = ref.from_dimension_type
731
+ column = dim_type.value
732
+ mapping_config = self._dimension_mapping_mgr.get_by_id(
733
+ ref.mapping_id, version=ref.version, conn=self.connection
734
+ )
735
+ logger.info(
736
+ "Mapping dimension type %s mapping_type=%s",
737
+ dim_type,
738
+ mapping_config.model.mapping_type,
739
+ )
740
+ records = mapping_config.get_records_dataframe()
741
+ if filtered_records is not None and dim_type in filtered_records:
742
+ records = join(records, filtered_records[dim_type], "to_id", "id").drop("id")
743
+
744
+ if is_noop_mapping(records):
745
+ logger.info("Skip no-op mapping %s.", ref.mapping_id)
746
+ continue
747
+ if column in df.columns:
748
+ persisted_file: Path | None = None
749
+ df = map_stacked_dimension(df, records, column)
750
+ df, persisted_file = repartition_if_needed_by_mapping(
751
+ df,
752
+ mapping_config.model.mapping_type,
753
+ mapping_manager.scratch_dir_context,
754
+ repartition=dim_mapping.handle_data_skew,
755
+ )
756
+ if dim_mapping.persist and persisted_file is None:
757
+ mapping_manager.persist_intermediate_table(df, dim_mapping)
758
+ if persisted_file is not None:
759
+ mapping_manager.save_checkpoint(persisted_file, dim_mapping)
760
+
761
+ return df
762
+
763
+ def _apply_fraction(
764
+ self,
765
+ df,
766
+ value_columns,
767
+ mapping_manager: DatasetMappingManager,
768
+ agg_func=None,
769
+ ):
770
+ op = mapping_manager.plan.apply_fraction_op
771
+ if "fraction" not in df.columns:
772
+ return df
773
+ if mapping_manager.has_completed_operation(op):
774
+ return df
775
+ agg_func = agg_func or F.sum
776
+ # Maintain column order.
777
+ agg_ops = [
778
+ agg_func(F.col(x) * F.col("fraction")).alias(x)
779
+ for x in [y for y in df.columns if y in value_columns]
780
+ ]
781
+ gcols = set(df.columns) - value_columns - {"fraction"}
782
+ df = df.groupBy(*ordered_subset_columns(df, gcols)).agg(*agg_ops)
783
+ df = df.drop("fraction")
784
+ if op.persist:
785
+ df = mapping_manager.persist_intermediate_table(df, op)
786
+ return df
787
+
788
+ @track_timing(timer_stats_collector)
789
+ def _convert_time_dimension(
790
+ self,
791
+ load_data_df: DataFrame,
792
+ to_time_dim: TimeDimensionBaseConfig,
793
+ value_column: str,
794
+ mapping_manager: DatasetMappingManager,
795
+ wrap_time_allowed: bool,
796
+ time_based_data_adjustment: TimeBasedDataAdjustmentModel,
797
+ to_geo_dim: DimensionBaseConfigWithFiles | None = None,
798
+ ):
799
+ op = mapping_manager.plan.map_time_op
800
+ if mapping_manager.has_completed_operation(op):
801
+ return load_data_df
802
+ self._validate_daylight_saving_adjustment(time_based_data_adjustment)
803
+ time_dim = self._config.get_time_dimension()
804
+ assert time_dim is not None
805
+ if time_dim.model.is_time_zone_required_in_geography():
806
+ if self._config.model.use_project_geography_time_zone:
807
+ if to_geo_dim is None:
808
+ msg = "Bug: to_geo_dim must be provided if time zone is required in geography."
809
+ raise Exception(msg)
810
+ logger.info("Add time zone from project geography dimension.")
811
+ geography_dim = to_geo_dim
812
+ else:
813
+ logger.info("Add time zone from dataset geography dimension.")
814
+ geography_dim = self._config.get_dimension(DimensionType.GEOGRAPHY)
815
+ load_data_df = add_time_zone(load_data_df, geography_dim)
816
+
817
+ if isinstance(time_dim, AnnualTimeDimensionConfig):
818
+ if not isinstance(to_time_dim, DateTimeDimensionConfig):
819
+ msg = f"Annual time can only be mapped to DateTime: {to_time_dim.model.time_type}"
820
+ raise NotImplementedError(msg)
821
+
822
+ return map_annual_time_to_date_time(
823
+ load_data_df,
824
+ time_dim,
825
+ to_time_dim,
826
+ {value_column},
827
+ )
828
+
829
+ config = dsgrid.runtime_config
830
+ if not time_dim.supports_chronify():
831
+ # annual time is returned above
832
+ # no mapping for no-op
833
+ assert isinstance(
834
+ time_dim, NoOpTimeDimensionConfig
835
+ ), "Only NoOp and AnnualTimeDimensionConfig do not currently support Chronify"
836
+ return load_data_df
837
+ match (config.backend_engine, config.use_hive_metastore):
838
+ case (BackendEngine.SPARK, True):
839
+ table_name = make_temp_view_name()
840
+ load_data_df = map_time_dimension_with_chronify_spark_hive(
841
+ df=save_to_warehouse(load_data_df, table_name),
842
+ table_name=table_name,
843
+ value_column=value_column,
844
+ from_time_dim=time_dim,
845
+ to_time_dim=to_time_dim,
846
+ scratch_dir_context=mapping_manager.scratch_dir_context,
847
+ time_based_data_adjustment=time_based_data_adjustment,
848
+ wrap_time_allowed=wrap_time_allowed,
849
+ )
850
+
851
+ case (BackendEngine.SPARK, False):
852
+ filename = persist_intermediate_table(
853
+ load_data_df,
854
+ mapping_manager.scratch_dir_context,
855
+ tag="query before time mapping",
856
+ )
857
+ load_data_df = map_time_dimension_with_chronify_spark_path(
858
+ df=read_dataframe(filename),
859
+ filename=filename,
860
+ value_column=value_column,
861
+ from_time_dim=time_dim,
862
+ to_time_dim=to_time_dim,
863
+ scratch_dir_context=mapping_manager.scratch_dir_context,
864
+ time_based_data_adjustment=time_based_data_adjustment,
865
+ wrap_time_allowed=wrap_time_allowed,
866
+ )
867
+ case (BackendEngine.DUCKDB, _):
868
+ load_data_df = map_time_dimension_with_chronify_duckdb(
869
+ df=load_data_df,
870
+ value_column=value_column,
871
+ from_time_dim=time_dim,
872
+ to_time_dim=to_time_dim,
873
+ scratch_dir_context=mapping_manager.scratch_dir_context,
874
+ time_based_data_adjustment=time_based_data_adjustment,
875
+ wrap_time_allowed=wrap_time_allowed,
876
+ )
877
+
878
+ if time_dim.model.is_time_zone_required_in_geography():
879
+ load_data_df = load_data_df.drop("time_zone")
880
+
881
+ if op.persist:
882
+ load_data_df = mapping_manager.persist_intermediate_table(load_data_df, op)
883
+ return load_data_df
884
+
885
+ def _validate_daylight_saving_adjustment(self, time_based_data_adjustment):
886
+ if (
887
+ time_based_data_adjustment.daylight_saving_adjustment
888
+ == DaylightSavingAdjustmentModel()
889
+ ):
890
+ return
891
+ time_dim = self._config.get_time_dimension()
892
+ if not isinstance(time_dim, IndexTimeDimensionConfig):
893
+ assert time_dim is not None
894
+ msg = f"time_based_data_adjustment.daylight_saving_adjustment does not apply to {time_dim.model.time_type=} time type, it applies to INDEX time type only."
895
+ logger.warning(msg)
896
+
897
+ def _remove_non_dimension_columns(self, df: DataFrame) -> DataFrame:
898
+ allowed_columns = self._list_dimension_columns(df)
899
+ return df.select(*allowed_columns)