dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. build_backend.py +93 -0
  2. dsgrid/__init__.py +22 -0
  3. dsgrid/api/__init__.py +0 -0
  4. dsgrid/api/api_manager.py +179 -0
  5. dsgrid/api/app.py +419 -0
  6. dsgrid/api/models.py +60 -0
  7. dsgrid/api/response_models.py +116 -0
  8. dsgrid/apps/__init__.py +0 -0
  9. dsgrid/apps/project_viewer/app.py +216 -0
  10. dsgrid/apps/registration_gui.py +444 -0
  11. dsgrid/chronify.py +32 -0
  12. dsgrid/cli/__init__.py +0 -0
  13. dsgrid/cli/common.py +120 -0
  14. dsgrid/cli/config.py +176 -0
  15. dsgrid/cli/download.py +13 -0
  16. dsgrid/cli/dsgrid.py +157 -0
  17. dsgrid/cli/dsgrid_admin.py +92 -0
  18. dsgrid/cli/install_notebooks.py +62 -0
  19. dsgrid/cli/query.py +729 -0
  20. dsgrid/cli/registry.py +1862 -0
  21. dsgrid/cloud/__init__.py +0 -0
  22. dsgrid/cloud/cloud_storage_interface.py +140 -0
  23. dsgrid/cloud/factory.py +31 -0
  24. dsgrid/cloud/fake_storage_interface.py +37 -0
  25. dsgrid/cloud/s3_storage_interface.py +156 -0
  26. dsgrid/common.py +36 -0
  27. dsgrid/config/__init__.py +0 -0
  28. dsgrid/config/annual_time_dimension_config.py +194 -0
  29. dsgrid/config/common.py +142 -0
  30. dsgrid/config/config_base.py +148 -0
  31. dsgrid/config/dataset_config.py +907 -0
  32. dsgrid/config/dataset_schema_handler_factory.py +46 -0
  33. dsgrid/config/date_time_dimension_config.py +136 -0
  34. dsgrid/config/dimension_config.py +54 -0
  35. dsgrid/config/dimension_config_factory.py +65 -0
  36. dsgrid/config/dimension_mapping_base.py +350 -0
  37. dsgrid/config/dimension_mappings_config.py +48 -0
  38. dsgrid/config/dimensions.py +1025 -0
  39. dsgrid/config/dimensions_config.py +71 -0
  40. dsgrid/config/file_schema.py +190 -0
  41. dsgrid/config/index_time_dimension_config.py +80 -0
  42. dsgrid/config/input_dataset_requirements.py +31 -0
  43. dsgrid/config/mapping_tables.py +209 -0
  44. dsgrid/config/noop_time_dimension_config.py +42 -0
  45. dsgrid/config/project_config.py +1462 -0
  46. dsgrid/config/registration_models.py +188 -0
  47. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  48. dsgrid/config/simple_models.py +49 -0
  49. dsgrid/config/supplemental_dimension.py +29 -0
  50. dsgrid/config/time_dimension_base_config.py +192 -0
  51. dsgrid/data_models.py +155 -0
  52. dsgrid/dataset/__init__.py +0 -0
  53. dsgrid/dataset/dataset.py +123 -0
  54. dsgrid/dataset/dataset_expression_handler.py +86 -0
  55. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  56. dsgrid/dataset/dataset_schema_handler_base.py +945 -0
  57. dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
  58. dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
  59. dsgrid/dataset/growth_rates.py +162 -0
  60. dsgrid/dataset/models.py +51 -0
  61. dsgrid/dataset/table_format_handler_base.py +257 -0
  62. dsgrid/dataset/table_format_handler_factory.py +17 -0
  63. dsgrid/dataset/unpivoted_table.py +121 -0
  64. dsgrid/dimension/__init__.py +0 -0
  65. dsgrid/dimension/base_models.py +230 -0
  66. dsgrid/dimension/dimension_filters.py +308 -0
  67. dsgrid/dimension/standard.py +252 -0
  68. dsgrid/dimension/time.py +352 -0
  69. dsgrid/dimension/time_utils.py +103 -0
  70. dsgrid/dsgrid_rc.py +88 -0
  71. dsgrid/exceptions.py +105 -0
  72. dsgrid/filesystem/__init__.py +0 -0
  73. dsgrid/filesystem/cloud_filesystem.py +32 -0
  74. dsgrid/filesystem/factory.py +32 -0
  75. dsgrid/filesystem/filesystem_interface.py +136 -0
  76. dsgrid/filesystem/local_filesystem.py +74 -0
  77. dsgrid/filesystem/s3_filesystem.py +118 -0
  78. dsgrid/loggers.py +132 -0
  79. dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
  80. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
  81. dsgrid/notebooks/registration.ipynb +48 -0
  82. dsgrid/notebooks/start_notebook.sh +11 -0
  83. dsgrid/project.py +451 -0
  84. dsgrid/query/__init__.py +0 -0
  85. dsgrid/query/dataset_mapping_plan.py +142 -0
  86. dsgrid/query/derived_dataset.py +388 -0
  87. dsgrid/query/models.py +728 -0
  88. dsgrid/query/query_context.py +287 -0
  89. dsgrid/query/query_submitter.py +994 -0
  90. dsgrid/query/report_factory.py +19 -0
  91. dsgrid/query/report_peak_load.py +70 -0
  92. dsgrid/query/reports_base.py +20 -0
  93. dsgrid/registry/__init__.py +0 -0
  94. dsgrid/registry/bulk_register.py +165 -0
  95. dsgrid/registry/common.py +287 -0
  96. dsgrid/registry/config_update_checker_base.py +63 -0
  97. dsgrid/registry/data_store_factory.py +34 -0
  98. dsgrid/registry/data_store_interface.py +74 -0
  99. dsgrid/registry/dataset_config_generator.py +158 -0
  100. dsgrid/registry/dataset_registry_manager.py +950 -0
  101. dsgrid/registry/dataset_update_checker.py +16 -0
  102. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  103. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  104. dsgrid/registry/dimension_registry_manager.py +413 -0
  105. dsgrid/registry/dimension_update_checker.py +16 -0
  106. dsgrid/registry/duckdb_data_store.py +207 -0
  107. dsgrid/registry/filesystem_data_store.py +150 -0
  108. dsgrid/registry/filter_registry_manager.py +123 -0
  109. dsgrid/registry/project_config_generator.py +57 -0
  110. dsgrid/registry/project_registry_manager.py +1623 -0
  111. dsgrid/registry/project_update_checker.py +48 -0
  112. dsgrid/registry/registration_context.py +223 -0
  113. dsgrid/registry/registry_auto_updater.py +316 -0
  114. dsgrid/registry/registry_database.py +667 -0
  115. dsgrid/registry/registry_interface.py +446 -0
  116. dsgrid/registry/registry_manager.py +558 -0
  117. dsgrid/registry/registry_manager_base.py +367 -0
  118. dsgrid/registry/versioning.py +92 -0
  119. dsgrid/rust_ext/__init__.py +14 -0
  120. dsgrid/rust_ext/find_minimal_patterns.py +129 -0
  121. dsgrid/spark/__init__.py +0 -0
  122. dsgrid/spark/functions.py +589 -0
  123. dsgrid/spark/types.py +110 -0
  124. dsgrid/tests/__init__.py +0 -0
  125. dsgrid/tests/common.py +140 -0
  126. dsgrid/tests/make_us_data_registry.py +265 -0
  127. dsgrid/tests/register_derived_datasets.py +103 -0
  128. dsgrid/tests/utils.py +25 -0
  129. dsgrid/time/__init__.py +0 -0
  130. dsgrid/time/time_conversions.py +80 -0
  131. dsgrid/time/types.py +67 -0
  132. dsgrid/units/__init__.py +0 -0
  133. dsgrid/units/constants.py +113 -0
  134. dsgrid/units/convert.py +71 -0
  135. dsgrid/units/energy.py +145 -0
  136. dsgrid/units/power.py +87 -0
  137. dsgrid/utils/__init__.py +0 -0
  138. dsgrid/utils/dataset.py +830 -0
  139. dsgrid/utils/files.py +179 -0
  140. dsgrid/utils/filters.py +125 -0
  141. dsgrid/utils/id_remappings.py +100 -0
  142. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  143. dsgrid/utils/py_expression_eval/README.md +8 -0
  144. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  145. dsgrid/utils/py_expression_eval/tests.py +283 -0
  146. dsgrid/utils/run_command.py +70 -0
  147. dsgrid/utils/scratch_dir_context.py +65 -0
  148. dsgrid/utils/spark.py +918 -0
  149. dsgrid/utils/spark_partition.py +98 -0
  150. dsgrid/utils/timing.py +239 -0
  151. dsgrid/utils/utilities.py +221 -0
  152. dsgrid/utils/versioning.py +36 -0
  153. dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
  154. dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
  155. dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
  156. dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
  157. dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
