dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. build_backend.py +93 -0
  2. dsgrid/__init__.py +22 -0
  3. dsgrid/api/__init__.py +0 -0
  4. dsgrid/api/api_manager.py +179 -0
  5. dsgrid/api/app.py +419 -0
  6. dsgrid/api/models.py +60 -0
  7. dsgrid/api/response_models.py +116 -0
  8. dsgrid/apps/__init__.py +0 -0
  9. dsgrid/apps/project_viewer/app.py +216 -0
  10. dsgrid/apps/registration_gui.py +444 -0
  11. dsgrid/chronify.py +32 -0
  12. dsgrid/cli/__init__.py +0 -0
  13. dsgrid/cli/common.py +120 -0
  14. dsgrid/cli/config.py +176 -0
  15. dsgrid/cli/download.py +13 -0
  16. dsgrid/cli/dsgrid.py +157 -0
  17. dsgrid/cli/dsgrid_admin.py +92 -0
  18. dsgrid/cli/install_notebooks.py +62 -0
  19. dsgrid/cli/query.py +729 -0
  20. dsgrid/cli/registry.py +1862 -0
  21. dsgrid/cloud/__init__.py +0 -0
  22. dsgrid/cloud/cloud_storage_interface.py +140 -0
  23. dsgrid/cloud/factory.py +31 -0
  24. dsgrid/cloud/fake_storage_interface.py +37 -0
  25. dsgrid/cloud/s3_storage_interface.py +156 -0
  26. dsgrid/common.py +36 -0
  27. dsgrid/config/__init__.py +0 -0
  28. dsgrid/config/annual_time_dimension_config.py +194 -0
  29. dsgrid/config/common.py +142 -0
  30. dsgrid/config/config_base.py +148 -0
  31. dsgrid/config/dataset_config.py +907 -0
  32. dsgrid/config/dataset_schema_handler_factory.py +46 -0
  33. dsgrid/config/date_time_dimension_config.py +136 -0
  34. dsgrid/config/dimension_config.py +54 -0
  35. dsgrid/config/dimension_config_factory.py +65 -0
  36. dsgrid/config/dimension_mapping_base.py +350 -0
  37. dsgrid/config/dimension_mappings_config.py +48 -0
  38. dsgrid/config/dimensions.py +1025 -0
  39. dsgrid/config/dimensions_config.py +71 -0
  40. dsgrid/config/file_schema.py +190 -0
  41. dsgrid/config/index_time_dimension_config.py +80 -0
  42. dsgrid/config/input_dataset_requirements.py +31 -0
  43. dsgrid/config/mapping_tables.py +209 -0
  44. dsgrid/config/noop_time_dimension_config.py +42 -0
  45. dsgrid/config/project_config.py +1462 -0
  46. dsgrid/config/registration_models.py +188 -0
  47. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  48. dsgrid/config/simple_models.py +49 -0
  49. dsgrid/config/supplemental_dimension.py +29 -0
  50. dsgrid/config/time_dimension_base_config.py +192 -0
  51. dsgrid/data_models.py +155 -0
  52. dsgrid/dataset/__init__.py +0 -0
  53. dsgrid/dataset/dataset.py +123 -0
  54. dsgrid/dataset/dataset_expression_handler.py +86 -0
  55. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  56. dsgrid/dataset/dataset_schema_handler_base.py +945 -0
  57. dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
  58. dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
  59. dsgrid/dataset/growth_rates.py +162 -0
  60. dsgrid/dataset/models.py +51 -0
  61. dsgrid/dataset/table_format_handler_base.py +257 -0
  62. dsgrid/dataset/table_format_handler_factory.py +17 -0
  63. dsgrid/dataset/unpivoted_table.py +121 -0
  64. dsgrid/dimension/__init__.py +0 -0
  65. dsgrid/dimension/base_models.py +230 -0
  66. dsgrid/dimension/dimension_filters.py +308 -0
  67. dsgrid/dimension/standard.py +252 -0
  68. dsgrid/dimension/time.py +352 -0
  69. dsgrid/dimension/time_utils.py +103 -0
  70. dsgrid/dsgrid_rc.py +88 -0
  71. dsgrid/exceptions.py +105 -0
  72. dsgrid/filesystem/__init__.py +0 -0
  73. dsgrid/filesystem/cloud_filesystem.py +32 -0
  74. dsgrid/filesystem/factory.py +32 -0
  75. dsgrid/filesystem/filesystem_interface.py +136 -0
  76. dsgrid/filesystem/local_filesystem.py +74 -0
  77. dsgrid/filesystem/s3_filesystem.py +118 -0
  78. dsgrid/loggers.py +132 -0
  79. dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
  80. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
  81. dsgrid/notebooks/registration.ipynb +48 -0
  82. dsgrid/notebooks/start_notebook.sh +11 -0
  83. dsgrid/project.py +451 -0
  84. dsgrid/query/__init__.py +0 -0
  85. dsgrid/query/dataset_mapping_plan.py +142 -0
  86. dsgrid/query/derived_dataset.py +388 -0
  87. dsgrid/query/models.py +728 -0
  88. dsgrid/query/query_context.py +287 -0
  89. dsgrid/query/query_submitter.py +994 -0
  90. dsgrid/query/report_factory.py +19 -0
  91. dsgrid/query/report_peak_load.py +70 -0
  92. dsgrid/query/reports_base.py +20 -0
  93. dsgrid/registry/__init__.py +0 -0
  94. dsgrid/registry/bulk_register.py +165 -0
  95. dsgrid/registry/common.py +287 -0
  96. dsgrid/registry/config_update_checker_base.py +63 -0
  97. dsgrid/registry/data_store_factory.py +34 -0
  98. dsgrid/registry/data_store_interface.py +74 -0
  99. dsgrid/registry/dataset_config_generator.py +158 -0
  100. dsgrid/registry/dataset_registry_manager.py +950 -0
  101. dsgrid/registry/dataset_update_checker.py +16 -0
  102. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  103. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  104. dsgrid/registry/dimension_registry_manager.py +413 -0
  105. dsgrid/registry/dimension_update_checker.py +16 -0
  106. dsgrid/registry/duckdb_data_store.py +207 -0
  107. dsgrid/registry/filesystem_data_store.py +150 -0
  108. dsgrid/registry/filter_registry_manager.py +123 -0
  109. dsgrid/registry/project_config_generator.py +57 -0
  110. dsgrid/registry/project_registry_manager.py +1623 -0
  111. dsgrid/registry/project_update_checker.py +48 -0
  112. dsgrid/registry/registration_context.py +223 -0
  113. dsgrid/registry/registry_auto_updater.py +316 -0
  114. dsgrid/registry/registry_database.py +667 -0
  115. dsgrid/registry/registry_interface.py +446 -0
  116. dsgrid/registry/registry_manager.py +558 -0
  117. dsgrid/registry/registry_manager_base.py +367 -0
  118. dsgrid/registry/versioning.py +92 -0
  119. dsgrid/rust_ext/__init__.py +14 -0
  120. dsgrid/rust_ext/find_minimal_patterns.py +129 -0
  121. dsgrid/spark/__init__.py +0 -0
  122. dsgrid/spark/functions.py +589 -0
  123. dsgrid/spark/types.py +110 -0
  124. dsgrid/tests/__init__.py +0 -0
  125. dsgrid/tests/common.py +140 -0
  126. dsgrid/tests/make_us_data_registry.py +265 -0
  127. dsgrid/tests/register_derived_datasets.py +103 -0
  128. dsgrid/tests/utils.py +25 -0
  129. dsgrid/time/__init__.py +0 -0
  130. dsgrid/time/time_conversions.py +80 -0
  131. dsgrid/time/types.py +67 -0
  132. dsgrid/units/__init__.py +0 -0
  133. dsgrid/units/constants.py +113 -0
  134. dsgrid/units/convert.py +71 -0
  135. dsgrid/units/energy.py +145 -0
  136. dsgrid/units/power.py +87 -0
  137. dsgrid/utils/__init__.py +0 -0
  138. dsgrid/utils/dataset.py +830 -0
  139. dsgrid/utils/files.py +179 -0
  140. dsgrid/utils/filters.py +125 -0
  141. dsgrid/utils/id_remappings.py +100 -0
  142. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  143. dsgrid/utils/py_expression_eval/README.md +8 -0
  144. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  145. dsgrid/utils/py_expression_eval/tests.py +283 -0
  146. dsgrid/utils/run_command.py +70 -0
  147. dsgrid/utils/scratch_dir_context.py +65 -0
  148. dsgrid/utils/spark.py +918 -0
  149. dsgrid/utils/spark_partition.py +98 -0
  150. dsgrid/utils/timing.py +239 -0
  151. dsgrid/utils/utilities.py +221 -0
  152. dsgrid/utils/versioning.py +36 -0
  153. dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
  154. dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
  155. dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
  156. dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
  157. dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,51 @@
