dsgrid-toolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsgrid-toolkit might be problematic. Click here for more details.

Files changed (152) hide show
  1. dsgrid/__init__.py +22 -0
  2. dsgrid/api/__init__.py +0 -0
  3. dsgrid/api/api_manager.py +179 -0
  4. dsgrid/api/app.py +420 -0
  5. dsgrid/api/models.py +60 -0
  6. dsgrid/api/response_models.py +116 -0
  7. dsgrid/apps/__init__.py +0 -0
  8. dsgrid/apps/project_viewer/app.py +216 -0
  9. dsgrid/apps/registration_gui.py +444 -0
  10. dsgrid/chronify.py +22 -0
  11. dsgrid/cli/__init__.py +0 -0
  12. dsgrid/cli/common.py +120 -0
  13. dsgrid/cli/config.py +177 -0
  14. dsgrid/cli/download.py +13 -0
  15. dsgrid/cli/dsgrid.py +142 -0
  16. dsgrid/cli/dsgrid_admin.py +349 -0
  17. dsgrid/cli/install_notebooks.py +62 -0
  18. dsgrid/cli/query.py +711 -0
  19. dsgrid/cli/registry.py +1773 -0
  20. dsgrid/cloud/__init__.py +0 -0
  21. dsgrid/cloud/cloud_storage_interface.py +140 -0
  22. dsgrid/cloud/factory.py +31 -0
  23. dsgrid/cloud/fake_storage_interface.py +37 -0
  24. dsgrid/cloud/s3_storage_interface.py +156 -0
  25. dsgrid/common.py +35 -0
  26. dsgrid/config/__init__.py +0 -0
  27. dsgrid/config/annual_time_dimension_config.py +187 -0
  28. dsgrid/config/common.py +131 -0
  29. dsgrid/config/config_base.py +148 -0
  30. dsgrid/config/dataset_config.py +684 -0
  31. dsgrid/config/dataset_schema_handler_factory.py +41 -0
  32. dsgrid/config/date_time_dimension_config.py +108 -0
  33. dsgrid/config/dimension_config.py +54 -0
  34. dsgrid/config/dimension_config_factory.py +65 -0
  35. dsgrid/config/dimension_mapping_base.py +349 -0
  36. dsgrid/config/dimension_mappings_config.py +48 -0
  37. dsgrid/config/dimensions.py +775 -0
  38. dsgrid/config/dimensions_config.py +71 -0
  39. dsgrid/config/index_time_dimension_config.py +76 -0
  40. dsgrid/config/input_dataset_requirements.py +31 -0
  41. dsgrid/config/mapping_tables.py +209 -0
  42. dsgrid/config/noop_time_dimension_config.py +42 -0
  43. dsgrid/config/project_config.py +1457 -0
  44. dsgrid/config/registration_models.py +199 -0
  45. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  46. dsgrid/config/simple_models.py +49 -0
  47. dsgrid/config/supplemental_dimension.py +29 -0
  48. dsgrid/config/time_dimension_base_config.py +200 -0
  49. dsgrid/data_models.py +155 -0
  50. dsgrid/dataset/__init__.py +0 -0
  51. dsgrid/dataset/dataset.py +123 -0
  52. dsgrid/dataset/dataset_expression_handler.py +86 -0
  53. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  54. dsgrid/dataset/dataset_schema_handler_base.py +899 -0
  55. dsgrid/dataset/dataset_schema_handler_one_table.py +196 -0
  56. dsgrid/dataset/dataset_schema_handler_standard.py +303 -0
  57. dsgrid/dataset/growth_rates.py +162 -0
  58. dsgrid/dataset/models.py +44 -0
  59. dsgrid/dataset/table_format_handler_base.py +257 -0
  60. dsgrid/dataset/table_format_handler_factory.py +17 -0
  61. dsgrid/dataset/unpivoted_table.py +121 -0
  62. dsgrid/dimension/__init__.py +0 -0
  63. dsgrid/dimension/base_models.py +218 -0
  64. dsgrid/dimension/dimension_filters.py +308 -0
  65. dsgrid/dimension/standard.py +213 -0
  66. dsgrid/dimension/time.py +531 -0
  67. dsgrid/dimension/time_utils.py +88 -0
  68. dsgrid/dsgrid_rc.py +88 -0
  69. dsgrid/exceptions.py +105 -0
  70. dsgrid/filesystem/__init__.py +0 -0
  71. dsgrid/filesystem/cloud_filesystem.py +32 -0
  72. dsgrid/filesystem/factory.py +32 -0
  73. dsgrid/filesystem/filesystem_interface.py +136 -0
  74. dsgrid/filesystem/local_filesystem.py +74 -0
  75. dsgrid/filesystem/s3_filesystem.py +118 -0
  76. dsgrid/loggers.py +132 -0
  77. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +950 -0
  78. dsgrid/notebooks/registration.ipynb +48 -0
  79. dsgrid/notebooks/start_notebook.sh +11 -0
  80. dsgrid/project.py +451 -0
  81. dsgrid/query/__init__.py +0 -0
  82. dsgrid/query/dataset_mapping_plan.py +142 -0
  83. dsgrid/query/derived_dataset.py +384 -0
  84. dsgrid/query/models.py +726 -0
  85. dsgrid/query/query_context.py +287 -0
  86. dsgrid/query/query_submitter.py +847 -0
  87. dsgrid/query/report_factory.py +19 -0
  88. dsgrid/query/report_peak_load.py +70 -0
  89. dsgrid/query/reports_base.py +20 -0
  90. dsgrid/registry/__init__.py +0 -0
  91. dsgrid/registry/bulk_register.py +161 -0
  92. dsgrid/registry/common.py +287 -0
  93. dsgrid/registry/config_update_checker_base.py +63 -0
  94. dsgrid/registry/data_store_factory.py +34 -0
  95. dsgrid/registry/data_store_interface.py +69 -0
  96. dsgrid/registry/dataset_config_generator.py +156 -0
  97. dsgrid/registry/dataset_registry_manager.py +734 -0
  98. dsgrid/registry/dataset_update_checker.py +16 -0
  99. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  100. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  101. dsgrid/registry/dimension_registry_manager.py +413 -0
  102. dsgrid/registry/dimension_update_checker.py +16 -0
  103. dsgrid/registry/duckdb_data_store.py +185 -0
  104. dsgrid/registry/filesystem_data_store.py +141 -0
  105. dsgrid/registry/filter_registry_manager.py +123 -0
  106. dsgrid/registry/project_config_generator.py +57 -0
  107. dsgrid/registry/project_registry_manager.py +1616 -0
  108. dsgrid/registry/project_update_checker.py +48 -0
  109. dsgrid/registry/registration_context.py +223 -0
  110. dsgrid/registry/registry_auto_updater.py +316 -0
  111. dsgrid/registry/registry_database.py +662 -0
  112. dsgrid/registry/registry_interface.py +446 -0
  113. dsgrid/registry/registry_manager.py +544 -0
  114. dsgrid/registry/registry_manager_base.py +367 -0
  115. dsgrid/registry/versioning.py +92 -0
  116. dsgrid/spark/__init__.py +0 -0
  117. dsgrid/spark/functions.py +545 -0
  118. dsgrid/spark/types.py +50 -0
  119. dsgrid/tests/__init__.py +0 -0
  120. dsgrid/tests/common.py +139 -0
  121. dsgrid/tests/make_us_data_registry.py +204 -0
  122. dsgrid/tests/register_derived_datasets.py +103 -0
  123. dsgrid/tests/utils.py +25 -0
  124. dsgrid/time/__init__.py +0 -0
  125. dsgrid/time/time_conversions.py +80 -0
  126. dsgrid/time/types.py +67 -0
  127. dsgrid/units/__init__.py +0 -0
  128. dsgrid/units/constants.py +113 -0
  129. dsgrid/units/convert.py +71 -0
  130. dsgrid/units/energy.py +145 -0
  131. dsgrid/units/power.py +87 -0
  132. dsgrid/utils/__init__.py +0 -0
  133. dsgrid/utils/dataset.py +612 -0
  134. dsgrid/utils/files.py +179 -0
  135. dsgrid/utils/filters.py +125 -0
  136. dsgrid/utils/id_remappings.py +100 -0
  137. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  138. dsgrid/utils/py_expression_eval/README.md +8 -0
  139. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  140. dsgrid/utils/py_expression_eval/tests.py +283 -0
  141. dsgrid/utils/run_command.py +70 -0
  142. dsgrid/utils/scratch_dir_context.py +64 -0
  143. dsgrid/utils/spark.py +918 -0
  144. dsgrid/utils/spark_partition.py +98 -0
  145. dsgrid/utils/timing.py +239 -0
  146. dsgrid/utils/utilities.py +184 -0
  147. dsgrid/utils/versioning.py +36 -0
  148. dsgrid_toolkit-0.2.0.dist-info/METADATA +216 -0
  149. dsgrid_toolkit-0.2.0.dist-info/RECORD +152 -0
  150. dsgrid_toolkit-0.2.0.dist-info/WHEEL +4 -0
  151. dsgrid_toolkit-0.2.0.dist-info/entry_points.txt +4 -0
  152. dsgrid_toolkit-0.2.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,684 @@