dsgrid/query/models.py ADDED
@@ -0,0 +1,728 @@
1
+ import abc
2
+ import itertools
3
+ from enum import StrEnum
4
+ from typing import Any, Generator, Union, Literal, Self, TypeAlias
5
+
6
+ from pydantic import field_validator, model_validator, Field, field_serializer, ValidationInfo
7
+ from semver import VersionInfo
8
+ from typing_extensions import Annotated
9
+
10
+ from dsgrid.config.dimensions import DimensionReferenceModel
11
+ from dsgrid.config.project_config import DatasetBaseDimensionNamesModel
12
+ from dsgrid.data_models import DSGBaseModel, make_model_config
13
+ from dsgrid.dataset.models import (
14
+ TableFormatModel,
15
+ StackedTableFormatModel,
16
+ ValueFormat,
17
+ )
18
+ from dsgrid.dimension.base_models import DimensionType
19
+ from dsgrid.dimension.dimension_filters import (
20
+ DimensionFilterExpressionModel,
21
+ DimensionFilterExpressionRawModel,
22
+ DimensionFilterColumnOperatorModel,
23
+ DimensionFilterBetweenColumnOperatorModel,
24
+ SubsetDimensionFilterModel,
25
+ SupplementalDimensionFilterColumnOperatorModel,
26
+ )
27
+ from dsgrid.dimension.time import TimeBasedDataAdjustmentModel
28
+ from dsgrid.query.dataset_mapping_plan import (
29
+ DatasetMappingPlan,
30
+ )
31
+ from dsgrid.spark.types import F
32
+ from dsgrid.utils.files import compute_hash
33
+
34
+
35
+ DimensionFilters: TypeAlias = Annotated[
36
+ Union[
37
+ DimensionFilterExpressionModel,
38
+ DimensionFilterExpressionRawModel,
39
+ DimensionFilterColumnOperatorModel,
40
+ DimensionFilterBetweenColumnOperatorModel,
41
+ SubsetDimensionFilterModel,
42
+ SupplementalDimensionFilterColumnOperatorModel,
43
+ ],
44
+ Field(discriminator="filter_type"),
45
+ ]
46
+
47
+
48
+ class FilteredDatasetModel(DSGBaseModel):
49
+ """Filters to apply to a dataset"""
50
+
51
+ dataset_id: str = Field(description="Dataset ID")
52
+ filters: list[DimensionFilters]
53
+
54
+
55
+ class ColumnModel(DSGBaseModel):
56
+ """Defines one column in a SQL aggregation statement."""
57
+
58
+ dimension_name: str
59
+ function: Any = Field(
60
+ default=None, description="Function or name of function in pyspark.sql.functions."
61
+ )
62
+ alias: str | None = Field(default=None, description="Name of the resulting column.")
63
+
64
+ @field_validator("function")
65
+ @classmethod
66
+ def handle_function(cls, function_name):
67
+ if function_name is None:
68
+ return function_name
69
+ if not isinstance(function_name, str):
70
+ return function_name
71
+
72
+ func = getattr(F, function_name, None)
73
+ if func is None:
74
+ msg = f"function={function_name} is not defined in pyspark.sql.functions"
75
+ raise ValueError(msg)
76
+ return func
77
+
78
+ @field_validator("alias")
79
+ @classmethod
80
+ def handle_alias(cls, alias, info: ValidationInfo):
81
+ if alias is not None:
82
+ return alias
83
+ func = info.data.get("function")
84
+ if func is not None:
85
+ name = info.data["dimension_name"]
86
+ return f"{func.__name__}__{name}"
87
+
88
+ return alias
89
+
90
+ @field_serializer("function")
91
+ def serialize_function(self, function, _):
92
+ if function is not None:
93
+ return function.__name__
94
+ return function
95
+
96
+ def get_column_name(self):
97
+ if self.alias is not None:
98
+ return self.alias
99
+ if self.function is None:
100
+ return self.dimension_name
101
+ return f"{self.function.__name__}__{self.dimension_name})"
102
+
103
+
104
+ class ColumnType(StrEnum):
105
+ """Defines what the columns of a dataset table represent."""
106
+
107
+ DIMENSION_TYPES = "dimension_types"
108
+ DIMENSION_NAMES = "dimension_names"
109
+
110
+
111
+ class DimensionNamesModel(DSGBaseModel):
112
+ """Defines the list of dimensions to which the value columns should be aggregated.
113
+ If a value is empty, that dimension will be aggregated and dropped from the table.
114
+ """
115
+
116
+ model_config = make_model_config(protected_namespaces=())
117
+
118
+ geography: list[Union[str, ColumnModel]]
119
+ metric: list[Union[str, ColumnModel]]
120
+ model_year: list[Union[str, ColumnModel]]
121
+ scenario: list[Union[str, ColumnModel]]
122
+ sector: list[Union[str, ColumnModel]]
123
+ subsector: list[Union[str, ColumnModel]]
124
+ time: list[Union[str, ColumnModel]]
125
+ weather_year: list[Union[str, ColumnModel]]
126
+
127
+ @model_validator(mode="before")
128
+ def fix_columns(cls, values):
129
+ for dim_type in DimensionType:
130
+ field = dim_type.value
131
+ container = values[field]
132
+ for i, item in enumerate(container):
133
+ if isinstance(item, str):
134
+ container[i] = ColumnModel(dimension_name=item)
135
+ return values
136
+
137
+
138
+ class AggregationModel(DSGBaseModel):
139
+ """Aggregate on one or more dimensions."""
140
+
141
+ aggregation_function: Any = Field(
142
+ default=None,
143
+ description="Must be a function name in pyspark.sql.functions",
144
+ )
145
+ dimensions: DimensionNamesModel = Field(description="Dimensions on which to aggregate")
146
+
147
+ @field_validator("aggregation_function")
148
+ @classmethod
149
+ def check_aggregation_function(cls, aggregation_function):
150
+ if isinstance(aggregation_function, str):
151
+ aggregation_function = getattr(F, aggregation_function, None)
152
+ if aggregation_function is None:
153
+ msg = f"{aggregation_function} is not defined in pyspark.sql.functions"
154
+ raise ValueError(msg)
155
+ elif aggregation_function is None:
156
+ msg = "aggregation_function cannot be None"
157
+ raise ValueError(msg)
158
+ return aggregation_function
159
+
160
+ @field_validator("dimensions")
161
+ @classmethod
162
+ def check_for_metric(cls, dimensions):
163
+ if not dimensions.metric:
164
+ msg = "An AggregationModel must include the metric dimension."
165
+ raise ValueError(msg)
166
+ return dimensions
167
+
168
+ @field_serializer("aggregation_function")
169
+ def serialize_aggregation_function(self, function, _):
170
+ return function.__name__
171
+
172
+ def iter_dimensions_to_keep(self) -> Generator[tuple[DimensionType, ColumnModel], None, None]:
173
+ """Yield the dimension type and ColumnModel for each dimension to keep."""
174
+ for field in DimensionNamesModel.model_fields:
175
+ for val in getattr(self.dimensions, field):
176
+ yield DimensionType(field), val
177
+
178
+ def list_dropped_dimensions(self) -> list[DimensionType]:
179
+ """Return a list of dimension types that will be dropped by the aggregation."""
180
+ return [
181
+ DimensionType(x)
182
+ for x in DimensionNamesModel.model_fields
183
+ if not getattr(self.dimensions, x)
184
+ ]
185
+
186
+
187
+ class ReportType(StrEnum):
188
+ """Pre-defined reports"""
189
+
190
+ PEAK_LOAD = "peak_load"
191
+
192
+
193
+ class ReportInputModel(DSGBaseModel):
194
+ report_type: ReportType
195
+ inputs: Any = None
196
+
197
+
198
+ class DimensionMetadataModel(DSGBaseModel):
199
+ """Defines the columns in a table for a dimension."""
200
+
201
+ dimension_name: str
202
+ column_names: list[str] = Field(
203
+ description="Columns associated with this dimension. Could be a dimension name, "
204
+ "the string-ified DimensionType, multiple strings as can happen with time, or dimension "
205
+ "record IDS if the dimension is pivoted."
206
+ )
207
+
208
+ def make_key(self):
209
+ return "__".join([self.dimension_name] + self.column_names)
210
+
211
+
212
+ class DatasetDimensionsMetadataModel(DSGBaseModel):
213
+ """Records the dimensions and columns of a dataset as it is transformed by a query."""
214
+
215
+ model_config = make_model_config(protected_namespaces=())
216
+
217
+ geography: list[DimensionMetadataModel] = []
218
+ metric: list[DimensionMetadataModel] = []
219
+ model_year: list[DimensionMetadataModel] = []
220
+ scenario: list[DimensionMetadataModel] = []
221
+ sector: list[DimensionMetadataModel] = []
222
+ subsector: list[DimensionMetadataModel] = []
223
+ time: list[DimensionMetadataModel] = []
224
+ weather_year: list[DimensionMetadataModel] = []
225
+
226
+ def add_metadata(
227
+ self, dimension_type: DimensionType, metadata: DimensionMetadataModel
228
+ ) -> None:
229
+ """Add dimension metadata. Skip duplicates."""
230
+ container = getattr(self, dimension_type.value)
231
+ if metadata.make_key() not in {x.make_key() for x in container}:
232
+ container.append(metadata)
233
+
234
+ def get_metadata(self, dimension_type: DimensionType) -> list[DimensionMetadataModel]:
235
+ """Return the dimension metadata."""
236
+ return getattr(self, dimension_type.value)
237
+
238
+ def replace_metadata(
239
+ self, dimension_type: DimensionType, metadata: list[DimensionMetadataModel]
240
+ ) -> None:
241
+ """Replace the dimension metadata."""
242
+ setattr(self, dimension_type.value, metadata)
243
+
244
+ def get_column_names(self, dimension_type: DimensionType) -> set[str]:
245
+ """Return the column names for the given dimension type."""
246
+ column_names = set()
247
+ for item in getattr(self, dimension_type.value):
248
+ column_names.update(item.column_names)
249
+ return column_names
250
+
251
+ def get_dimension_names(self, dimension_type: DimensionType) -> set[str]:
252
+ """Return the dimension names for the given dimension type."""
253
+ return {x.dimension_name for x in getattr(self, dimension_type.value)}
254
+
255
+ def remove_metadata(self, dimension_type: DimensionType, dimension_name: str) -> None:
256
+ """Remove the dimension metadata for the given dimension name."""
257
+ container = getattr(self, dimension_type.value)
258
+ for i, metadata in enumerate(container):
259
+ if metadata.dimension_name == dimension_name:
260
+ container.pop(i)
261
+ break
262
+
263
+
264
+ class DatasetMetadataModel(DSGBaseModel):
265
+ """Defines the metadata for a dataset serialized to file."""
266
+
267
+ dimensions: DatasetDimensionsMetadataModel = DatasetDimensionsMetadataModel()
268
+ table_format: TableFormatModel
269
+ # This will be set at the query context level but not per-dataset.
270
+ base_dimension_names: DatasetBaseDimensionNamesModel = DatasetBaseDimensionNamesModel()
271
+
272
+ def get_value_format(self) -> ValueFormat:
273
+ """Return the value format of the table."""
274
+ return ValueFormat(self.table_format.format_type)
275
+
276
+
277
+ class CacheableQueryBaseModel(DSGBaseModel):
278
+ def serialize_with_hash(self, *args, **kwargs) -> tuple[str, str]:
279
+ """Return a JSON representation of the model along with a hash that uniquely identifies it."""
280
+ text = self.model_dump_json(indent=2)
281
+ return compute_hash(text.encode()), text
282
+
283
+
284
+ class SparkConfByDataset(DSGBaseModel):
285
+ """Defines a custom Spark configuration to use while running a query on a dataset."""
286
+
287
+ dataset_id: str
288
+ conf: dict[str, Any]
289
+
290
+
291
+ class ProjectQueryDatasetParamsModel(CacheableQueryBaseModel):
292
+ """Parameters in a project query that only apply to datasets"""
293
+
294
+ dimension_filters: list[DimensionFilters] = Field(
295
+ description="Filters to apply to all datasets",
296
+ default=[],
297
+ )
298
+
299
+
300
+ class DatasetType(StrEnum):
301
+ """Defines the type of a dataset in a query."""
302
+
303
+ PROJECTION = "projection"
304
+ STANDALONE = "standalone"
305
+ DERIVED = "derived"
306
+
307
+
308
+ class DatasetConstructionMethod(StrEnum):
309
+ """Defines the type of construction method for DatasetType.PROJECTION."""
310
+
311
+ EXPONENTIAL_GROWTH = "exponential_growth"
312
+ ANNUAL_MULTIPLIER = "annual_multiplier"
313
+
314
+
315
+ class DatasetBaseModel(DSGBaseModel, abc.ABC):
316
+ @abc.abstractmethod
317
+ def get_dataset_id(self) -> str:
318
+ """Return the primary dataset ID.
319
+
320
+ Returns
321
+ -------
322
+ str
323
+ """
324
+
325
+ @abc.abstractmethod
326
+ def list_source_dataset_ids(self) -> list[str]:
327
+ """Return a list of all source dataset IDs."""
328
+
329
+
330
+ class StandaloneDatasetModel(DatasetBaseModel):
331
+ """A dataset with energy use data."""
332
+
333
+ dataset_type: Literal[DatasetType.STANDALONE] = Field(default=DatasetType.STANDALONE)
334
+ dataset_id: str = Field(description="Dataset identifier")
335
+
336
+ def get_dataset_id(self) -> str:
337
+ return self.dataset_id
338
+
339
+ def list_source_dataset_ids(self) -> list[str]:
340
+ return [self.dataset_id]
341
+
342
+
343
+ class ProjectionDatasetModel(DatasetBaseModel):
344
+ """A dataset with growth rates that can be applied to a standalone dataset."""
345
+
346
+ dataset_type: Literal[DatasetType.PROJECTION] = Field(default=DatasetType.PROJECTION)
347
+ dataset_id: str = Field(description="Identifier for the resulting dataset")
348
+ initial_value_dataset_id: str = Field(description="Principal dataset identifier")
349
+ growth_rate_dataset_id: str = Field(
350
+ description="Growth rate dataset identifier to apply to the principal dataset"
351
+ )
352
+ construction_method: DatasetConstructionMethod = Field(
353
+ default=DatasetConstructionMethod.EXPONENTIAL_GROWTH,
354
+ description="Specifier for the code that applies the growth rate to the principal dataset",
355
+ )
356
+ base_year: int | None = Field(
357
+ description="Base year of the dataset to use in growth rate application. Must be a year "
358
+ "defined in the principal dataset's model year dimension. If None, there must be only "
359
+ "one model year in that dimension and it will be used.",
360
+ default=None,
361
+ )
362
+
363
+ def get_dataset_id(self) -> str:
364
+ return self.initial_value_dataset_id
365
+
366
+ def list_source_dataset_ids(self) -> list[str]:
367
+ return [self.initial_value_dataset_id, self.growth_rate_dataset_id]
368
+
369
+
370
+ AbstractDatasetModel = Annotated[
371
+ Union[StandaloneDatasetModel, ProjectionDatasetModel], Field(discriminator="dataset_type")
372
+ ]
373
+
374
+
375
+ class DatasetModel(DSGBaseModel):
376
+ """Specifies the datasets to use in a project query."""
377
+
378
+ dataset_id: str = Field(description="Identifier for the resulting dataset")
379
+ source_datasets: list[AbstractDatasetModel] = Field(
380
+ description="Datasets from which to read. Each must be of type DatasetBaseModel.",
381
+ )
382
+ expression: str | None = Field(
383
+ description="Expression to combine datasets. Default is to take a union of all datasets.",
384
+ default=None,
385
+ )
386
+ params: ProjectQueryDatasetParamsModel = Field(
387
+ description="Parameters affecting datasets. Used for caching intermediate tables.",
388
+ default=ProjectQueryDatasetParamsModel(),
389
+ )
390
+
391
+ @field_validator("expression")
392
+ @classmethod
393
+ def handle_expression(cls, expression, info: ValidationInfo):
394
+ if "source_datasets" not in info.data:
395
+ return expression
396
+
397
+ if expression is None:
398
+ expression = " | ".join((x.dataset_id for x in info.data["source_datasets"]))
399
+ return expression
400
+
401
+
402
+ class ProjectQueryParamsModel(CacheableQueryBaseModel):
403
+ """Defines how to transform a project into a CompositeDataset"""
404
+
405
+ project_id: str = Field(description="Project ID for query")
406
+ dataset: DatasetModel = Field(description="Definition of the dataset to create.")
407
+ excluded_dataset_ids: list[str] = Field(
408
+ description="Datasets to exclude from query", default=[]
409
+ )
410
+ # TODO #203: default needs to change
411
+ include_dsgrid_dataset_components: bool = Field(description="", default=False)
412
+ version: str | None = Field(
413
+ default=None,
414
+ description="Version of project or dataset on which the query is based. "
415
+ "Should not be set by the user",
416
+ )
417
+ mapping_plans: list[DatasetMappingPlan] = Field(
418
+ default=[],
419
+ description="Defines the order in which to map the dimensions of datasets.",
420
+ )
421
+ spark_conf_per_dataset: list[SparkConfByDataset] = Field(
422
+ description="Apply these Spark configuration settings while a dataset is being processed.",
423
+ default=[],
424
+ )
425
+
426
+ @model_validator(mode="before")
427
+ @classmethod
428
+ def check_unsupported_fields(cls, values):
429
+ if values.get("include_dsgrid_dataset_components", False):
430
+ msg = "Setting include_dsgrid_dataset_components=true is not supported yet"
431
+ raise ValueError(msg)
432
+ if values.get("drop_dimensions", []):
433
+ msg = "drop_dimensions is not supported yet"
434
+ raise ValueError(msg)
435
+ if values.get("excluded_dataset_ids", []):
436
+ msg = "excluded_dataset_ids is not supported yet"
437
+ raise ValueError(msg)
438
+ return values
439
+
440
+ @model_validator(mode="after")
441
+ def check_invalid_dataset_ids(self) -> Self:
442
+ source_dataset_ids: set[str] = set()
443
+ for src_dataset in self.dataset.source_datasets:
444
+ source_dataset_ids.update(src_dataset.list_source_dataset_ids())
445
+ for item in itertools.chain(self.mapping_plans, self.spark_conf_per_dataset):
446
+ if item.dataset_id not in source_dataset_ids:
447
+ msg = f"Dataset {item.dataset_id} is not a source dataset"
448
+ raise ValueError(msg)
449
+ return self
450
+
451
+ @field_validator("mapping_plans", "spark_conf_per_dataset")
452
+ @classmethod
453
+ def check_duplicate_dataset_ids(cls, value: list) -> list:
454
+ dataset_ids: set[str] = set()
455
+ for item in value:
456
+ if item.dataset_id in dataset_ids:
457
+ msg = f"{item.dataset_id} is stored multiple times"
458
+ raise ValueError(msg)
459
+ dataset_ids.add(item.dataset_id)
460
+ return value
461
+
462
+ def set_dataset_mapper(self, new_mapper: DatasetMappingPlan) -> None:
463
+ for i, mapper in enumerate(self.mapping_plans):
464
+ if mapper.dataset_id == new_mapper.dataset_id:
465
+ self.mapping_plans[i] = new_mapper
466
+ return
467
+ self.mapping_plans.append(new_mapper)
468
+
469
+ def get_dataset_mapping_plan(self, dataset_id: str) -> DatasetMappingPlan | None:
470
+ """Return the mapping plan for this dataset_id or None if the user did not
471
+ specify one.
472
+ """
473
+ for mapper in self.mapping_plans:
474
+ if dataset_id == mapper.dataset_id:
475
+ return mapper
476
+ return None
477
+
478
+ def get_spark_conf(self, dataset_id: str) -> dict[str, Any]:
479
+ """Return the Spark settings to apply while processing dataset_id."""
480
+ for dataset in self.spark_conf_per_dataset:
481
+ if dataset.dataset_id == dataset_id:
482
+ return dataset.conf
483
+ return {}
484
+
485
+
486
+ QUERY_FORMAT_VERSION = VersionInfo.parse("0.1.0")
487
+
488
+
489
+ class QueryResultParamsModel(CacheableQueryBaseModel):
490
+ """Controls post-processing and storage of CompositeDatasets"""
491
+
492
+ replace_ids_with_names: bool = Field(
493
+ description="Replace dimension record IDs with their names in result tables.",
494
+ default=False,
495
+ )
496
+ aggregations: list[AggregationModel] = Field(
497
+ description="Defines how to aggregate dimensions",
498
+ default=[],
499
+ )
500
+ aggregate_each_dataset: bool = Field(
501
+ description="If True, aggregate each dataset before applying the expression to create one "
502
+ "overall dataset. This parameter must be set to True for queries that will be adding or "
503
+ "subtracting datasets with different dimensionality. Defaults to False, which corresponds to "
504
+ "the default behavior of performing one aggregation on the overall dataset. WARNING: "
505
+ "For a standard query that performs a union of datasets, setting this value to True could "
506
+ "produce rows with duplicate dimension combinations, especially if one or more "
507
+ "dimensions are also dropped.",
508
+ default=False,
509
+ )
510
+ reports: list[ReportInputModel] = Field(
511
+ description="Run these pre-defined reports on the result.", default=[]
512
+ )
513
+ column_type: ColumnType = Field(
514
+ description="Whether to make the result table columns dimension types. Default behavior "
515
+ "is to use dimension names. In order to register a result table as a derived "
516
+ f"dataset, this must be set to {ColumnType.DIMENSION_TYPES.value}.",
517
+ default=ColumnType.DIMENSION_NAMES,
518
+ )
519
+ table_format: TableFormatModel = StackedTableFormatModel()
520
+ output_format: str = Field(description="Output file format: csv or parquet", default="parquet")
521
+ sort_columns: list[str] = Field(
522
+ description="Sort the results by these dimension names.",
523
+ default=[],
524
+ )
525
+ dimension_filters: list[DimensionFilters] = Field(
526
+ description="Filters to apply to the result. Must contain columns in the result.",
527
+ default=[],
528
+ )
529
+ # TODO #205: implement
530
+ time_zone: str | Literal["geography"] | None = Field(
531
+ description="Convert the results to this time zone. If 'geography', use the time zone "
532
+ "of the geography dimension. The resulting time column will be time zone-naive with "
533
+ "time zone recorded in a separate column.",
534
+ default=None,
535
+ )
536
+
537
+ @model_validator(mode="after")
538
+ def check_pivot_dimension_type(self) -> "QueryResultParamsModel":
539
+ if self.table_format.format_type == ValueFormat.PIVOTED:
540
+ pivoted_dim_type = self.table_format.pivoted_dimension_type
541
+ for agg in self.aggregations:
542
+ names = getattr(agg.dimensions, pivoted_dim_type.value)
543
+ num_names = len(names)
544
+ if num_names == 0:
545
+ msg = (
546
+ f"The pivoted dimension ({pivoted_dim_type}) "
547
+ "must be specified in all aggregations."
548
+ )
549
+ raise ValueError(msg)
550
+ elif len(names) > 1:
551
+ msg = (
552
+ f"The pivoted dimension ({pivoted_dim_type}) "
553
+ "cannot have more than one dimension name: {names}"
554
+ )
555
+ raise ValueError(msg)
556
+ return self
557
+
558
+ @field_validator("output_format")
559
+ @classmethod
560
+ def check_format(cls, fmt):
561
+ allowed = {"csv", "parquet"}
562
+ if fmt not in allowed:
563
+ msg = f"output_format={fmt} is not supported. Allowed={allowed}"
564
+ raise ValueError(msg)
565
+ return fmt
566
+
567
+ @model_validator(mode="after")
568
+ def check_column_type(self) -> "QueryResultParamsModel":
569
+ if self.column_type == ColumnType.DIMENSION_TYPES:
570
+ for agg in self.aggregations:
571
+ for dim_type in DimensionType:
572
+ columns = getattr(agg.dimensions, dim_type.value)
573
+ if len(columns) > 1:
574
+ msg = f"Multiple columns are incompatible with {self.column_type=}. {columns=}"
575
+ raise ValueError(msg)
576
+ return self
577
+
578
+
579
+ class QueryBaseModel(CacheableQueryBaseModel, abc.ABC):
580
+ """Base class for all queries"""
581
+
582
+ name: str = Field(description="Name of query")
583
+ # TODO #204: This field is not being used. Wait until development slows down.
584
+ version: str = Field(
585
+ description="Version of the query structure. Changes to the major or minor version invalidate cached tables.",
586
+ default=str(QUERY_FORMAT_VERSION), # TODO: str shouldn't be required
587
+ )
588
+ result: QueryResultParamsModel = Field(
589
+ default=QueryResultParamsModel(),
590
+ description="Controls the output results",
591
+ )
592
+
593
+ def serialize_cached_content(self) -> dict[str, Any]:
594
+ """Return a JSON-able representation of the model that can be used for caching purposes."""
595
+ return self.model_dump(mode="json", exclude={"name"})
596
+
597
+
598
+ class ProjectQueryModel(QueryBaseModel):
599
+ """Represents a user query on a Project."""
600
+
601
+ project: ProjectQueryParamsModel = Field(
602
+ description="Defines the datasets to use and how to transform them.",
603
+ )
604
+
605
+ def serialize_cached_content(self) -> dict[str, Any]:
606
+ # Exclude all result-oriented fields in orer to faciliate re-using queries.
607
+ exclude = {
608
+ "spark_conf_per_dataset", # Doesn't change the query.
609
+ "version", # We use the project major version as a separate field.
610
+ }
611
+ return self.project.model_dump(mode="json", exclude=exclude)
612
+
613
+
614
+ class DatasetQueryModel(QueryBaseModel):
615
+ """Defines how to transform a dataset"""
616
+
617
+ dataset_id: str = Field(description="Dataset ID for query")
618
+ to_dimension_references: list[DimensionReferenceModel] = Field(
619
+ description="Map the dataset to these dimensions. Mappings must exist in the registry. "
620
+ "There cannot be duplicate mappings."
621
+ )
622
+ mapping_plan: DatasetMappingPlan | None = Field(
623
+ default=None,
624
+ description="Defines the order in which to map the dimensions of the dataset.",
625
+ )
626
+ time_based_data_adjustment: TimeBasedDataAdjustmentModel = Field(
627
+ description="Defines how the rest of the dataframe is adjusted with respect to time. "
628
+ "E.g., when drop associated data when dropping a leap day timestamp.",
629
+ default=TimeBasedDataAdjustmentModel(),
630
+ )
631
+ wrap_time_allowed: bool = Field(
632
+ default=False,
633
+ description="Whether to allow dataset time to be wrapped to the destination time "
634
+ "dimension, if different.",
635
+ )
636
+ result: QueryResultParamsModel = Field(
637
+ default=QueryResultParamsModel(),
638
+ description="Controls the output results",
639
+ )
640
+
641
+
642
+ def make_dataset_query(
643
+ name: str,
644
+ dataset_id: str,
645
+ to_dimension_references: list[DimensionReferenceModel],
646
+ plan: DatasetMappingPlan | None = None,
647
+ ) -> DatasetQueryModel:
648
+ """Create a query to map a dataset to alternate dimensions.
649
+
650
+ Parameters
651
+ ----------
652
+ dataset_id: str
653
+ plan: DatasetMappingPlan | None
654
+ Optional plan to control the mapping operation.
655
+ """
656
+ plans: list[DatasetMappingPlan] = []
657
+ if plan is not None:
658
+ plans.append(plan)
659
+ return DatasetQueryModel(
660
+ name=name,
661
+ dataset_id=dataset_id,
662
+ to_dimension_references=to_dimension_references,
663
+ mapping_plan=plan,
664
+ )
665
+
666
+
667
+ def make_query_for_standalone_dataset(
668
+ project_id: str,
669
+ dataset_id: str,
670
+ plan: DatasetMappingPlan | None = None,
671
+ column_type: ColumnType = ColumnType.DIMENSION_NAMES,
672
+ ) -> ProjectQueryModel:
673
+ """Create a query to map a standalone dataset to a project's dimensions.
674
+
675
+ Parameters
676
+ ----------
677
+ project_id: str
678
+ dataset_id: str
679
+ plan: DatasetMappingPlan | None
680
+ Optional plan to control the mapping operation.
681
+ column_type: ColumnType
682
+ The type of columns in the result table. Default is ColumnType.DIMENSION_NAMES.
683
+ """
684
+ plans: list[DatasetMappingPlan] = []
685
+ if plan is not None:
686
+ plans.append(plan)
687
+ return ProjectQueryModel(
688
+ name=dataset_id,
689
+ project=ProjectQueryParamsModel(
690
+ project_id=project_id,
691
+ dataset=DatasetModel(
692
+ dataset_id=dataset_id,
693
+ source_datasets=[StandaloneDatasetModel(dataset_id=dataset_id)],
694
+ ),
695
+ mapping_plans=plans,
696
+ ),
697
+ result=QueryResultParamsModel(
698
+ column_type=column_type,
699
+ ),
700
+ )
701
+
702
+
703
+ class CreateCompositeDatasetQueryModel(QueryBaseModel):
704
+ """Represents a user query to create a Result Dataset. This dataset requires a Project
705
+ in order to retrieve dimension records and dimension mapping records.
706
+ """
707
+
708
+ dataset_id: str = Field(description="Composite Dataset ID for query")
709
+ project: ProjectQueryParamsModel = Field(
710
+ description="Defines the datasets to use and how to transform them."
711
+ )
712
+ result: QueryResultParamsModel = Field(
713
+ description="Controls the output results",
714
+ default=QueryResultParamsModel(),
715
+ )
716
+
717
+ def serialize_cached_content(self) -> dict[str, Any]:
718
+ # Exclude all result-oriented fields in orer to faciliate re-using queries.
719
+ return self.project.model_dump(mode="json", exclude="spark_conf_per_dataset")
720
+
721
+
722
+ class CompositeDatasetQueryModel(QueryBaseModel):
723
+ """Represents a user query on a dataset."""
724
+
725
+ dataset_id: str = Field(description="Aggregated Dataset ID for query")
726
+ result: QueryResultParamsModel = Field(
727
+ description="Controls the output results", default=QueryResultParamsModel()
728
+ )