dsgrid-toolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsgrid-toolkit might be problematic. Click here for more details.

Files changed (152) hide show
  1. dsgrid/__init__.py +22 -0
  2. dsgrid/api/__init__.py +0 -0
  3. dsgrid/api/api_manager.py +179 -0
  4. dsgrid/api/app.py +420 -0
  5. dsgrid/api/models.py +60 -0
  6. dsgrid/api/response_models.py +116 -0
  7. dsgrid/apps/__init__.py +0 -0
  8. dsgrid/apps/project_viewer/app.py +216 -0
  9. dsgrid/apps/registration_gui.py +444 -0
  10. dsgrid/chronify.py +22 -0
  11. dsgrid/cli/__init__.py +0 -0
  12. dsgrid/cli/common.py +120 -0
  13. dsgrid/cli/config.py +177 -0
  14. dsgrid/cli/download.py +13 -0
  15. dsgrid/cli/dsgrid.py +142 -0
  16. dsgrid/cli/dsgrid_admin.py +349 -0
  17. dsgrid/cli/install_notebooks.py +62 -0
  18. dsgrid/cli/query.py +711 -0
  19. dsgrid/cli/registry.py +1773 -0
  20. dsgrid/cloud/__init__.py +0 -0
  21. dsgrid/cloud/cloud_storage_interface.py +140 -0
  22. dsgrid/cloud/factory.py +31 -0
  23. dsgrid/cloud/fake_storage_interface.py +37 -0
  24. dsgrid/cloud/s3_storage_interface.py +156 -0
  25. dsgrid/common.py +35 -0
  26. dsgrid/config/__init__.py +0 -0
  27. dsgrid/config/annual_time_dimension_config.py +187 -0
  28. dsgrid/config/common.py +131 -0
  29. dsgrid/config/config_base.py +148 -0
  30. dsgrid/config/dataset_config.py +684 -0
  31. dsgrid/config/dataset_schema_handler_factory.py +41 -0
  32. dsgrid/config/date_time_dimension_config.py +108 -0
  33. dsgrid/config/dimension_config.py +54 -0
  34. dsgrid/config/dimension_config_factory.py +65 -0
  35. dsgrid/config/dimension_mapping_base.py +349 -0
  36. dsgrid/config/dimension_mappings_config.py +48 -0
  37. dsgrid/config/dimensions.py +775 -0
  38. dsgrid/config/dimensions_config.py +71 -0
  39. dsgrid/config/index_time_dimension_config.py +76 -0
  40. dsgrid/config/input_dataset_requirements.py +31 -0
  41. dsgrid/config/mapping_tables.py +209 -0
  42. dsgrid/config/noop_time_dimension_config.py +42 -0
  43. dsgrid/config/project_config.py +1457 -0
  44. dsgrid/config/registration_models.py +199 -0
  45. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  46. dsgrid/config/simple_models.py +49 -0
  47. dsgrid/config/supplemental_dimension.py +29 -0
  48. dsgrid/config/time_dimension_base_config.py +200 -0
  49. dsgrid/data_models.py +155 -0
  50. dsgrid/dataset/__init__.py +0 -0
  51. dsgrid/dataset/dataset.py +123 -0
  52. dsgrid/dataset/dataset_expression_handler.py +86 -0
  53. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  54. dsgrid/dataset/dataset_schema_handler_base.py +899 -0
  55. dsgrid/dataset/dataset_schema_handler_one_table.py +196 -0
  56. dsgrid/dataset/dataset_schema_handler_standard.py +303 -0
  57. dsgrid/dataset/growth_rates.py +162 -0
  58. dsgrid/dataset/models.py +44 -0
  59. dsgrid/dataset/table_format_handler_base.py +257 -0
  60. dsgrid/dataset/table_format_handler_factory.py +17 -0
  61. dsgrid/dataset/unpivoted_table.py +121 -0
  62. dsgrid/dimension/__init__.py +0 -0
  63. dsgrid/dimension/base_models.py +218 -0
  64. dsgrid/dimension/dimension_filters.py +308 -0
  65. dsgrid/dimension/standard.py +213 -0
  66. dsgrid/dimension/time.py +531 -0
  67. dsgrid/dimension/time_utils.py +88 -0
  68. dsgrid/dsgrid_rc.py +88 -0
  69. dsgrid/exceptions.py +105 -0
  70. dsgrid/filesystem/__init__.py +0 -0
  71. dsgrid/filesystem/cloud_filesystem.py +32 -0
  72. dsgrid/filesystem/factory.py +32 -0
  73. dsgrid/filesystem/filesystem_interface.py +136 -0
  74. dsgrid/filesystem/local_filesystem.py +74 -0
  75. dsgrid/filesystem/s3_filesystem.py +118 -0
  76. dsgrid/loggers.py +132 -0
  77. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +950 -0
  78. dsgrid/notebooks/registration.ipynb +48 -0
  79. dsgrid/notebooks/start_notebook.sh +11 -0
  80. dsgrid/project.py +451 -0
  81. dsgrid/query/__init__.py +0 -0
  82. dsgrid/query/dataset_mapping_plan.py +142 -0
  83. dsgrid/query/derived_dataset.py +384 -0
  84. dsgrid/query/models.py +726 -0
  85. dsgrid/query/query_context.py +287 -0
  86. dsgrid/query/query_submitter.py +847 -0
  87. dsgrid/query/report_factory.py +19 -0
  88. dsgrid/query/report_peak_load.py +70 -0
  89. dsgrid/query/reports_base.py +20 -0
  90. dsgrid/registry/__init__.py +0 -0
  91. dsgrid/registry/bulk_register.py +161 -0
  92. dsgrid/registry/common.py +287 -0
  93. dsgrid/registry/config_update_checker_base.py +63 -0
  94. dsgrid/registry/data_store_factory.py +34 -0
  95. dsgrid/registry/data_store_interface.py +69 -0
  96. dsgrid/registry/dataset_config_generator.py +156 -0
  97. dsgrid/registry/dataset_registry_manager.py +734 -0
  98. dsgrid/registry/dataset_update_checker.py +16 -0
  99. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  100. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  101. dsgrid/registry/dimension_registry_manager.py +413 -0
  102. dsgrid/registry/dimension_update_checker.py +16 -0
  103. dsgrid/registry/duckdb_data_store.py +185 -0
  104. dsgrid/registry/filesystem_data_store.py +141 -0
  105. dsgrid/registry/filter_registry_manager.py +123 -0
  106. dsgrid/registry/project_config_generator.py +57 -0
  107. dsgrid/registry/project_registry_manager.py +1616 -0
  108. dsgrid/registry/project_update_checker.py +48 -0
  109. dsgrid/registry/registration_context.py +223 -0
  110. dsgrid/registry/registry_auto_updater.py +316 -0
  111. dsgrid/registry/registry_database.py +662 -0
  112. dsgrid/registry/registry_interface.py +446 -0
  113. dsgrid/registry/registry_manager.py +544 -0
  114. dsgrid/registry/registry_manager_base.py +367 -0
  115. dsgrid/registry/versioning.py +92 -0
  116. dsgrid/spark/__init__.py +0 -0
  117. dsgrid/spark/functions.py +545 -0
  118. dsgrid/spark/types.py +50 -0
  119. dsgrid/tests/__init__.py +0 -0
  120. dsgrid/tests/common.py +139 -0
  121. dsgrid/tests/make_us_data_registry.py +204 -0
  122. dsgrid/tests/register_derived_datasets.py +103 -0
  123. dsgrid/tests/utils.py +25 -0
  124. dsgrid/time/__init__.py +0 -0
  125. dsgrid/time/time_conversions.py +80 -0
  126. dsgrid/time/types.py +67 -0
  127. dsgrid/units/__init__.py +0 -0
  128. dsgrid/units/constants.py +113 -0
  129. dsgrid/units/convert.py +71 -0
  130. dsgrid/units/energy.py +145 -0
  131. dsgrid/units/power.py +87 -0
  132. dsgrid/utils/__init__.py +0 -0
  133. dsgrid/utils/dataset.py +612 -0
  134. dsgrid/utils/files.py +179 -0
  135. dsgrid/utils/filters.py +125 -0
  136. dsgrid/utils/id_remappings.py +100 -0
  137. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  138. dsgrid/utils/py_expression_eval/README.md +8 -0
  139. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  140. dsgrid/utils/py_expression_eval/tests.py +283 -0
  141. dsgrid/utils/run_command.py +70 -0
  142. dsgrid/utils/scratch_dir_context.py +64 -0
  143. dsgrid/utils/spark.py +918 -0
  144. dsgrid/utils/spark_partition.py +98 -0
  145. dsgrid/utils/timing.py +239 -0
  146. dsgrid/utils/utilities.py +184 -0
  147. dsgrid/utils/versioning.py +36 -0
  148. dsgrid_toolkit-0.2.0.dist-info/METADATA +216 -0
  149. dsgrid_toolkit-0.2.0.dist-info/RECORD +152 -0
  150. dsgrid_toolkit-0.2.0.dist-info/WHEEL +4 -0
  151. dsgrid_toolkit-0.2.0.dist-info/entry_points.txt +4 -0
  152. dsgrid_toolkit-0.2.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,775 @@