1
+ from dsgrid.dataset.models import UnpivotedTableFormatModel
2
+ import logging
3
+ from enum import Enum
4
+ from pathlib import Path
5
+ from typing import Any, Literal, Union
6
+
7
+ from pydantic import field_validator, model_validator, Field
8
+
9
+ from dsgrid.common import SCALING_FACTOR_COLUMN, VALUE_COLUMN
10
+ from dsgrid.config.common import make_base_dimension_template
11
+ from dsgrid.config.dimension_config import (
12
+ DimensionBaseConfig,
13
+ DimensionBaseConfigWithFiles,
14
+ )
15
+ from dsgrid.config.time_dimension_base_config import TimeDimensionBaseConfig
16
+ from dsgrid.dataset.models import (
17
+ PivotedTableFormatModel,
18
+ TableFormatModel,
19
+ TableFormatType,
20
+ )
21
+ from dsgrid.dimension.base_models import DimensionType, check_timezone_in_geography
22
+ from dsgrid.dimension.time import TimeDimensionType
23
+ from dsgrid.exceptions import DSGInvalidParameter
24
+ from dsgrid.registry.common import check_config_id_strict
25
+ from dsgrid.data_models import DSGBaseDatabaseModel, DSGBaseModel, DSGEnum, EnumValue
26
+ from dsgrid.exceptions import DSGInvalidDimension
27
+ from dsgrid.spark.types import (
28
+ DataFrame,
29
+ F,
30
+ )
31
+ from dsgrid.utils.spark import get_unique_values, read_dataframe
32
+ from dsgrid.utils.utilities import check_uniqueness
33
+ from .config_base import ConfigBase
34
+ from .dimensions import (
35
+ DimensionsListModel,
36
+ DimensionReferenceModel,
37
+ DimensionModel,
38
+ TimeDimensionBaseModel,
39
+ )
40
+
41
+
42
+ # Note that there is special handling for S3 at use sites.
43
+ ALLOWED_LOAD_DATA_FILENAMES = ("load_data.parquet", "load_data.csv", "table.parquet")
44
+ ALLOWED_LOAD_DATA_LOOKUP_FILENAMES = (
45
+ "load_data_lookup.parquet",
46
+ "lookup_table.parquet",
47
+ # The next two are only used for test data.
48
+ "load_data_lookup.csv",
49
+ "load_data_lookup.json",
50
+ )
51
+ ALLOWED_DATA_FILES = ALLOWED_LOAD_DATA_FILENAMES + ALLOWED_LOAD_DATA_LOOKUP_FILENAMES
52
+ ALLOWED_MISSING_DIMENSION_ASSOCATIONS_FILENAMES = (
53
+ "missing_associations.csv",
54
+ "missing_associations.parquet",
55
+ )
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+
60
+ def check_load_data_filename(path: str | Path) -> Path:
61
+ """Return the load_data filename in path. Supports Parquet and CSV.
62
+
63
+ Parameters
64
+ ----------
65
+ path : str | Path
66
+
67
+ Returns
68
+ -------
69
+ Path
70
+
71
+ Raises
72
+ ------
73
+ ValueError
74
+ Raised if no supported load data filename exists.
75
+
76
+ """
77
+ path_ = path if isinstance(path, Path) else Path(path)
78
+ if str(path_).startswith("s3://"):
79
+ # Only Parquet is supported on AWS.
80
+ return path_ / "/load_data.parquet"
81
+
82
+ for allowed_name in ALLOWED_LOAD_DATA_FILENAMES:
83
+ filename = path_ / allowed_name
84
+ if filename.exists():
85
+ return filename
86
+
87
+ # Use ValueError because this gets called in Pydantic model validation.
88
+ msg = f"no load_data file exists in {path_}"
89
+ raise ValueError(msg)
90
+
91
+
92
+ def check_load_data_lookup_filename(path: str | Path) -> Path:
93
+ """Return the load_data_lookup filename in path. Supports Parquet, CSV, and JSON.
94
+
95
+ Parameters
96
+ ----------
97
+ path : Path
98
+
99
+ Returns
100
+ -------
101
+ Path
102
+
103
+ Raises
104
+ ------
105
+ ValueError
106
+ Raised if no supported load data lookup filename exists.
107
+
108
+ """
109
+ path_ = path if isinstance(path, Path) else Path(path)
110
+ if str(path_).startswith("s3://"):
111
+ # Only Parquet is supported on AWS.
112
+ return path_ / "/load_data_lookup.parquet"
113
+
114
+ for allowed_name in ALLOWED_LOAD_DATA_LOOKUP_FILENAMES:
115
+ filename = path_ / allowed_name
116
+ if filename.exists():
117
+ return filename
118
+
119
+ # Use ValueError because this gets called in Pydantic model validation.
120
+ msg = f"no load_data_lookup file exists in {path_}"
121
+ raise ValueError(msg)
122
+
123
+
124
+ class InputDatasetType(DSGEnum):
125
+ MODELED = "modeled"
126
+ HISTORICAL = "historical"
127
+ BENCHMARK = "benchmark"
128
+
129
+
130
+ class DataSchemaType(str, Enum):
131
+ """Data schema types."""
132
+
133
+ STANDARD = "standard"
134
+ ONE_TABLE = "one_table"
135
+
136
+
137
+ class DSGDatasetParquetType(DSGEnum):
138
+ """Dataset parquet types."""
139
+
140
+ LOAD_DATA = EnumValue(
141
+ value="load_data",
142
+ description="""
143
+ In STANDARD data_schema_type, load_data is a file with ID, timestamp, and metric value columns.
144
+ In ONE_TABLE data_schema_type, load_data is a file with multiple data dimension and metric value columns.
145
+ """,
146
+ )
147
+ LOAD_DATA_LOOKUP = EnumValue(
148
+ value="load_data_lookup",
149
+ description="""
150
+ load_data_lookup is a file with multiple data dimension columns and an ID column that maps to load_data file.
151
+ """,
152
+ )
153
+ # # These are not currently supported by dsgrid but may be needed in the near future
154
+ # DATASET_DIMENSION_MAPPING = EnumValue(
155
+ # value="dataset_dimension_mapping",
156
+ # description="",
157
+ # ) # optional
158
+ # PROJECT_DIMENSION_MAPPING = EnumValue(
159
+ # value="project_dimension_mapping",
160
+ # description="",
161
+ # ) # optional
162
+
163
+
164
+ class DataClassificationType(DSGEnum):
165
+ """Data risk classification type.
166
+
167
+ See https://uit.stanford.edu/guide/riskclassifications for more information.
168
+ """
169
+
170
+ # TODO: can we get NREL/DOE definitions for these instead of standford's?
171
+
172
+ LOW = EnumValue(
173
+ value="low",
174
+ description="Low risk data that does not require special data management",
175
+ )
176
+ MODERATE = EnumValue(
177
+ value="moderate",
178
+ description=(
179
+ "The moderate class includes all data under an NDA, data classified as business sensitive, "
180
+ "data classification as Critical Energy Infrastructure Infromation (CEII), or data with Personal Identifiable Information (PII)."
181
+ ),
182
+ )
183
+
184
+
185
+ class StandardDataSchemaModel(DSGBaseModel):
186
+ data_schema_type: Literal[DataSchemaType.STANDARD]
187
+ table_format: TableFormatModel
188
+
189
+ @model_validator(mode="before")
190
+ @classmethod
191
+ def handle_legacy(cls, values: dict) -> dict:
192
+ if "load_data_column_dimension" in values:
193
+ values["table_format"] = PivotedTableFormatModel(
194
+ pivoted_dimension_type=values.pop("load_data_column_dimension")
195
+ )
196
+ return values
197
+
198
+
199
+ class OneTableDataSchemaModel(DSGBaseModel):
200
+ data_schema_type: Literal[DataSchemaType.ONE_TABLE]
201
+ table_format: TableFormatModel
202
+
203
+ @model_validator(mode="before")
204
+ @classmethod
205
+ def handle_legacy(cls, values: dict) -> dict:
206
+ if "load_data_column_dimension" in values:
207
+ values["table_format"] = PivotedTableFormatModel(
208
+ pivoted_dimension_type=values.pop("load_data_column_dimension")
209
+ )
210
+ return values
211
+
212
+
213
+ class DatasetQualifierType(str, Enum):
214
+ QUANTITY = "quantity"
215
+ GROWTH_RATE = "growth_rate"
216
+
217
+
218
+ class GrowthRateType(str, Enum):
219
+ EXPONENTIAL_ANNUAL = "exponential_annual"
220
+ EXPONENTIAL_MONTHLY = "exponential_monthly"
221
+
222
+
223
+ class QuantityModel(DSGBaseModel):
224
+ dataset_qualifier_type: Literal[DatasetQualifierType.QUANTITY]
225
+
226
+
227
+ class GrowthRateModel(DSGBaseModel):
228
+ dataset_qualifier_type: Literal[DatasetQualifierType.GROWTH_RATE]
229
+ growth_rate_type: GrowthRateType = Field(
230
+ title="growth_rate_type",
231
+ description="Type of growth rates, e.g., exponential_annual",
232
+ )
233
+
234
+
235
+ class DatasetConfigModel(DSGBaseDatabaseModel):
236
+ """Represents dataset configurations."""
237
+
238
+ dataset_id: str = Field(
239
+ title="dataset_id",
240
+ description="Unique dataset identifier.",
241
+ )
242
+ dataset_type: InputDatasetType = Field(
243
+ title="dataset_type",
244
+ description="Input dataset type.",
245
+ json_schema_extra={
246
+ "options": InputDatasetType.format_for_docs(),
247
+ },
248
+ )
249
+ dataset_qualifier_metadata: Union[QuantityModel, GrowthRateModel] = Field(
250
+ default=QuantityModel(dataset_qualifier_type=DatasetQualifierType.QUANTITY),
251
+ title="dataset_qualifier_metadata",
252
+ description="Additional metadata to include related to the dataset_qualifier",
253
+ discriminator="dataset_qualifier_type",
254
+ )
255
+ sector_description: str | None = Field(
256
+ default=None,
257
+ title="sector_description",
258
+ description="Sectoral description (e.g., residential, commercial, industrial, "
259
+ "transportation, electricity)",
260
+ )
261
+ data_source: str = Field(
262
+ title="data_source",
263
+ description="Data source name, e.g. 'ComStock'.",
264
+ # TODO: it would be nice to extend the description here with a CLI example of how to list the project's data source IDs.
265
+ )
266
+ data_schema: Union[StandardDataSchemaModel, OneTableDataSchemaModel] = Field(
267
+ title="data_schema",
268
+ description="Schema (table layouts) used for writing out the dataset",
269
+ discriminator="data_schema_type",
270
+ )
271
+ description: str = Field(
272
+ title="description",
273
+ description="A detailed description of the dataset.",
274
+ )
275
+ origin_creator: str = Field(
276
+ title="origin_creator",
277
+ description="Origin data creator's name (first and last)",
278
+ )
279
+ origin_organization: str = Field(
280
+ title="origin_organization",
281
+ description="Origin organization name, e.g., NREL",
282
+ )
283
+ origin_contributors: list[str] = Field(
284
+ title="origin_contributors",
285
+ description="List of origin data contributor's first and last names"
286
+ """ e.g., ["Harry Potter", "Ronald Weasley"]""",
287
+ default=[],
288
+ )
289
+ origin_project: str = Field(
290
+ title="origin_project",
291
+ description="Origin project name",
292
+ )
293
+ origin_date: str = Field(
294
+ title="origin_date",
295
+ description="Date the source data was generated",
296
+ )
297
+ origin_version: str = Field(
298
+ title="origin_version",
299
+ description="Version of the origin data",
300
+ )
301
+ source: str = Field(
302
+ title="source",
303
+ description="Source of the data (text description or link)",
304
+ )
305
+ data_classification: DataClassificationType = Field(
306
+ title="data_classification",
307
+ description="Data security classification (e.g., low, moderate, high)",
308
+ json_schema_extra={
309
+ "options": DataClassificationType.format_for_docs(),
310
+ },
311
+ )
312
+ tags: list[str] = Field(
313
+ title="source",
314
+ description="List of data tags",
315
+ default=[],
316
+ )
317
+ enable_unit_conversion: bool = Field(
318
+ default=True,
319
+ description="If the dataset uses its dimension mapping for the metric dimension to also "
320
+ "perform unit conversion, then this value should be false.",
321
+ )
322
+ # This field must be listed before dimensions.
323
+ use_project_geography_time_zone: bool = Field(
324
+ default=False,
325
+ description="If true, time zones will be applied from the project's geography dimension. "
326
+ "If false, the dataset's geography dimension records must provide a time zone column.",
327
+ )
328
+ dimensions: DimensionsListModel = Field(
329
+ title="dimensions",
330
+ description="List of dimensions that make up the dimensions of dataset. They will be "
331
+ "automatically registered during dataset registration and then converted "
332
+ "to dimension_references.",
333
+ default=[],
334
+ )
335
+ dimension_references: list[DimensionReferenceModel] = Field(
336
+ title="dimensions",
337
+ description="List of registered dimension references that make up the dimensions of dataset.",
338
+ default=[],
339
+ # TODO: Add to notes - link to registering dimensions page
340
+ # TODO: Add to notes - link to example of how to list dimensions to find existing registered dimensions
341
+ )
342
+ user_defined_metadata: dict[str, Any] = Field(
343
+ title="user_defined_metadata",
344
+ description="Additional user defined metadata fields",
345
+ default={},
346
+ )
347
+ trivial_dimensions: list[DimensionType] = Field(
348
+ title="trivial_dimensions",
349
+ default=[],
350
+ description="List of trivial dimensions (i.e., 1-element dimensions) that "
351
+ "do not exist in the load_data_lookup. List the dimensions by dimension type. "
352
+ "Trivial dimensions are 1-element dimensions that are not present in the parquet data "
353
+ "columns. Instead they are added by dsgrid as an alias column.",
354
+ )
355
+
356
+ # This function can be deleted once all dataset repositories have been updated.
357
+ @model_validator(mode="before")
358
+ @classmethod
359
+ def handle_legacy_fields(cls, values):
360
+ if "dataset_version" in values:
361
+ val = values.pop("dataset_version")
362
+ if val is not None:
363
+ values["version"] = val
364
+
365
+ if "data_schema_type" in values:
366
+ if "data_schema_type" in values["data_schema"]:
367
+ msg = f"Unknown data_schema format: {values=}"
368
+ raise ValueError(msg)
369
+ values["data_schema"]["data_schema_type"] = values.pop("data_schema_type")
370
+
371
+ if "leap_day_adjustment" in values:
372
+ if values["leap_day_adjustment"] != "none":
373
+ msg = f"Unknown leap day adjustment: {values=}"
374
+ raise ValueError(msg)
375
+ values.pop("leap_day_adjustment")
376
+
377
+ return values
378
+
379
+ @field_validator("dataset_id")
380
+ @classmethod
381
+ def check_dataset_id(cls, dataset_id):
382
+ """Check dataset ID validity"""
383
+ check_config_id_strict(dataset_id, "Dataset")
384
+ return dataset_id
385
+
386
+ @field_validator("trivial_dimensions")
387
+ @classmethod
388
+ def check_time_not_trivial(cls, trivial_dimensions):
389
+ for dim in trivial_dimensions:
390
+ if dim == DimensionType.TIME:
391
+ msg = "The time dimension is currently not a dsgrid supported trivial dimension."
392
+ raise ValueError(msg)
393
+ return trivial_dimensions
394
+
395
+ @field_validator("dimensions")
396
+ @classmethod
397
+ def check_files(cls, values: list) -> list:
398
+ """Validate dimension files are unique across all dimensions"""
399
+ check_uniqueness(
400
+ (
401
+ x.filename
402
+ for x in values
403
+ if isinstance(x, DimensionModel) and x.filename is not None
404
+ ),
405
+ "dimension record filename",
406
+ )
407
+ return values
408
+
409
+ @field_validator("dimensions")
410
+ @classmethod
411
+ def check_names(cls, values: list) -> list:
412
+ """Validate dimension names are unique across all dimensions."""
413
+ check_uniqueness(
414
+ [dim.name for dim in values],
415
+ "dimension record name",
416
+ )
417
+ return values
418
+
419
+ @model_validator(mode="after")
420
+ def check_time_zone(self) -> "DatasetConfigModel":
421
+ """Validate whether required time zone information is present."""
422
+ geo_requires_time_zone = False
423
+ time_dim = None
424
+ if not self.use_project_geography_time_zone:
425
+ for dimension in self.dimensions:
426
+ if dimension.dimension_type == DimensionType.TIME:
427
+ assert isinstance(dimension, TimeDimensionBaseModel)
428
+ geo_requires_time_zone = dimension.is_time_zone_required_in_geography()
429
+ time_dim = dimension
430
+ break
431
+
432
+ if geo_requires_time_zone:
433
+ for dimension in self.dimensions:
434
+ if dimension.dimension_type == DimensionType.GEOGRAPHY:
435
+ check_timezone_in_geography(
436
+ dimension,
437
+ err_msg=f"Dataset with time dimension {time_dim} requires that its "
438
+ "geography dimension records include a time_zone column.",
439
+ )
440
+
441
+ return self
442
+
443
+
444
+ def make_unvalidated_dataset_config(
445
+ dataset_id,
446
+ metric_type: str,
447
+ table_format: dict[str, str] | None = None,
448
+ data_classification=DataClassificationType.MODERATE.value,
449
+ dataset_type=InputDatasetType.MODELED,
450
+ included_dimensions: list[DimensionType] | None = None,
451
+ time_type: TimeDimensionType | None = None,
452
+ use_project_geography_time_zone: bool = False,
453
+ dimension_references: list[DimensionReferenceModel] | None = None,
454
+ trivial_dimensions: list[DimensionType] | None = None,
455
+ ) -> dict[str, Any]:
456
+ """Create a dataset config as a dictionary, skipping validation."""
457
+ table_format_ = table_format or UnpivotedTableFormatModel().model_dump()
458
+ trivial_dimensions_ = trivial_dimensions or []
459
+ exclude_dimension_types = {x.dimension_type for x in dimension_references or []}
460
+ if included_dimensions is not None:
461
+ for dim_type in set(DimensionType).difference(included_dimensions):
462
+ exclude_dimension_types.add(dim_type)
463
+
464
+ dimensions = make_base_dimension_template(
465
+ [metric_type],
466
+ exclude_dimension_types=exclude_dimension_types,
467
+ time_type=time_type,
468
+ )
469
+ return {
470
+ "dataset_id": dataset_id,
471
+ "dataset_type": dataset_type.value,
472
+ "data_schema": {
473
+ "data_schema_type": DataSchemaType.ONE_TABLE.value,
474
+ "table_format": table_format_,
475
+ },
476
+ "version": "1.0.0",
477
+ "description": "",
478
+ "origin_creator": "",
479
+ "origin_organization": "",
480
+ "origin_date": "",
481
+ "origin_project": "",
482
+ "origin_version": "",
483
+ "data_source": "",
484
+ "source": "",
485
+ "data_classification": data_classification,
486
+ "use_project_geography_time_zone": True,
487
+ "dimensions": dimensions,
488
+ "dimension_references": [x.model_dump(mode="json") for x in dimension_references or []],
489
+ "tags": [],
490
+ "user_defined_metadata": {},
491
+ "trivial_dimensions": [x.value for x in trivial_dimensions_],
492
+ }
493
+
494
+
495
+ class DatasetConfig(ConfigBase):
496
+ """Provides an interface to a DatasetConfigModel."""
497
+
498
+ def __init__(self, model):
499
+ super().__init__(model)
500
+ self._dimensions = {} # ConfigKey to DimensionConfig
501
+ self._dataset_path: Path | None = None
502
+
503
+ @staticmethod
504
+ def config_filename():
505
+ return "dataset.json5"
506
+
507
+ @property
508
+ def config_id(self):
509
+ return self._model.dataset_id
510
+
511
+ @staticmethod
512
+ def model_class():
513
+ return DatasetConfigModel
514
+
515
+ @classmethod
516
+ def load_from_user_path(cls, config_file, dataset_path) -> "DatasetConfig":
517
+ config = cls.load(config_file)
518
+ schema_type = config.get_data_schema_type()
519
+ if str(dataset_path).startswith("s3://"):
520
+ # TODO: This may need to handle AWS s3 at some point.
521
+ msg = "Registering a dataset from an S3 path is not supported."
522
+ raise DSGInvalidParameter(msg)
523
+ if not dataset_path.exists():
524
+ msg = f"Dataset {dataset_path} does not exist"
525
+ raise DSGInvalidParameter(msg)
526
+ dataset_path = str(dataset_path)
527
+ if schema_type == DataSchemaType.STANDARD:
528
+ check_load_data_filename(dataset_path)
529
+ check_load_data_lookup_filename(dataset_path)
530
+ elif schema_type == DataSchemaType.ONE_TABLE:
531
+ check_load_data_filename(dataset_path)
532
+ else:
533
+ msg = f"data_schema_type={schema_type} not supported."
534
+ raise DSGInvalidParameter(msg)
535
+
536
+ config.dataset_path = dataset_path
537
+ return config
538
+
539
+ @property
540
+ def dataset_path(self) -> Path | None:
541
+ """Return the directory containing the dataset file(s)."""
542
+ return self._dataset_path
543
+
544
+ @dataset_path.setter
545
+ def dataset_path(self, dataset_path: Path | str | None) -> None:
546
+ """Set the dataset path."""
547
+ if isinstance(dataset_path, str):
548
+ dataset_path = Path(dataset_path)
549
+ self._dataset_path = dataset_path
550
+
551
+ @property
552
+ def load_data_path(self):
553
+ assert self._dataset_path is not None
554
+ return check_load_data_filename(self._dataset_path)
555
+
556
+ @property
557
+ def load_data_lookup_path(self):
558
+ assert self._dataset_path is not None
559
+ return check_load_data_lookup_filename(self._dataset_path)
560
+
561
+ def update_dimensions(self, dimensions):
562
+ """Update all dataset dimensions."""
563
+ self._dimensions.update(dimensions)
564
+
565
+ @property
566
+ def dimensions(self):
567
+ return self._dimensions
568
+
569
+ def get_dimension(self, dimension_type: DimensionType) -> DimensionBaseConfig | None:
570
+ """Return the dimension matching dimension_type."""
571
+ for dim_config in self.dimensions.values():
572
+ if dim_config.model.dimension_type == dimension_type:
573
+ return dim_config
574
+ return None
575
+
576
+ def get_time_dimension(self) -> TimeDimensionBaseConfig | None:
577
+ """Return the time dimension of the dataset."""
578
+ dim = self.get_dimension(DimensionType.TIME)
579
+ assert dim is None or isinstance(dim, TimeDimensionBaseConfig)
580
+ return dim
581
+
582
+ def get_dimension_with_records(
583
+ self, dimension_type: DimensionType
584
+ ) -> DimensionBaseConfigWithFiles | None:
585
+ """Return the dimension matching dimension_type."""
586
+ for dim_config in self.dimensions.values():
587
+ if dim_config.model.dimension_type == dimension_type and isinstance(
588
+ dim_config, DimensionBaseConfigWithFiles
589
+ ):
590
+ return dim_config
591
+ return None
592
+
593
+ def get_pivoted_dimension_type(self) -> DimensionType | None:
594
+ """Return the table's pivoted dimension type or None if the table isn't pivoted."""
595
+ if self.get_table_format_type() != TableFormatType.PIVOTED:
596
+ return None
597
+ return self.model.data_schema.table_format.pivoted_dimension_type
598
+
599
+ def get_pivoted_dimension_columns(self) -> list[str]:
600
+ """Return the table's pivoted dimension columns or an empty list if the table isn't
601
+ pivoted.
602
+ """
603
+ if self.get_table_format_type() != TableFormatType.PIVOTED:
604
+ return []
605
+ dim_type = self.model.data_schema.table_format.pivoted_dimension_type
606
+ dim = self.get_dimension_with_records(dim_type)
607
+ assert dim is not None
608
+ return sorted(list(dim.get_unique_ids()))
609
+
610
+ def get_value_columns(self) -> list[str]:
611
+ """Return the table's columns that contain values."""
612
+ match self.get_table_format_type():
613
+ case TableFormatType.PIVOTED:
614
+ return self.get_pivoted_dimension_columns()
615
+ case TableFormatType.UNPIVOTED:
616
+ return [VALUE_COLUMN]
617
+ case _:
618
+ raise NotImplementedError(str(self.get_table_format_type()))
619
+
620
+ def get_data_schema_type(self) -> DataSchemaType:
621
+ """Return the schema type of the table."""
622
+ return DataSchemaType(self.model.data_schema.data_schema_type)
623
+
624
+ def get_table_format_type(self) -> TableFormatType:
625
+ """Return the format type of the table."""
626
+ return TableFormatType(self._model.data_schema.table_format.format_type)
627
+
628
+ def add_trivial_dimensions(self, df: DataFrame):
629
+ """Add trivial 1-element dimensions to load_data_lookup."""
630
+ for dim in self._dimensions.values():
631
+ if dim.model.dimension_type in self.model.trivial_dimensions:
632
+ self._check_trivial_record_length(dim.model.records)
633
+ val = dim.model.records[0].id
634
+ col = dim.model.dimension_type.value
635
+ df = df.withColumn(col, F.lit(val))
636
+ return df
637
+
638
+ def remove_trivial_dimensions(self, df):
639
+ trivial_cols = {d.value for d in self.model.trivial_dimensions}
640
+ select_cols = [col for col in df.columns if col not in trivial_cols]
641
+ return df[select_cols]
642
+
643
+ def _check_trivial_record_length(self, records):
644
+ """Check that trivial dimensions have only 1 record."""
645
+ if len(records) > 1:
646
+ msg = f"Trivial dimensions must have only 1 record but {len(records)} records found for dimension: {records}"
647
+ raise DSGInvalidDimension(msg)
648
+
649
+
650
+ def get_unique_dimension_record_ids(
651
+ path: Path,
652
+ schema_type: DataSchemaType,
653
+ pivoted_dimension_type: DimensionType | None,
654
+ time_columns: set[str],
655
+ ) -> dict[DimensionType, list[str]]:
656
+ """Get the unique dimension record IDs from a table."""
657
+ if schema_type == DataSchemaType.STANDARD:
658
+ ld = read_dataframe(check_load_data_filename(path))
659
+ lk = read_dataframe(check_load_data_lookup_filename(path))
660
+ df = ld.join(lk, on="id").drop("id")
661
+ elif schema_type == DataSchemaType.ONE_TABLE:
662
+ ld_path = check_load_data_filename(path)
663
+ df = read_dataframe(ld_path)
664
+ else:
665
+ msg = f"Unsupported schema type: {schema_type}"
666
+ raise NotImplementedError(msg)
667
+
668
+ ids_by_dimension_type: dict[DimensionType, list[str]] = {}
669
+ for dimension_type in DimensionType:
670
+ if dimension_type.value in df.columns:
671
+ ids_by_dimension_type[dimension_type] = sorted(
672
+ get_unique_values(df, dimension_type.value)
673
+ )
674
+ if pivoted_dimension_type is not None:
675
+ if pivoted_dimension_type.value in df.columns:
676
+ msg = f"{pivoted_dimension_type=} cannot be in the dataframe columns."
677
+ raise DSGInvalidParameter(msg)
678
+ dimension_type_columns = {x.value for x in DimensionType}
679
+ dimension_type_columns.update(time_columns)
680
+ dimension_type_columns.update({"id", SCALING_FACTOR_COLUMN})
681
+ pivoted_columns = set(df.columns) - dimension_type_columns
682
+ ids_by_dimension_type[pivoted_dimension_type] = sorted(pivoted_columns)
683
+
684
+ return ids_by_dimension_type