1
+ from enum import StrEnum
2
+ from typing import Annotated, Literal, Union
3
+
4
+ from pydantic import Field
5
+
6
+ from dsgrid.data_models import DSGBaseModel
7
+ from dsgrid.dimension.base_models import DimensionType
8
+
9
+
10
+ class ValueFormat(StrEnum):
11
+ """Defines the format of value columns in a dataset."""
12
+
13
+ PIVOTED = "pivoted"
14
+ STACKED = "stacked"
15
+
16
+
17
+ class TableFormat(StrEnum):
18
+ """Defines the table structure of a dataset."""
19
+
20
+ ONE_TABLE = "one_table"
21
+ TWO_TABLE = "two_table"
22
+
23
+
24
+ # Keep old name as alias for backward compatibility during migration
25
+ TableFormatType = ValueFormat
26
+
27
+
28
+ class PivotedTableFormatModel(DSGBaseModel):
29
+ """Defines a pivoted table format where one dimension's records are columns."""
30
+
31
+ format_type: Literal[ValueFormat.PIVOTED] = ValueFormat.PIVOTED
32
+ pivoted_dimension_type: DimensionType = Field(
33
+ title="pivoted_dimension_type",
34
+ description="The dimension type whose records are columns that contain data values.",
35
+ )
36
+
37
+
38
+ class StackedTableFormatModel(DSGBaseModel):
39
+ """Defines a stacked (unpivoted) table format with a single value column."""
40
+
41
+ format_type: Literal[ValueFormat.STACKED] = ValueFormat.STACKED
42
+
43
+
44
+ # Alias for backward compatibility
45
+ UnpivotedTableFormatModel = StackedTableFormatModel
46
+
47
+
48
+ TableFormatModel = Annotated[
49
+ Union[PivotedTableFormatModel, StackedTableFormatModel],
50
+ Field(discriminator="format_type"),
51
+ ]
@@ -0,0 +1,257 @@
1
+ import abc
2
+ import logging
3
+ from typing import Iterable
4
+
5
+ from dsgrid.config.project_config import ProjectConfig
6
+ from dsgrid.dimension.base_models import DimensionCategory, DimensionType
7
+ from dsgrid.query.query_context import QueryContext
8
+ from dsgrid.query.models import (
9
+ AggregationModel,
10
+ ColumnModel,
11
+ ColumnType,
12
+ DatasetDimensionsMetadataModel,
13
+ DimensionMetadataModel,
14
+ )
15
+ from dsgrid.spark.types import DataFrame
16
+ from dsgrid.utils.dataset import map_stacked_dimension, remove_invalid_null_timestamps
17
+ from dsgrid.utils.spark import persist_intermediate_query
18
+ from dsgrid.utils.timing import track_timing, timer_stats_collector
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class TableFormatHandlerBase(abc.ABC):
25
+ """Base class for table format handers"""
26
+
27
+ def __init__(self, project_config: ProjectConfig, dataset_id: str | None = None):
28
+ self._project_config = project_config
29
+ self._dataset_id = dataset_id
30
+
31
+ def add_columns(
32
+ self,
33
+ df: DataFrame,
34
+ column_models: list[ColumnModel],
35
+ context: QueryContext,
36
+ value_columns: Iterable[str],
37
+ ) -> DataFrame:
38
+ """Add columns to the dataframe. For example, suppose the geography dimension is at
39
+ county resolution and the user wants to add a column for state.
40
+
41
+ Parameters
42
+ ----------
43
+ df : pyspark.sql.DataFrame
44
+ column_models : list
45
+ context : QueryContext
46
+ value_columns: Iterable[str]
47
+ Columns in the dataframe that contain load values.
48
+ """
49
+ columns = set(df.columns)
50
+ all_base_names = self.project_config.list_dimension_names(category=DimensionCategory.BASE)
51
+ for column in column_models:
52
+ name = column.dimension_name
53
+ if name in all_base_names or name in columns:
54
+ continue
55
+ supp_dim = self._project_config.get_dimension_with_records(name)
56
+ existing_metadata = context.get_dimension_metadata(
57
+ supp_dim.model.dimension_type, dataset_id=self._dataset_id
58
+ )
59
+ existing_base_metadata = [
60
+ x for x in existing_metadata if x.dimension_name in all_base_names
61
+ ]
62
+ if len(existing_base_metadata) != 1:
63
+ msg = (
64
+ f"Bug: expected one base metadata object for {supp_dim.model.dimension_type}: "
65
+ "{existing_base_metadata}"
66
+ )
67
+ raise Exception(msg)
68
+ base_dim_name = existing_base_metadata[0].dimension_name
69
+ if base_dim_name not in all_base_names:
70
+ msg = f"Bug: Expected {base_dim_name} to be a base dimension."
71
+ raise Exception(msg)
72
+ base_dim = self._project_config.get_dimension_with_records(base_dim_name)
73
+ records = self._project_config.get_base_to_supplemental_mapping_records(
74
+ base_dim, supp_dim
75
+ )
76
+
77
+ if column.function is not None:
78
+ # TODO #200: Do we want to allow this?
79
+ msg = f"Applying a SQL function to added column={name} is not supported yet"
80
+ raise NotImplementedError(msg)
81
+ expected_base_dim_cols = context.get_dimension_column_names_by_name(
82
+ supp_dim.model.dimension_type,
83
+ base_dim.model.name,
84
+ dataset_id=self._dataset_id,
85
+ )
86
+ if len(expected_base_dim_cols) > 1:
87
+ msg = "Bug: Non-time dimensions cannot have more than one base dimension column"
88
+ raise Exception(msg)
89
+ expected_base_dim_col = expected_base_dim_cols[0]
90
+ df = map_stacked_dimension(
91
+ df,
92
+ records,
93
+ expected_base_dim_col,
94
+ drop_column=False,
95
+ to_column=name,
96
+ )
97
+ if context.model.result.column_type == ColumnType.DIMENSION_NAMES:
98
+ assert supp_dim.model.dimension_type != DimensionType.TIME
99
+ column_names = [name]
100
+ else:
101
+ column_names = [expected_base_dim_col]
102
+ context.add_dimension_metadata(
103
+ supp_dim.model.dimension_type,
104
+ DimensionMetadataModel(dimension_name=name, column_names=column_names),
105
+ dataset_id=self.dataset_id,
106
+ )
107
+
108
+ if "fraction" in df.columns:
109
+ for col in value_columns:
110
+ df = df.withColumn(col, df[col] * df["fraction"])
111
+ df = df.drop("fraction")
112
+
113
+ return df
114
+
115
+ @abc.abstractmethod
116
+ def process_aggregations(
117
+ self, df: DataFrame, aggregations: list[AggregationModel], context: QueryContext
118
+ ) -> DataFrame:
119
+ """Aggregate the dimensional data as specified by aggregations.
120
+
121
+ Parameters
122
+ ----------
123
+ df : pyspark.sql.DataFrame
124
+ aggregations : AggregationModel
125
+ context : QueryContext
126
+
127
+ Returns
128
+ -------
129
+ pyspark.sql.DataFrame
130
+
131
+ """
132
+
133
+ @property
134
+ def project_config(self) -> ProjectConfig:
135
+ """Return the project config of the dataset being processed."""
136
+ return self._project_config
137
+
138
+ @property
139
+ def dataset_id(self) -> str | None:
140
+ """Return the ID of the dataset being processed."""
141
+ return self._dataset_id
142
+
143
+ def convert_columns_to_query_names(
144
+ self, df: DataFrame, dataset_id: str, context: QueryContext
145
+ ) -> DataFrame:
146
+ """Convert columns from dimension types to dimension query names."""
147
+ columns = set(df.columns)
148
+ for dim_type in DimensionType:
149
+ if dim_type == DimensionType.TIME:
150
+ time_dim = self._project_config.get_base_time_dimension()
151
+ df = time_dim.map_timestamp_load_data_columns_for_query_name(df)
152
+ elif dim_type.value in columns:
153
+ existing_col = dim_type.value
154
+ new_cols = context.get_dimension_column_names(dim_type, dataset_id=dataset_id)
155
+ assert len(new_cols) == 1, f"{dim_type=} {new_cols=}"
156
+ new_col = next(iter(new_cols))
157
+ if existing_col != new_col:
158
+ df = df.withColumnRenamed(existing_col, new_col)
159
+ logger.debug("Converted column from %s to %s", existing_col, new_col)
160
+
161
+ return df
162
+
163
+ def replace_ids_with_names(self, df: DataFrame) -> DataFrame:
164
+ """Replace dimension record IDs with names."""
165
+ assert not {"id", "name"}.intersection(df.columns), df.columns
166
+ orig = df
167
+ all_query_names = self._project_config.get_dimension_names_mapped_to_type()
168
+ for name in set(df.columns).intersection(all_query_names.keys()):
169
+ if all_query_names[name] != DimensionType.TIME:
170
+ # Time doesn't have records.
171
+ dim_config = self._project_config.get_dimension_with_records(name)
172
+ records = dim_config.get_records_dataframe().select("id", "name")
173
+ df = (
174
+ df.join(records, on=df[name] == records["id"])
175
+ .drop("id", name)
176
+ .withColumnRenamed("name", name)
177
+ )
178
+ assert df.count() == orig.count(), f"counts changed {df.count()} {orig.count()}"
179
+ return df
180
+
181
+ @staticmethod
182
+ def _add_column_to_dim_type(
183
+ column: ColumnModel, dim_type: DimensionType, column_to_dim_type: dict[str, DimensionType]
184
+ ) -> None:
185
+ name = column.get_column_name()
186
+ if name in column_to_dim_type:
187
+ assert dim_type == column_to_dim_type[name], f"{name=} {column_to_dim_type}"
188
+ column_to_dim_type[name] = dim_type
189
+
190
+ def _build_group_by_columns(
191
+ self,
192
+ columns: list[ColumnModel],
193
+ context: QueryContext,
194
+ final_metadata: DatasetDimensionsMetadataModel,
195
+ ):
196
+ group_by_cols: list[str] = []
197
+ for column in columns:
198
+ dim = self._project_config.get_dimension(column.dimension_name)
199
+ dim_type = dim.model.dimension_type
200
+ match context.model.result.column_type:
201
+ case ColumnType.DIMENSION_TYPES:
202
+ column_names = context.get_dimension_column_names_by_name(
203
+ dim_type, column.dimension_name, dataset_id=self._dataset_id
204
+ )
205
+ if dim_type == DimensionType.TIME:
206
+ group_by_cols += column_names
207
+ else:
208
+ group_by_cols.append(dim_type.value)
209
+ case ColumnType.DIMENSION_NAMES:
210
+ column_names = [column.get_column_name()]
211
+ expr = self._make_group_by_column_expr(column)
212
+ group_by_cols.append(expr)
213
+ if not isinstance(expr, str) or expr != column.dimension_name:
214
+ # In this case we are replacing any existing query name with an expression
215
+ # or alias, and so the old name must be removed.
216
+ final_metadata.remove_metadata(dim_type, column.dimension_name)
217
+ case _:
218
+ msg = f"Bug: unhandled: {context.model.result.column_type}"
219
+ raise NotImplementedError(msg)
220
+ final_metadata.add_metadata(
221
+ dim_type,
222
+ DimensionMetadataModel(
223
+ dimension_name=column.dimension_name, column_names=column_names
224
+ ),
225
+ )
226
+ return group_by_cols
227
+
228
+ @staticmethod
229
+ def _make_group_by_column_expr(column):
230
+ if column.function is None:
231
+ expr = column.dimension_name
232
+ else:
233
+ expr = column.function(column.dimension_name)
234
+ if column.alias is not None:
235
+ expr = expr.alias(column.alias)
236
+ return expr
237
+
238
+ @track_timing(timer_stats_collector)
239
+ def _remove_invalid_null_timestamps(self, df: DataFrame, orig_id, context: QueryContext):
240
+ if id(df) != orig_id:
241
+ # The table could have NULL timestamps that designate expected-missing data.
242
+ # Those rows could be obsolete after aggregating stacked dimensions.
243
+ # This is an expensive operation, so only do it if the dataframe changed.
244
+ value_columns = context.get_value_columns()
245
+ if not value_columns:
246
+ msg = "Bug: value_columns cannot be empty"
247
+ raise Exception(msg)
248
+ time_columns = context.get_dimension_column_names(
249
+ DimensionType.TIME, dataset_id=self._dataset_id
250
+ )
251
+ if time_columns:
252
+ # Persist the query up to this point to avoid multiple evaluations.
253
+ df = persist_intermediate_query(df, context.scratch_dir_context)
254
+ stacked_columns = set(df.columns) - value_columns.union(time_columns)
255
+ df = remove_invalid_null_timestamps(df, time_columns, stacked_columns)
256
+ logger.debug("Removed any rows with invalid null timestamps")
257
+ return df
@@ -0,0 +1,17 @@
1
+ from dsgrid.dataset.models import ValueFormat
2
+ from .table_format_handler_base import TableFormatHandlerBase
3
+
4
+ from .unpivoted_table import UnpivotedTableHandler
5
+
6
+
7
+ def make_table_format_handler(
8
+ value_format: ValueFormat, project_config, dataset_id=None
9
+ ) -> TableFormatHandlerBase:
10
+ """Return a format handler for the passed value format."""
11
+ match value_format:
12
+ case ValueFormat.STACKED:
13
+ handler = UnpivotedTableHandler(project_config, dataset_id=dataset_id)
14
+ case _:
15
+ raise NotImplementedError(str(value_format))
16
+
17
+ return handler
@@ -0,0 +1,121 @@
1
+ import logging
2
+
3
+ from dsgrid.common import VALUE_COLUMN
4
+ from dsgrid.dimension.base_models import DimensionType
5
+ from dsgrid.query.models import (
6
+ AggregationModel,
7
+ ColumnModel,
8
+ ColumnType,
9
+ DatasetDimensionsMetadataModel,
10
+ )
11
+ from dsgrid.query.query_context import QueryContext
12
+ from dsgrid.spark.types import DataFrame
13
+ from dsgrid.units.convert import convert_units_unpivoted
14
+ from dsgrid.dataset.table_format_handler_base import TableFormatHandlerBase
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class UnpivotedTableHandler(TableFormatHandlerBase):
21
+ """Implements behavior for tables stored in unpivoted format."""
22
+
23
+ def process_aggregations(
24
+ self, df: DataFrame, aggregations: list[AggregationModel], context: QueryContext
25
+ ):
26
+ orig_id = id(df)
27
+ df = self.process_stacked_aggregations(df, aggregations, context)
28
+ df = self._remove_invalid_null_timestamps(df, orig_id, context)
29
+ return df
30
+
31
+ def process_stacked_aggregations(
32
+ self, df, aggregations: list[AggregationModel], context: QueryContext
33
+ ):
34
+ """Aggregate the stacked dimensional data as specified by aggregations.
35
+
36
+ Parameters
37
+ ----------
38
+ df : pyspark.sql.DataFrame
39
+ aggregations : AggregationModel
40
+ context : QueryContext
41
+
42
+ Returns
43
+ -------
44
+ pyspark.sql.DataFrame
45
+
46
+ """
47
+ if not aggregations:
48
+ return df
49
+
50
+ final_metadata = DatasetDimensionsMetadataModel()
51
+ dim_type_to_base_query_name = self.project_config.get_dimension_type_to_base_name_mapping()
52
+ column_to_dim_type: dict[str, DimensionType] = {}
53
+ dropped_dimensions = set()
54
+ for agg in aggregations:
55
+ metric_query_name = None
56
+ columns: list[ColumnModel] = []
57
+ for dim_type, column in agg.iter_dimensions_to_keep():
58
+ assert dim_type not in dropped_dimensions, dim_type
59
+ columns.append(column)
60
+ self._add_column_to_dim_type(column, dim_type, column_to_dim_type)
61
+ if dim_type == DimensionType.METRIC:
62
+ metric_query_name = column.dimension_name
63
+
64
+ if metric_query_name is None:
65
+ msg = f"Bug: A metric dimension is not included in {agg}"
66
+ raise Exception(msg)
67
+
68
+ dropped_dimensions.update(set(agg.list_dropped_dimensions()))
69
+ if not columns:
70
+ continue
71
+
72
+ df = self.add_columns(df, columns, context, [VALUE_COLUMN])
73
+ group_by_cols = self._build_group_by_columns(columns, context, final_metadata)
74
+ op = agg.aggregation_function
75
+ df = df.groupBy(*group_by_cols).agg(op(VALUE_COLUMN).alias(VALUE_COLUMN))
76
+
77
+ if metric_query_name not in dim_type_to_base_query_name[DimensionType.METRIC]:
78
+ to_dim = self.project_config.get_dimension_with_records(metric_query_name)
79
+ assert context.base_dimension_names.metric is not None
80
+ mapping = self.project_config.get_base_to_supplemental_config(
81
+ self.project_config.get_dimension_with_records(
82
+ context.base_dimension_names.metric
83
+ ),
84
+ to_dim,
85
+ )
86
+ from_dim_id = mapping.model.from_dimension.dimension_id
87
+ from_records = self.project_config.get_base_dimension_records_by_id(from_dim_id)
88
+ mapping_records = mapping.get_records_dataframe()
89
+ to_unit_records = to_dim.get_records_dataframe()
90
+ df = convert_units_unpivoted(
91
+ df,
92
+ _get_metric_column_name(context, metric_query_name),
93
+ from_records,
94
+ mapping_records,
95
+ to_unit_records,
96
+ )
97
+
98
+ logger.debug(
99
+ "Aggregated dimensions with groupBy %s and operation %s",
100
+ group_by_cols,
101
+ op.__name__,
102
+ )
103
+
104
+ for dim_type in DimensionType:
105
+ metadata = final_metadata.get_metadata(dim_type)
106
+ if dim_type in dropped_dimensions and metadata:
107
+ metadata.clear()
108
+ context.replace_dimension_metadata(dim_type, metadata, dataset_id=self.dataset_id)
109
+ return df
110
+
111
+
112
+ def _get_metric_column_name(context: QueryContext, metric_query_name):
113
+ match context.model.result.column_type:
114
+ case ColumnType.DIMENSION_TYPES:
115
+ metric_column = DimensionType.METRIC.value
116
+ case ColumnType.DIMENSION_NAMES:
117
+ metric_column = metric_query_name
118
+ case _:
119
+ msg = f"Bug: unhandled: {context.model.result.column_type}"
120
+ raise NotImplementedError(msg)
121
+ return metric_column
File without changes
@@ -0,0 +1,230 @@
1
+ """Dimension types for dsgrid"""
2
+
3
+ from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
4
+
5
+ from pydantic import Field
6
+
7
+ from dsgrid.exceptions import DSGInvalidDimension
8
+ from dsgrid.data_models import DSGBaseModel, DSGEnum
9
+ from dsgrid.utils.utilities import check_uniqueness
10
+
11
+
12
+ class DimensionType(DSGEnum):
13
+ """Dimension types"""
14
+
15
+ METRIC = "metric"
16
+ GEOGRAPHY = "geography"
17
+ SECTOR = "sector"
18
+ SUBSECTOR = "subsector"
19
+ TIME = "time"
20
+ WEATHER_YEAR = "weather_year"
21
+ MODEL_YEAR = "model_year"
22
+ SCENARIO = "scenario"
23
+
24
+ def __lt__(self, other):
25
+ return self.value < other.value
26
+
27
+ @classmethod
28
+ def from_column(cls, column: str) -> "DimensionType":
29
+ try:
30
+ return cls(column)
31
+ except ValueError:
32
+ msg = f"column={column} is not expected or of a known dimension type."
33
+ raise DSGInvalidDimension(msg)
34
+
35
+ @staticmethod
36
+ def get_dimension_types_allowed_as_columns() -> set["DimensionType"]:
37
+ """Return the dimension types that may exist in the data table as columns."""
38
+ return {x for x in DimensionType if x != DimensionType.TIME}
39
+
40
+ @staticmethod
41
+ def get_allowed_dimension_column_names() -> set[str]:
42
+ """Return the dimension types that may exist in the data table as columns."""
43
+ return {x.value for x in DimensionType.get_dimension_types_allowed_as_columns()}
44
+
45
+
46
+ class DimensionCategory(DSGEnum):
47
+ """Types of dimension categories in a project"""
48
+
49
+ BASE = "base"
50
+ SUBSET = "subset"
51
+ SUPPLEMENTAL = "supplemental"
52
+
53
+
54
+ class DimensionRecordBaseModel(DSGBaseModel):
55
+ """Base class for all dsgrid dimension models"""
56
+
57
+ # TODO: add support/links for docs
58
+ id: str = Field(
59
+ title="ID",
60
+ description="Unique identifier within a dimension",
61
+ )
62
+ name: str = Field(
63
+ title="name",
64
+ description="User-defined name",
65
+ )
66
+
67
+
68
+ class MetricDimensionBaseModel(DimensionRecordBaseModel):
69
+ """Base class for all metric dimensions (e.g. EnergyEndUse)"""
70
+
71
+
72
+ class GeographyDimensionBaseModel(DimensionRecordBaseModel):
73
+ """Base class for all geography dimensions"""
74
+
75
+ time_zone: str | None = Field(
76
+ default=None,
77
+ title="Local Prevailing Time Zone",
78
+ description="""
79
+ These time zone information are used in reference to project time_zone
80
+ to convert between project time and local times as necessary.
81
+ All Prevailing time_zones account for daylight savings time.
82
+ If a location does not observe daylight savings, use Standard time_zones.
83
+ """,
84
+ )
85
+
86
+
87
+ class ModelYearDimensionBaseModel(DimensionRecordBaseModel):
88
+ """Base class for all model year dimensions"""
89
+
90
+
91
+ class ScenarioDimensionBaseModel(DimensionRecordBaseModel):
92
+ """Base class for all scenario dimensions"""
93
+
94
+
95
+ class SectorDimensionBaseModel(DimensionRecordBaseModel):
96
+ """Base class for all sector dimensions"""
97
+
98
+
99
+ class SubsectorDimensionBaseModel(DimensionRecordBaseModel):
100
+ """Base class for all subsector dimensions"""
101
+
102
+
103
+ class WeatherYearDimensionBaseModel(DimensionRecordBaseModel):
104
+ """Base class for weather year dimensions"""
105
+
106
+
107
+ _DIMENSION_TO_MODEL = {
108
+ DimensionType.METRIC: MetricDimensionBaseModel,
109
+ DimensionType.GEOGRAPHY: GeographyDimensionBaseModel,
110
+ DimensionType.SECTOR: SectorDimensionBaseModel,
111
+ DimensionType.SUBSECTOR: SubsectorDimensionBaseModel,
112
+ DimensionType.WEATHER_YEAR: WeatherYearDimensionBaseModel,
113
+ DimensionType.MODEL_YEAR: ModelYearDimensionBaseModel,
114
+ DimensionType.SCENARIO: ScenarioDimensionBaseModel,
115
+ }
116
+
117
+
118
+ class DatasetDimensionRequirements(DSGBaseModel):
119
+ """Defines the requirements for checking a dataset prior to registration."""
120
+
121
+ check_time_consistency: bool = True
122
+ check_dimension_associations: bool = True
123
+ require_all_dimension_types: bool = True
124
+
125
+
126
+ def get_record_base_model(type_enum):
127
+ """Return the dimension model class for a DimensionType."""
128
+ dim_model = _DIMENSION_TO_MODEL.get(type_enum)
129
+ if dim_model is None:
130
+ msg = f"no mapping for {type_enum}"
131
+ raise DSGInvalidDimension(msg)
132
+ return dim_model
133
+
134
+
135
+ def check_required_dimensions(dimensions, tag):
136
+ """Check that a project or dataset config contains all required dimensions.
137
+
138
+ Parameters
139
+ ----------
140
+ dimensions : list
141
+ list of DimensionReferenceModel
142
+ tag : str
143
+ User-defined string to include in exception messages
144
+
145
+ Raises
146
+ ------
147
+ ValueError
148
+ Raised if a required dimension is not provided.
149
+
150
+ """
151
+ dimension_types = {x.dimension_type for x in dimensions}
152
+ required_dim_types = set(DimensionType)
153
+ missing = required_dim_types.difference(dimension_types)
154
+ if missing:
155
+ msg = f"Required dimension(s) {missing} are not in {tag}."
156
+ raise ValueError(msg)
157
+
158
+ check_uniqueness((x.dimension_type for x in dimensions), tag)
159
+
160
+
161
+ def check_required_dataset_dimensions(dimensions, requirements: DatasetDimensionRequirements, tag):
162
+ """Check that a dataset config contains all required dimensions.
163
+
164
+ Parameters
165
+ ----------
166
+ dimensions : list
167
+ list of DimensionReferenceModel
168
+ requirements : DatasetDimensionRequirements
169
+ Defines the dimension requirements in the dataset that must be checked.
170
+ tag : str
171
+ User-defined string to include in exception messages
172
+
173
+ Raises
174
+ ------
175
+ ValueError
176
+ Raised if a required dimension is not provided.
177
+
178
+ """
179
+ dimension_types = {x.dimension_type for x in dimensions}
180
+ if requirements.require_all_dimension_types:
181
+ required_dim_types = set(DimensionType)
182
+ missing = required_dim_types.difference(dimension_types)
183
+ if missing:
184
+ msg = f"Required dimension(s) {missing} are not in {tag}."
185
+ raise ValueError(msg)
186
+
187
+ check_uniqueness((x.dimension_type for x in dimensions), tag)
188
+
189
+
190
+ def check_timezone_in_geography(dimension, err_msg=None):
191
+ """Check that a geography dimension contains valid time zones in records.
192
+
193
+ Parameters
194
+ ----------
195
+ dimension : DimensionModel
196
+ err_msg : str | None
197
+ Optional error message
198
+
199
+ Raises
200
+ ------
201
+ DSGInvalidDimension
202
+ Raised if a required dimension is not provided.
203
+
204
+ """
205
+ if dimension.dimension_type != DimensionType.GEOGRAPHY:
206
+ msg = (
207
+ f"Dimension has type {dimension.dimension_type}, "
208
+ "Can only check time_zone for Geography."
209
+ )
210
+ raise DSGInvalidDimension(msg)
211
+
212
+ if not hasattr(dimension.records[0], "time_zone"):
213
+ if err_msg is None:
214
+ err_msg = "These geography dimension records must include a time_zone column."
215
+ raise ValueError(err_msg)
216
+
217
+ record_tzs = {rec.time_zone for rec in dimension.records if rec.time_zone is not None}
218
+ invalid_tzs = []
219
+ for tz in record_tzs:
220
+ try:
221
+ ZoneInfo(tz)
222
+ except ZoneInfoNotFoundError:
223
+ invalid_tzs.append(tz)
224
+
225
+ if invalid_tzs:
226
+ msg = (
227
+ f"Geography dimension {dimension.dimension_id} has invalid time zone(s) in records "
228
+ f"{dimension.filename}: {invalid_tzs}. Use IANA time zone names only."
229
+ )
230
+ raise DSGInvalidDimension(msg)