1
+ import abc
2
+ import csv
3
+ import importlib
4
+ import logging
5
+ import os
6
+ from datetime import datetime, timedelta
7
+ from typing import Any, Union, Literal
8
+ import copy
9
+
10
+ from pydantic import field_serializer, field_validator, model_validator, Field, ValidationInfo
11
+ from pydantic.functional_validators import BeforeValidator
12
+ from typing_extensions import Annotated
13
+
14
+ from dsgrid.data_models import DSGBaseDatabaseModel, DSGBaseModel
15
+ from dsgrid.dimension.base_models import DimensionType, DimensionCategory
16
+ from dsgrid.dimension.time import (
17
+ TimeIntervalType,
18
+ MeasurementType,
19
+ TimeZone,
20
+ TimeDimensionType,
21
+ RepresentativePeriodFormat,
22
+ DatetimeFormat,
23
+ )
24
+ from dsgrid.registry.common import REGEX_VALID_REGISTRY_NAME
25
+ from dsgrid.utils.files import compute_file_hash
26
+ from dsgrid.utils.utilities import convert_record_dicts_to_classes
27
+
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class DimensionBaseModel(DSGBaseDatabaseModel):
33
+ """Common attributes for all dimensions"""
34
+
35
+ name: str = Field(
36
+ title="name",
37
+ description="Dimension name",
38
+ )
39
+ dimension_type: DimensionType = Field(
40
+ title="dimension_type",
41
+ alias="type",
42
+ description="Type of the dimension",
43
+ json_schema_extra={
44
+ "options": DimensionType.format_for_docs(),
45
+ },
46
+ )
47
+ dimension_id: str | None = Field(
48
+ default=None,
49
+ title="dimension_id",
50
+ description="Unique identifier, generated by dsgrid",
51
+ json_schema_extra={
52
+ "dsg_internal": True,
53
+ "updateable": False,
54
+ },
55
+ )
56
+ module: str = Field(
57
+ title="module",
58
+ description="Python module with the dimension class",
59
+ default="dsgrid.dimension.standard",
60
+ )
61
+ class_name: str = Field(
62
+ title="class_name",
63
+ description="Dimension record model class name. "
64
+ "The dimension class defines the expected and allowable fields (and their data types)"
65
+ " for the dimension records file."
66
+ "All dimension records must have a 'id' and 'name' field."
67
+ "Some dimension classes support additional fields that can be used for mapping,"
68
+ " querying, display, etc."
69
+ "dsgrid in online-mode only supports dimension classes defined in the"
70
+ " :mod:`dsgrid.dimension.standard` module. If dsgrid does not currently support a"
71
+ " dimension class that you require, please contact the dsgrid-coordination team to"
72
+ " request a new class feature",
73
+ alias="class",
74
+ )
75
+ cls: Any = Field(
76
+ default=None,
77
+ title="cls",
78
+ description="Dimension record model class",
79
+ alias="dimension_class",
80
+ json_schema_extra={
81
+ "dsgrid_internal": True,
82
+ },
83
+ )
84
+ description: str = Field(
85
+ title="description",
86
+ description="A description of the dimension records that is helpful, memorable, and "
87
+ "identifiable",
88
+ )
89
+ id: int | None = Field(
90
+ default=None,
91
+ description="Registry database ID",
92
+ json_schema_extra={
93
+ "dsgrid_internal": True,
94
+ },
95
+ )
96
+
97
+ @field_validator("name")
98
+ @classmethod
99
+ def check_name(cls, name: str) -> str:
100
+ if REGEX_VALID_REGISTRY_NAME.search(name) is None:
101
+ msg = f"dimension name={name} does not meet the requirements"
102
+ raise ValueError(msg)
103
+ return name
104
+
105
+ @field_validator("description")
106
+ @classmethod
107
+ def check_description(cls, description):
108
+ if description == "":
109
+ msg = f'Empty description field for dimension: "{cls}"'
110
+ raise ValueError(msg)
111
+
112
+ # TODO: improve validation for allowable dimension record names.
113
+ prohibited_names = [
114
+ "county",
115
+ "counties",
116
+ "year",
117
+ "hourly",
118
+ "comstock",
119
+ "resstock",
120
+ "tempo",
121
+ "model",
122
+ "source",
123
+ "data-source",
124
+ "dimension",
125
+ ]
126
+ prohibited_names = prohibited_names + [x + "s" for x in prohibited_names]
127
+ if description.lower() in prohibited_names:
128
+ msg = f"""
129
+ Dimension description '{description}' is insufficient. Please be more descriptive.
130
+ Hint: try adding a vintage, or other distinguishable text that will be this dimension memorable,
131
+ identifiable, and reusable for other datasets and projects.
132
+ e.g., 'Time dimension, 2012 hourly EST, period-beginning, no DST, no Leap Day Adjustment, total value'
133
+ is a good description.
134
+ """
135
+ raise ValueError(msg)
136
+ return description
137
+
138
+ @field_validator("module")
139
+ @classmethod
140
+ def check_module(cls, module) -> "DimensionBaseModel":
141
+ if not module.startswith("dsgrid"):
142
+ msg = "Only dsgrid modules are supported as a dimension module."
143
+ raise ValueError(msg)
144
+ return module
145
+
146
+ @field_validator("class_name")
147
+ @classmethod
148
+ def get_dimension_class_name(cls, class_name, info: ValidationInfo):
149
+ """Set class_name based on inputs."""
150
+ if "module" not in info.data:
151
+ return class_name
152
+
153
+ mod = importlib.import_module(info.data["module"])
154
+ if not hasattr(mod, class_name):
155
+ if class_name is None:
156
+ msg = (
157
+ f'There is no class "{class_name}" in module: {mod}.'
158
+ "\nIf you are using a unique dimension name, you must "
159
+ "specify the dimension class."
160
+ )
161
+ else:
162
+ msg = f"dimension class {class_name} not in {mod}"
163
+ raise ValueError(msg)
164
+
165
+ return class_name
166
+
167
+ @field_validator("cls")
168
+ @classmethod
169
+ def get_dimension_class(cls, dim_class, info: ValidationInfo):
170
+ if "module" not in info.data or "class_name" not in info.data:
171
+ return dim_class
172
+
173
+ if dim_class is not None:
174
+ msg = f"cls={dim_class} should not be set"
175
+ raise ValueError(msg)
176
+
177
+ return getattr(
178
+ importlib.import_module(info.data["module"]),
179
+ info.data["class_name"],
180
+ )
181
+
182
+ @property
183
+ def label(self) -> str:
184
+ """Return a label for the dimension to be used in user messages."""
185
+ return f"{self.dimension_type} {self.name}"
186
+
187
+
188
+ class DimensionModel(DimensionBaseModel):
189
+ """Defines a non-time dimension"""
190
+
191
+ filename: str | None = Field(
192
+ title="filename",
193
+ alias="file",
194
+ default=None,
195
+ description="Filename containing dimension records. Only assigned for user input and "
196
+ "output purposes. The registry database stores records in the dimension JSON document.",
197
+ )
198
+ file_hash: str | None = Field(
199
+ title="file_hash",
200
+ description="Hash of the contents of the file",
201
+ json_schema_extra={
202
+ "dsgrid_internal": True,
203
+ },
204
+ default=None,
205
+ )
206
+ records: list = Field(
207
+ title="records",
208
+ description="Dimension records in filename that get loaded at runtime",
209
+ json_schema_extra={
210
+ "dsgrid_internal": True,
211
+ },
212
+ default=[],
213
+ )
214
+
215
+ @field_validator("filename")
216
+ @classmethod
217
+ def check_file(cls, filename: str) -> str:
218
+ """Validate that dimension file exists and has no errors"""
219
+ if filename is not None:
220
+ if not os.path.isfile(filename):
221
+ msg = f"file {filename} does not exist"
222
+ raise ValueError(msg)
223
+ if filename.startswith("s3://"):
224
+ msg = "records must exist in the local filesystem, not on S3"
225
+ raise ValueError(msg)
226
+ if not filename.endswith(".csv"):
227
+ msg = f"only CSV is supported: {filename}"
228
+ raise ValueError(msg)
229
+
230
+ return filename
231
+
232
+ @field_validator("file_hash")
233
+ @classmethod
234
+ def compute_file_hash(cls, file_hash: str, info: ValidationInfo) -> str:
235
+ if info.data.get("filename") is None:
236
+ return file_hash
237
+
238
+ if file_hash is None:
239
+ file_hash = compute_file_hash(info.data["filename"])
240
+ return file_hash
241
+
242
+ @field_validator("records")
243
+ @classmethod
244
+ def add_records(
245
+ cls, records: list[dict[str, Any]], info: ValidationInfo
246
+ ) -> list[dict[str, Any]]:
247
+ """Add records from the file."""
248
+ dim_class = info.data.get("cls")
249
+ if "filename" not in info.data or dim_class is None:
250
+ return records
251
+
252
+ if records:
253
+ if isinstance(records[0], dict):
254
+ records = convert_record_dicts_to_classes(
255
+ records, dim_class, check_duplicates=["id"]
256
+ )
257
+ return records
258
+
259
+ with open(info.data["filename"], encoding="utf-8-sig") as f_in:
260
+ records = convert_record_dicts_to_classes(
261
+ csv.DictReader(f_in), dim_class, check_duplicates=["id"]
262
+ )
263
+ return records
264
+
265
+ @field_serializer("cls", "filename")
266
+ def serialize_cls(self, val: str, _) -> None:
267
+ return None
268
+
269
+
270
+ class TimeRangeModel(DSGBaseModel):
271
+ """Defines a continuous range of time."""
272
+
273
+ # This uses str instead of datetime because this object doesn't have the ability
274
+ # to serialize/deserialize by itself (no str-format).
275
+ # We use the DatetimeRange object during processing.
276
+ start: str = Field(
277
+ title="start",
278
+ description="First timestamp in the data",
279
+ )
280
+ end: str = Field(
281
+ title="end",
282
+ description="Last timestamp in the data (inclusive)",
283
+ )
284
+
285
+
286
+ class MonthRangeModel(DSGBaseModel):
287
+ """Defines a continuous range of time."""
288
+
289
+ # This uses str instead of datetime because this object doesn't have the ability
290
+ # to serialize/deserialize by itself (no str-format).
291
+ # We use the DatetimeRange object during processing.
292
+ start: int = Field(
293
+ title="start",
294
+ description="First month in the data (January is 1, December is 12)",
295
+ )
296
+ end: int = Field(
297
+ title="end",
298
+ description="Last month in the data (inclusive)",
299
+ )
300
+
301
+
302
+ class IndexRangeModel(DSGBaseModel):
303
+ """Defines a continuous range of indices."""
304
+
305
+ start: int = Field(
306
+ title="start",
307
+ description="First of indices",
308
+ )
309
+ end: int = Field(
310
+ title="end",
311
+ description="Last of indices (inclusive)",
312
+ )
313
+
314
+
315
+ class TimeDimensionBaseModel(DimensionBaseModel, abc.ABC):
316
+ """Defines a base model common to all time dimensions."""
317
+
318
+ time_type: TimeDimensionType = Field(
319
+ title="time_type",
320
+ default=TimeDimensionType.DATETIME,
321
+ description="Type of time dimension",
322
+ json_schema_extra={
323
+ "options": TimeDimensionType.format_for_docs(),
324
+ },
325
+ )
326
+
327
+ @field_serializer("cls")
328
+ def serialize_cls(self, val, _):
329
+ return None
330
+
331
+ @abc.abstractmethod
332
+ def is_time_zone_required_in_geography(self):
333
+ """Returns True if the geography dimension records must contain a time_zone column."""
334
+
335
+
336
+ class AlignedTime(DSGBaseModel):
337
+ """Data has absolute timestamps that are aligned with the same start and end
338
+ for each geography."""
339
+
340
+ format_type: Literal[DatetimeFormat.ALIGNED] = DatetimeFormat.ALIGNED
341
+ timezone: TimeZone = Field(
342
+ title="timezone",
343
+ description="Time zone of data",
344
+ )
345
+
346
+
347
+ class LocalTimeAsStrings(DSGBaseModel):
348
+ """Data has absolute timestamps formatted as strings with offsets from UTC.
349
+ They are aligned for each geography when adjusted for time zone but staggered
350
+ in an absolute time scale."""
351
+
352
+ format_type: Literal[DatetimeFormat.LOCAL_AS_STRINGS] = DatetimeFormat.LOCAL_AS_STRINGS
353
+
354
+ data_str_format: str = Field(
355
+ title="data_str_format",
356
+ default="yyyy-MM-dd HH:mm:ssZZZZZ",
357
+ description="Timestamp string format (for parsing the time column of the dataframe). "
358
+ "The string format is used to parse the timestamps in the dataframe while in Spark, "
359
+ "(e.g., yyyy-MM-dd HH:mm:ssZZZZZ). "
360
+ "Cheatsheet reference: `<https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html>`_.",
361
+ )
362
+
363
+ @field_validator("data_str_format")
364
+ @classmethod
365
+ def check_data_str_format(cls, data_str_format):
366
+ msg = "DatetimeFormat.LOCAL_AS_STRINGS is not fully implemented."
367
+ raise NotImplementedError(msg)
368
+ dsf = data_str_format
369
+ if (
370
+ "x" not in dsf
371
+ and "X" not in dsf
372
+ and "Z" not in dsf
373
+ and "z" not in dsf
374
+ and "V" not in dsf
375
+ and "O" not in dsf
376
+ ):
377
+ msg = "data_str_format must provide time zone or zone offset."
378
+ raise ValueError(msg)
379
+ return data_str_format
380
+
381
+
382
+ class DateTimeDimensionModel(TimeDimensionBaseModel):
383
+ """Defines a time dimension where timestamps translate to datetime objects."""
384
+
385
+ datetime_format: Union[AlignedTime, LocalTimeAsStrings] = Field(
386
+ title="datetime_format",
387
+ discriminator="format_type",
388
+ description="""
389
+ Format of the datetime used to define the data format, alignment between geography,
390
+ and time zone information.
391
+ """,
392
+ )
393
+
394
+ measurement_type: MeasurementType = Field(
395
+ title="measurement_type",
396
+ default=MeasurementType.TOTAL,
397
+ description="""
398
+ The type of measurement represented by a value associated with a timestamp:
399
+ mean, min, max, measured, total
400
+ """,
401
+ json_schema_extra={
402
+ "options": MeasurementType.format_for_docs(),
403
+ },
404
+ )
405
+
406
+ str_format: str = Field(
407
+ title="str_format",
408
+ default="%Y-%m-%d %H:%M:%s",
409
+ description="Timestamp string format (for parsing the time ranges). "
410
+ "The string format is used to parse the timestamps provided in the time ranges."
411
+ "Cheatsheet reference: `<https://strftime.org/>`_.",
412
+ )
413
+ frequency: timedelta = Field(
414
+ title="frequency",
415
+ description="Resolution of the timestamps",
416
+ )
417
+ ranges: list[TimeRangeModel] = Field(
418
+ title="time_ranges",
419
+ description="Defines the continuous ranges of time in the data, inclusive of start and end time.",
420
+ )
421
+ time_interval_type: TimeIntervalType = Field(
422
+ title="time_interval",
423
+ description="The range of time that the value associated with a timestamp represents, e.g., period-beginning",
424
+ json_schema_extra={
425
+ "options": TimeIntervalType.format_descriptions_for_docs(),
426
+ },
427
+ )
428
+
429
+ @model_validator(mode="before")
430
+ @classmethod
431
+ def handle_legacy_fields(cls, values):
432
+ if "leap_day_adjustment" in values:
433
+ if values["leap_day_adjustment"] != "none":
434
+ msg = f"Unknown data_schema format: {values=}"
435
+ raise ValueError(msg)
436
+ values.pop("leap_day_adjustment")
437
+
438
+ if "timezone" in values:
439
+ values["datetime_format"] = {
440
+ "format_type": DatetimeFormat.ALIGNED.value,
441
+ "timezone": values["timezone"],
442
+ }
443
+ values.pop("timezone")
444
+
445
+ return values
446
+
447
+ # @model_validator(mode="after")
448
+ # def check_frequency(self) -> "DateTimeDimensionModel":
449
+ # if self.frequency in [timedelta(days=365), timedelta(days=366)]:
450
+ # raise ValueError(
451
+ # f"frequency={self.frequency}, datetime config does not allow 365 or 366 days frequency, "
452
+ # "use class=AnnualTime, time_type=annual to specify a year series."
453
+ # )
454
+ # return self
455
+
456
+ @field_validator("frequency")
457
+ @classmethod
458
+ def check_frequency(cls, frequency: timedelta) -> timedelta:
459
+ if frequency in [timedelta(days=365), timedelta(days=366)]:
460
+ msg = (
461
+ f"{frequency=}, datetime config does not allow 365 or 366 days frequency, "
462
+ "use class=AnnualTime, time_type=annual to specify a year series."
463
+ )
464
+ raise ValueError(msg)
465
+ return frequency
466
+
467
+ @field_validator("ranges")
468
+ @classmethod
469
+ def check_times(
470
+ cls, ranges: list[TimeRangeModel], info: ValidationInfo
471
+ ) -> list[TimeRangeModel]:
472
+ if "str_format" not in info.data or "frequency" not in info.data:
473
+ return ranges
474
+ return _check_time_ranges(ranges, info.data["str_format"], info.data["frequency"])
475
+
476
+ def is_time_zone_required_in_geography(self) -> bool:
477
+ return False
478
+
479
+
480
+ class AnnualTimeDimensionModel(TimeDimensionBaseModel):
481
+ """Defines an annual time dimension where timestamps are years."""
482
+
483
+ time_type: TimeDimensionType = Field(default=TimeDimensionType.ANNUAL)
484
+ measurement_type: MeasurementType = Field(
485
+ title="measurement_type",
486
+ default=MeasurementType.TOTAL,
487
+ description="""
488
+ The type of measurement represented by a value associated with a timestamp:
489
+ e.g., mean, total
490
+ """,
491
+ json_schema_extra={
492
+ "options": MeasurementType.format_for_docs(),
493
+ },
494
+ )
495
+ str_format: str = Field(
496
+ title="str_format",
497
+ default="%Y",
498
+ description="Timestamp string format. "
499
+ "The string format is used to parse the timestamps provided in the time ranges. "
500
+ "Cheatsheet reference: `<https://strftime.org/>`_.",
501
+ )
502
+ ranges: list[TimeRangeModel] = Field(
503
+ default=[],
504
+ title="time_ranges",
505
+ description="Defines the contiguous ranges of time in the data, inclusive of start and end time.",
506
+ )
507
+ include_leap_day: bool = Field(
508
+ title="include_leap_day",
509
+ default=False,
510
+ description="Whether annual time includes leap day.",
511
+ )
512
+
513
+ @field_validator("ranges")
514
+ @classmethod
515
+ def check_times(
516
+ cls, ranges: list[TimeRangeModel], info: ValidationInfo
517
+ ) -> list[TimeRangeModel]:
518
+ return _check_annual_ranges(ranges, info.data["str_format"])
519
+
520
+ @field_validator("measurement_type")
521
+ @classmethod
522
+ def check_measurement_type(cls, measurement_type: MeasurementType) -> MeasurementType:
523
+ # This restriction exists because any other measurement type would require a frequency,
524
+ # and that isn't part of the model definition.
525
+ if measurement_type != MeasurementType.TOTAL:
526
+ msg = f"Annual time currently only supports MeasurementType total: {measurement_type}"
527
+ raise ValueError(msg)
528
+ return measurement_type
529
+
530
+ def is_time_zone_required_in_geography(self) -> bool:
531
+ return False
532
+
533
+
534
+ class RepresentativePeriodTimeDimensionModel(TimeDimensionBaseModel):
535
+ """Defines a representative time dimension."""
536
+
537
+ time_type: TimeDimensionType = Field(default=TimeDimensionType.REPRESENTATIVE_PERIOD)
538
+ measurement_type: MeasurementType = Field(
539
+ title="measurement_type",
540
+ default=MeasurementType.TOTAL,
541
+ description="""
542
+ The type of measurement represented by a value associated with a timestamp:
543
+ e.g., mean, total
544
+ """,
545
+ json_schema_extra={
546
+ "options": MeasurementType.format_for_docs(),
547
+ },
548
+ )
549
+ format: RepresentativePeriodFormat = Field(
550
+ title="format",
551
+ description="Format of the timestamps in the load data",
552
+ )
553
+ ranges: list[MonthRangeModel] = Field(
554
+ title="ranges",
555
+ description="Defines the continuous ranges of time in the data, inclusive of start and end time.",
556
+ )
557
+ time_interval_type: TimeIntervalType = Field(
558
+ title="time_interval",
559
+ description="The range of time that the value associated with a timestamp represents",
560
+ )
561
+
562
+ def is_time_zone_required_in_geography(self) -> bool:
563
+ return True
564
+
565
+
566
+ class IndexTimeDimensionModel(TimeDimensionBaseModel):
567
+ """Defines a time dimension where timestamps are indices."""
568
+
569
+ time_type: TimeDimensionType = Field(default=TimeDimensionType.INDEX)
570
+ measurement_type: MeasurementType = Field(
571
+ title="measurement_type",
572
+ default=MeasurementType.TOTAL,
573
+ description="""
574
+ The type of measurement represented by a value associated with a timestamp:
575
+ e.g., mean, total
576
+ """,
577
+ json_schema_extra={
578
+ "options": MeasurementType.format_for_docs(),
579
+ },
580
+ )
581
+ ranges: list[IndexRangeModel] = Field(
582
+ title="ranges",
583
+ description="Defines the continuous ranges of indices of the data, inclusive of start and end index.",
584
+ )
585
+ frequency: timedelta = Field(
586
+ title="frequency",
587
+ description="Resolution of the timestamps for which the ranges represent.",
588
+ )
589
+ starting_timestamps: list[str] = Field(
590
+ title="starting timestamps",
591
+ description="Starting timestamp for for each of the ranges.",
592
+ )
593
+ str_format: str = Field(
594
+ title="str_format",
595
+ default="%Y-%m-%d %H:%M:%s",
596
+ description="Timestamp string format. "
597
+ "The string format is used to parse the starting timestamp provided. "
598
+ "Cheatsheet reference: `<https://strftime.org/>`_.",
599
+ )
600
+ time_interval_type: TimeIntervalType = Field(
601
+ title="time_interval",
602
+ description="The range of time that the value associated with a timestamp represents, e.g., period-beginning",
603
+ json_schema_extra={
604
+ "options": TimeIntervalType.format_descriptions_for_docs(),
605
+ },
606
+ )
607
+
608
+ @field_validator("starting_timestamps")
609
+ @classmethod
610
+ def check_timestamps(cls, starting_timestamps, info: ValidationInfo) -> list[str]:
611
+ if len(starting_timestamps) != len(info.data["ranges"]):
612
+ msg = f"{starting_timestamps=} must match the number of ranges."
613
+ raise ValueError(msg)
614
+ return starting_timestamps
615
+
616
+ @field_validator("ranges")
617
+ @classmethod
618
+ def check_indices(cls, ranges: list[IndexRangeModel]) -> list[IndexRangeModel]:
619
+ return _check_index_ranges(ranges)
620
+
621
+ def is_time_zone_required_in_geography(self) -> bool:
622
+ return True
623
+
624
+
625
+ class NoOpTimeDimensionModel(TimeDimensionBaseModel):
626
+ """Defines a NoOp time dimension."""
627
+
628
+ time_type: TimeDimensionType = TimeDimensionType.NOOP
629
+
630
+ def is_time_zone_required_in_geography(self) -> bool:
631
+ return False
632
+
633
+
634
+ class DimensionReferenceModel(DSGBaseModel):
635
+ """Reference to a dimension stored in the registry"""
636
+
637
+ dimension_type: DimensionType = Field(
638
+ title="dimension_type",
639
+ alias="type",
640
+ description="Type of the dimension",
641
+ json_schema_extra={
642
+ "options": DimensionType.format_for_docs(),
643
+ },
644
+ )
645
+ dimension_id: str = Field(
646
+ title="dimension_id",
647
+ description="Unique ID of the dimension in the registry. "
648
+ "The dimension ID is generated by dsgrid when a dimension is registered. "
649
+ "Only alphanumerics and dashes are supported.",
650
+ )
651
+ version: str = Field(
652
+ title="version",
653
+ # TODO: add notes about warnings for outdated versions DSGRID-189 & DSGRID-148
654
+ description="Version of the dimension. "
655
+ "The version string must be in semver format (e.g., '1.0.0') and it must be "
656
+ " a valid/existing version in the registry.",
657
+ )
658
+
659
+
660
+ def handle_dimension_union(values):
661
+ values = copy.deepcopy(values)
662
+ for i, value in enumerate(values):
663
+ if isinstance(value, DimensionBaseModel):
664
+ continue
665
+
666
+ dim_type = value.get("type")
667
+ if dim_type is None:
668
+ dim_type = value["dimension_type"]
669
+ # NOTE: Errors inside DimensionModel or DateTimeDimensionModel will be duplicated by Pydantic
670
+ if dim_type == DimensionType.TIME.value:
671
+ if value["time_type"] == TimeDimensionType.DATETIME.value:
672
+ values[i] = DateTimeDimensionModel(**value)
673
+ elif value["time_type"] == TimeDimensionType.ANNUAL.value:
674
+ values[i] = AnnualTimeDimensionModel(**value)
675
+ elif value["time_type"] == TimeDimensionType.REPRESENTATIVE_PERIOD.value:
676
+ values[i] = RepresentativePeriodTimeDimensionModel(**value)
677
+ elif value["time_type"] == TimeDimensionType.INDEX.value:
678
+ values[i] = IndexTimeDimensionModel(**value)
679
+ elif value["time_type"] == TimeDimensionType.NOOP.value:
680
+ values[i] = NoOpTimeDimensionModel(**value)
681
+ else:
682
+ options = [x.value for x in TimeDimensionType]
683
+ msg = f"{value['time_type']} not supported, valid options: {options}"
684
+ raise ValueError(msg)
685
+ else:
686
+ values[i] = DimensionModel(**value)
687
+ return values
688
+
689
+
690
+ DimensionsListModel = Annotated[
691
+ list[
692
+ Union[
693
+ DimensionModel,
694
+ DateTimeDimensionModel,
695
+ AnnualTimeDimensionModel,
696
+ RepresentativePeriodTimeDimensionModel,
697
+ IndexTimeDimensionModel,
698
+ NoOpTimeDimensionModel,
699
+ ]
700
+ ],
701
+ BeforeValidator(handle_dimension_union),
702
+ ]
703
+
704
+
705
+ def _check_time_ranges(ranges: list[TimeRangeModel], str_format: str, frequency: timedelta):
706
+ assert isinstance(frequency, timedelta)
707
+ for trange in ranges:
708
+ # Make sure start and end time parse.
709
+ start = datetime.strptime(trange.start, str_format)
710
+ end = datetime.strptime(trange.end, str_format)
711
+ # Make sure start and end is tz-naive.
712
+ if start.tzinfo is not None or end.tzinfo is not None:
713
+ msg = f"datetime range {trange} start and end need to be tz-naive. Pass in the time zone info via datetime_format"
714
+ raise ValueError(msg)
715
+ if end < start:
716
+ msg = f"datetime range {trange} end must not be less than start."
717
+ raise ValueError(msg)
718
+ if (end - start) % frequency != timedelta(0):
719
+ msg = f"datetime range {trange} is inconsistent with {frequency}"
720
+ raise ValueError(msg)
721
+
722
+ return ranges
723
+
724
+
725
+ def _check_annual_ranges(ranges: list[TimeRangeModel], str_format: str):
726
+ for trange in ranges:
727
+ # Make sure start and end time parse.
728
+ start = datetime.strptime(trange.start, str_format)
729
+ end = datetime.strptime(trange.end, str_format)
730
+ if end < start:
731
+ msg = f"annual time range {trange} end must not be less than start."
732
+ raise ValueError(msg)
733
+
734
+ return ranges
735
+
736
+
737
+ def _check_index_ranges(ranges: list[IndexRangeModel]):
738
+ for trange in ranges:
739
+ if trange.end < trange.start:
740
+ msg = f"index range {trange} end must not be less than start."
741
+ raise ValueError(msg)
742
+
743
+ return ranges
744
+
745
+
746
+ class DimensionCommonModel(DSGBaseModel):
747
+ """Common attributes for all dimensions"""
748
+
749
+ name: str
750
+ dimension_type: DimensionType
751
+ dimension_id: str
752
+ class_name: str
753
+ description: str
754
+
755
+
756
+ class ProjectDimensionModel(DimensionCommonModel):
757
+ """Common attributes for all dimensions that are assigned to a project"""
758
+
759
+ category: DimensionCategory
760
+
761
+
762
+ def create_dimension_common_model(model) -> DimensionCommonModel:
763
+ """Constructs an instance of DimensionBaseModel from subclasses in order to give the API
764
+ one common model for all dimensions. Avoids the complexity of dealing with
765
+ DimensionBaseModel validators.
766
+ """
767
+ fields = set(DimensionCommonModel.model_fields)
768
+ data = {x: getattr(model, x) for x in type(model).model_fields if x in fields}
769
+ return DimensionCommonModel(**data)
770
+
771
+
772
+ def create_project_dimension_model(model, category: DimensionCategory) -> ProjectDimensionModel:
773
+ data = create_dimension_common_model(model).model_dump()
774
+ data["category"] = category.value
775
+ return ProjectDimensionModel(**data)