dsgrid-toolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsgrid-toolkit might be problematic. Click here for more details.

Files changed (152) hide show
  1. dsgrid/__init__.py +22 -0
  2. dsgrid/api/__init__.py +0 -0
  3. dsgrid/api/api_manager.py +179 -0
  4. dsgrid/api/app.py +420 -0
  5. dsgrid/api/models.py +60 -0
  6. dsgrid/api/response_models.py +116 -0
  7. dsgrid/apps/__init__.py +0 -0
  8. dsgrid/apps/project_viewer/app.py +216 -0
  9. dsgrid/apps/registration_gui.py +444 -0
  10. dsgrid/chronify.py +22 -0
  11. dsgrid/cli/__init__.py +0 -0
  12. dsgrid/cli/common.py +120 -0
  13. dsgrid/cli/config.py +177 -0
  14. dsgrid/cli/download.py +13 -0
  15. dsgrid/cli/dsgrid.py +142 -0
  16. dsgrid/cli/dsgrid_admin.py +349 -0
  17. dsgrid/cli/install_notebooks.py +62 -0
  18. dsgrid/cli/query.py +711 -0
  19. dsgrid/cli/registry.py +1773 -0
  20. dsgrid/cloud/__init__.py +0 -0
  21. dsgrid/cloud/cloud_storage_interface.py +140 -0
  22. dsgrid/cloud/factory.py +31 -0
  23. dsgrid/cloud/fake_storage_interface.py +37 -0
  24. dsgrid/cloud/s3_storage_interface.py +156 -0
  25. dsgrid/common.py +35 -0
  26. dsgrid/config/__init__.py +0 -0
  27. dsgrid/config/annual_time_dimension_config.py +187 -0
  28. dsgrid/config/common.py +131 -0
  29. dsgrid/config/config_base.py +148 -0
  30. dsgrid/config/dataset_config.py +684 -0
  31. dsgrid/config/dataset_schema_handler_factory.py +41 -0
  32. dsgrid/config/date_time_dimension_config.py +108 -0
  33. dsgrid/config/dimension_config.py +54 -0
  34. dsgrid/config/dimension_config_factory.py +65 -0
  35. dsgrid/config/dimension_mapping_base.py +349 -0
  36. dsgrid/config/dimension_mappings_config.py +48 -0
  37. dsgrid/config/dimensions.py +775 -0
  38. dsgrid/config/dimensions_config.py +71 -0
  39. dsgrid/config/index_time_dimension_config.py +76 -0
  40. dsgrid/config/input_dataset_requirements.py +31 -0
  41. dsgrid/config/mapping_tables.py +209 -0
  42. dsgrid/config/noop_time_dimension_config.py +42 -0
  43. dsgrid/config/project_config.py +1457 -0
  44. dsgrid/config/registration_models.py +199 -0
  45. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  46. dsgrid/config/simple_models.py +49 -0
  47. dsgrid/config/supplemental_dimension.py +29 -0
  48. dsgrid/config/time_dimension_base_config.py +200 -0
  49. dsgrid/data_models.py +155 -0
  50. dsgrid/dataset/__init__.py +0 -0
  51. dsgrid/dataset/dataset.py +123 -0
  52. dsgrid/dataset/dataset_expression_handler.py +86 -0
  53. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  54. dsgrid/dataset/dataset_schema_handler_base.py +899 -0
  55. dsgrid/dataset/dataset_schema_handler_one_table.py +196 -0
  56. dsgrid/dataset/dataset_schema_handler_standard.py +303 -0
  57. dsgrid/dataset/growth_rates.py +162 -0
  58. dsgrid/dataset/models.py +44 -0
  59. dsgrid/dataset/table_format_handler_base.py +257 -0
  60. dsgrid/dataset/table_format_handler_factory.py +17 -0
  61. dsgrid/dataset/unpivoted_table.py +121 -0
  62. dsgrid/dimension/__init__.py +0 -0
  63. dsgrid/dimension/base_models.py +218 -0
  64. dsgrid/dimension/dimension_filters.py +308 -0
  65. dsgrid/dimension/standard.py +213 -0
  66. dsgrid/dimension/time.py +531 -0
  67. dsgrid/dimension/time_utils.py +88 -0
  68. dsgrid/dsgrid_rc.py +88 -0
  69. dsgrid/exceptions.py +105 -0
  70. dsgrid/filesystem/__init__.py +0 -0
  71. dsgrid/filesystem/cloud_filesystem.py +32 -0
  72. dsgrid/filesystem/factory.py +32 -0
  73. dsgrid/filesystem/filesystem_interface.py +136 -0
  74. dsgrid/filesystem/local_filesystem.py +74 -0
  75. dsgrid/filesystem/s3_filesystem.py +118 -0
  76. dsgrid/loggers.py +132 -0
  77. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +950 -0
  78. dsgrid/notebooks/registration.ipynb +48 -0
  79. dsgrid/notebooks/start_notebook.sh +11 -0
  80. dsgrid/project.py +451 -0
  81. dsgrid/query/__init__.py +0 -0
  82. dsgrid/query/dataset_mapping_plan.py +142 -0
  83. dsgrid/query/derived_dataset.py +384 -0
  84. dsgrid/query/models.py +726 -0
  85. dsgrid/query/query_context.py +287 -0
  86. dsgrid/query/query_submitter.py +847 -0
  87. dsgrid/query/report_factory.py +19 -0
  88. dsgrid/query/report_peak_load.py +70 -0
  89. dsgrid/query/reports_base.py +20 -0
  90. dsgrid/registry/__init__.py +0 -0
  91. dsgrid/registry/bulk_register.py +161 -0
  92. dsgrid/registry/common.py +287 -0
  93. dsgrid/registry/config_update_checker_base.py +63 -0
  94. dsgrid/registry/data_store_factory.py +34 -0
  95. dsgrid/registry/data_store_interface.py +69 -0
  96. dsgrid/registry/dataset_config_generator.py +156 -0
  97. dsgrid/registry/dataset_registry_manager.py +734 -0
  98. dsgrid/registry/dataset_update_checker.py +16 -0
  99. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  100. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  101. dsgrid/registry/dimension_registry_manager.py +413 -0
  102. dsgrid/registry/dimension_update_checker.py +16 -0
  103. dsgrid/registry/duckdb_data_store.py +185 -0
  104. dsgrid/registry/filesystem_data_store.py +141 -0
  105. dsgrid/registry/filter_registry_manager.py +123 -0
  106. dsgrid/registry/project_config_generator.py +57 -0
  107. dsgrid/registry/project_registry_manager.py +1616 -0
  108. dsgrid/registry/project_update_checker.py +48 -0
  109. dsgrid/registry/registration_context.py +223 -0
  110. dsgrid/registry/registry_auto_updater.py +316 -0
  111. dsgrid/registry/registry_database.py +662 -0
  112. dsgrid/registry/registry_interface.py +446 -0
  113. dsgrid/registry/registry_manager.py +544 -0
  114. dsgrid/registry/registry_manager_base.py +367 -0
  115. dsgrid/registry/versioning.py +92 -0
  116. dsgrid/spark/__init__.py +0 -0
  117. dsgrid/spark/functions.py +545 -0
  118. dsgrid/spark/types.py +50 -0
  119. dsgrid/tests/__init__.py +0 -0
  120. dsgrid/tests/common.py +139 -0
  121. dsgrid/tests/make_us_data_registry.py +204 -0
  122. dsgrid/tests/register_derived_datasets.py +103 -0
  123. dsgrid/tests/utils.py +25 -0
  124. dsgrid/time/__init__.py +0 -0
  125. dsgrid/time/time_conversions.py +80 -0
  126. dsgrid/time/types.py +67 -0
  127. dsgrid/units/__init__.py +0 -0
  128. dsgrid/units/constants.py +113 -0
  129. dsgrid/units/convert.py +71 -0
  130. dsgrid/units/energy.py +145 -0
  131. dsgrid/units/power.py +87 -0
  132. dsgrid/utils/__init__.py +0 -0
  133. dsgrid/utils/dataset.py +612 -0
  134. dsgrid/utils/files.py +179 -0
  135. dsgrid/utils/filters.py +125 -0
  136. dsgrid/utils/id_remappings.py +100 -0
  137. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  138. dsgrid/utils/py_expression_eval/README.md +8 -0
  139. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  140. dsgrid/utils/py_expression_eval/tests.py +283 -0
  141. dsgrid/utils/run_command.py +70 -0
  142. dsgrid/utils/scratch_dir_context.py +64 -0
  143. dsgrid/utils/spark.py +918 -0
  144. dsgrid/utils/spark_partition.py +98 -0
  145. dsgrid/utils/timing.py +239 -0
  146. dsgrid/utils/utilities.py +184 -0
  147. dsgrid/utils/versioning.py +36 -0
  148. dsgrid_toolkit-0.2.0.dist-info/METADATA +216 -0
  149. dsgrid_toolkit-0.2.0.dist-info/RECORD +152 -0
  150. dsgrid_toolkit-0.2.0.dist-info/WHEEL +4 -0
  151. dsgrid_toolkit-0.2.0.dist-info/entry_points.txt +4 -0
  152. dsgrid_toolkit-0.2.0.dist-info/licenses/LICENSE +29 -0
File without changes
@@ -0,0 +1,140 @@
1
+ import abc
2
+ from contextlib import contextmanager
3
+
4
+
5
+ class CloudStorageInterface(abc.ABC):
6
+ """Defines interface to synchronize data stored on a cloud storage system."""
7
+
8
+ @abc.abstractmethod
9
+ def check_lock_file(self, path):
10
+ """Checks if a given lock file path exists and that it was created by the same username and uuid.
11
+
12
+ Returns an error if the existing lock file's username and uuid do not match.
13
+
14
+ Parameters
15
+ ----------
16
+ path : str
17
+ Lock file path
18
+
19
+ Raises
20
+ ------
21
+ DSGRegistryLockError
22
+ exception is raised if any lock files are found
23
+ """
24
+
25
+ @abc.abstractmethod
26
+ def check_valid_lock_file(self, path):
27
+ """Checks if a given lock file path is valid. Returns errors if invalid.
28
+
29
+ Parameters
30
+ ----------
31
+ path : str
32
+ Lock file path
33
+ """
34
+
35
+ @abc.abstractmethod
36
+ def get_lock_files(self, relative_path=None):
37
+ """Returns a generator of lock files within the /.locks directory (non-recursive).
38
+
39
+ Parameters
40
+ ----------
41
+ relative_path : str
42
+ Relative path to search for lock files in. By default, None.
43
+ """
44
+
45
+ @abc.abstractmethod
46
+ def has_lock_files(self):
47
+ """Returns True if a .lock file exists within the /.locks directory."""
48
+
49
+ @abc.abstractmethod
50
+ @contextmanager
51
+ def make_lock_file_managed(self, path):
52
+ """Context manager to make a lock file given the file path. On close, it removes the lock file.
53
+
54
+ Parameters
55
+ ----------
56
+ path : str
57
+ Lock file path
58
+
59
+ Raises
60
+ ------
61
+ DSGRegistryLockError
62
+ Raises a registry lock error if a registry.lock already exists
63
+ """
64
+
65
+ @abc.abstractmethod
66
+ def make_lock_file(self, path):
67
+ """Make a lock file given the file path.
68
+
69
+ Parameters
70
+ ----------
71
+ path : str
72
+ Lock file path
73
+
74
+ Raises
75
+ ------
76
+ DSGRegistryLockError
77
+ Raises a registry lock error if a registry.lock already exists
78
+ """
79
+
80
+ @abc.abstractmethod
81
+ def read_lock_file(self, path):
82
+ """Reads a lock file and returns a dictionary of its contents.
83
+
84
+ Parameters
85
+ ----------
86
+ path : str
87
+ Lock file path
88
+ """
89
+
90
+ @abc.abstractmethod
91
+ def remove_lock_file(self, path, force=False):
92
+ """Remove a lock file.
93
+
94
+ Parameters
95
+ ----------
96
+ path : str
97
+ Lock file path
98
+ force : bool
99
+ Boolean flag to force removal of lock file that does not have the same UUID or username, by default False
100
+
101
+ Raises
102
+ ------
103
+ DSGRegistryLockError
104
+ Raises a registry lock error if a registry.lock already exists and is was generated by a different user or different UUID (if force==False).
105
+ """
106
+
107
+ @abc.abstractmethod
108
+ def sync_pull(self, remote_path, local_path, exclude=None, delete_local=False, is_file=False):
109
+ """Synchronize data from remote_path to local_path.
110
+ If delete_local is True, this deletes any files in local_path that do not exist in remote_path.
111
+
112
+ Parameters
113
+ ----------
114
+ remote_path : str
115
+ Remote registry path
116
+ local_path : str
117
+ Local registry path
118
+ delete_local : bool, optional
119
+ If true, this method deletes files and directories that exist in the local_path but not in the remote_path
120
+ exclude : list, optional
121
+ List of patterns to exclude, by default None.
122
+ If excluding whole directories, the exclusion must end with /* , e.g. data/*
123
+ is_file : bool, optional
124
+ If the path is a file (not a registry). By default False.
125
+ """
126
+
127
+ @abc.abstractmethod
128
+ def sync_push(self, remote_path, local_path, exclude=None):
129
+ """Synchronize data from local path to remote_path
130
+
131
+ Parameters
132
+ ----------
133
+ remote_path : str
134
+ Remote registry path
135
+ local_path : str
136
+ Local registry path
137
+ exclude : list, optional
138
+ List of patterns to exclude, by default None.
139
+ If excluding whole directories, the exclusion must end with /* , e.g. data/*
140
+ """
@@ -0,0 +1,31 @@
1
+ # from .s3_storage_interface import S3StorageInterface
2
+ from .fake_storage_interface import FakeStorageInterface
3
+
4
+ # from dsgrid.common import AWS_PROFILE_NAME
5
+
6
+
7
+ def make_cloud_storage_interface(local_path, remote_path, uuid, user, offline=False):
8
+ """Creates a CloudStorageInterface appropriate for path.
9
+
10
+ Parameters
11
+ ----------
12
+ local_path : str
13
+ remote_path : str
14
+ uuid : str
15
+ Unique ID to be used when generating cloud locks.
16
+ user : str
17
+ Username to be used when generating cloud locks.
18
+ offline : bool, optional
19
+ If True, don't perform any remote syncing operations.
20
+
21
+
22
+ Returns
23
+ -------
24
+ CloudStorageInterface
25
+
26
+ """
27
+ if not offline and remote_path.lower().startswith("s3"):
28
+ msg = f"Support for S3 is currently disabled: {remote_path=}"
29
+ raise NotImplementedError(msg)
30
+ # return S3StorageInterface(local_path, remote_path, uuid, user, profile=AWS_PROFILE_NAME)
31
+ return FakeStorageInterface()
@@ -0,0 +1,37 @@
1
+ from contextlib import contextmanager
2
+ from .cloud_storage_interface import CloudStorageInterface
3
+
4
+
5
+ class FakeStorageInterface(CloudStorageInterface):
6
+ """Fake interface for tests and local mode."""
7
+
8
+ def check_lock_file(self, path):
9
+ pass
10
+
11
+ def check_valid_lock_file(self, path):
12
+ pass
13
+
14
+ def get_lock_files(self, relative_path=None):
15
+ pass
16
+
17
+ def has_lock_files(self):
18
+ pass
19
+
20
+ @contextmanager
21
+ def make_lock_file_managed(self, path):
22
+ yield
23
+
24
+ def make_lock_file(self, path):
25
+ pass
26
+
27
+ def read_lock_file(self, path):
28
+ pass
29
+
30
+ def remove_lock_file(self, path, force=False):
31
+ pass
32
+
33
+ def sync_pull(self, remote_path, local_path, exclude=None, delete_local=False, is_file=False):
34
+ pass
35
+
36
+ def sync_push(self, remote_path, local_path, exclude=None):
37
+ pass
@@ -0,0 +1,156 @@
1
+ from contextlib import contextmanager
2
+ from datetime import datetime
3
+ import json
4
+ import logging
5
+ import os
6
+ import sys
7
+ from pathlib import Path
8
+ import time
9
+
10
+ from dsgrid.exceptions import DSGMakeLockError, DSGRegistryLockError
11
+ from dsgrid.filesystem.local_filesystem import LocalFilesystem
12
+ from dsgrid.filesystem.s3_filesystem import S3Filesystem
13
+ from dsgrid.utils.run_command import check_run_command
14
+ from dsgrid.utils.timing import track_timing, timer_stats_collector
15
+ from .cloud_storage_interface import CloudStorageInterface
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class S3StorageInterface(CloudStorageInterface):
21
+ """Interface to S3."""
22
+
23
+ def __init__(self, local_path, remote_path, uuid, user, profile):
24
+ self._local_path = local_path
25
+ self._remote_path = remote_path
26
+ self._uuid = uuid
27
+ self._user = user
28
+ self._local_filesystem = LocalFilesystem()
29
+ self._s3_filesystem = S3Filesystem(remote_path, profile)
30
+
31
+ def _sync(self, src, dst, exclude=None, is_file=False):
32
+ start = time.time()
33
+ aws_exec = self._get_aws_executable()
34
+ cmd = "cp" if is_file else "sync"
35
+ sync_command = f"{aws_exec} s3 {cmd} {src} {dst} --profile {self._s3_filesystem.profile}"
36
+ if exclude:
37
+ for x in exclude:
38
+ sync_command = sync_command + f" --exclude {x}"
39
+ logger.info("Running %s", sync_command)
40
+ check_run_command(sync_command)
41
+ logger.info("Command took %s seconds", time.time() - start)
42
+
43
+ @staticmethod
44
+ def _get_aws_executable():
45
+ # subprocess.run cannot find the aws executable on Windows if shell is not True.
46
+ # That's probably an aws cli bug. We can workaround it here.
47
+ return "aws.cmd" if sys.platform == "win32" else "aws"
48
+
49
+ def check_lock_file(self, path):
50
+ self.check_valid_lock_file(path)
51
+ filepath = self._s3_filesystem.path(path)
52
+ if filepath.exists():
53
+ lock_contents = self.read_lock_file(filepath)
54
+ if (
55
+ not self._uuid == lock_contents["uuid"]
56
+ or not self._user == lock_contents["username"]
57
+ ):
58
+ msg = f"Registry path {str(filepath)} is currently locked by {lock_contents['username']}, timestamp={lock_contents['timestamp']}, uuid={lock_contents['uuid']}."
59
+ raise DSGRegistryLockError(msg)
60
+
61
+ def check_valid_lock_file(self, path):
62
+ path = Path(path)
63
+ # check that lock file is of type .lock
64
+ if path.suffix != ".lock":
65
+ msg = f"Lock file path provided ({path}) must be a valid .lock path"
66
+ raise DSGMakeLockError(msg)
67
+ # check that lock file in expected dirs
68
+ relative_path = Path(path).parent
69
+ if str(relative_path).startswith("s3:/nrel-dsgrid-registry/"):
70
+ relative_path = relative_path.relative_to("s3:/nrel-dsgrid-registry/")
71
+ if str(relative_path).startswith("/"):
72
+ relative_path = relative_path.relative_to("/")
73
+ if relative_path == Path("configs/.locks"):
74
+ pass
75
+ elif relative_path == Path("data/.locks"):
76
+ pass
77
+ else:
78
+ DSGMakeLockError(
79
+ "Lock file path provided must have relative path of configs/.locks or data/.locks"
80
+ )
81
+ return True
82
+
83
+ def get_lock_files(self, relative_path=None):
84
+ if relative_path:
85
+ contents = self._s3_filesystem.path(
86
+ f"{self._s3_filesystem.bucket}/{relative_path}"
87
+ ).glob(pattern="**/*.locks")
88
+ else:
89
+ contents = self._s3_filesystem.path(self._s3_filesystem.bucket).glob(
90
+ pattern="**/*.locks"
91
+ )
92
+ return contents
93
+
94
+ def has_lock_files(self):
95
+ contents = self.get_lock_files()
96
+ return next(contents, None) is not None
97
+
98
+ @contextmanager
99
+ def make_lock_file_managed(self, path):
100
+ try:
101
+ self.make_lock_file(path)
102
+ yield
103
+ finally:
104
+ self.remove_lock_file(path)
105
+
106
+ def make_lock_file(self, path):
107
+ self.check_lock_file(path)
108
+ filepath = self._s3_filesystem.path(path)
109
+ lock_content = {
110
+ "username": self._user,
111
+ "uuid": self._uuid,
112
+ "timestamp": str(datetime.now()),
113
+ }
114
+ self._s3_filesystem.path(filepath).write_text(json.dumps(lock_content))
115
+
116
+ def read_lock_file(self, path):
117
+ lockfile_contents = json.loads(self._s3_filesystem.path(path).read_text())
118
+ return lockfile_contents
119
+
120
+ def remove_lock_file(self, path, force=False):
121
+ filepath = self._s3_filesystem.path(path)
122
+ if filepath.exists():
123
+ lockfile_contents = self.read_lock_file(filepath)
124
+ if not force:
125
+ if (
126
+ not self._uuid == lockfile_contents["uuid"]
127
+ and not self._user == lockfile_contents["username"]
128
+ ):
129
+ msg = f"Registry path {str(filepath)} is currently locked by {lockfile_contents['username']}. Lock created as {lockfile_contents['timestamp']} with uuid={lockfile_contents['uuid']}."
130
+ raise DSGRegistryLockError(msg)
131
+ if force:
132
+ logger.warning(
133
+ "Force removed lock file with user=%s and uuid=%s",
134
+ lockfile_contents["username"],
135
+ lockfile_contents["uuid"],
136
+ )
137
+ filepath.unlink()
138
+
139
+ @track_timing(timer_stats_collector)
140
+ def sync_pull(self, remote_path, local_path, exclude=None, delete_local=False, is_file=False):
141
+ if delete_local:
142
+ local_contents = self._local_filesystem.rglob(local_path)
143
+ s3_contents = {
144
+ str(self._s3_filesystem.path(x).relative_to(self._s3_filesystem.path(remote_path)))
145
+ for x in self._s3_filesystem.rglob(remote_path)
146
+ }
147
+ for content in local_contents:
148
+ relcontent = os.path.relpath(content, local_path)
149
+ if relcontent not in s3_contents:
150
+ self._local_filesystem.rm(content)
151
+ logger.info("delete: %s because it is not in %s", relcontent, remote_path)
152
+ self._sync(remote_path, local_path, exclude, is_file)
153
+
154
+ @track_timing(timer_stats_collector)
155
+ def sync_push(self, remote_path, local_path, exclude=None):
156
+ self._sync(local_path, remote_path, exclude)
dsgrid/common.py ADDED
@@ -0,0 +1,35 @@
1
+ import enum
2
+ import os
3
+ from pathlib import Path
4
+
5
+ # AWS_PROFILE_NAME = "nrel-aws-dsgrid"
6
+ REMOTE_REGISTRY = "s3://nrel-dsgrid-registry"
7
+
8
+
9
+ def on_hpc():
10
+ # NREL_CLUSTER is not set when you ssh into a compute node.
11
+ return "NREL_CLUSTER" in os.environ or "SLURM_JOB_ID" in os.environ
12
+
13
+
14
+ if on_hpc():
15
+ LOCAL_REGISTRY = Path("/scratch") / os.environ["USER"] / ".dsgrid-registry"
16
+ else:
17
+ LOCAL_REGISTRY = Path.home() / ".dsgrid-registry"
18
+
19
+ LOCAL_REGISTRY_DATA = LOCAL_REGISTRY / "data"
20
+ PROJECT_FILENAME = "project.json5"
21
+ REGISTRY_FILENAME = "registry.json5"
22
+ DATASET_FILENAME = "dataset.json5"
23
+ DIMENSIONS_FILENAME = "dimensions.json5"
24
+ DEFAULT_DB_PASSWORD = "openSesame"
25
+ DEFAULT_SCRATCH_DIR = "__dsgrid_scratch__"
26
+ SCALING_FACTOR_COLUMN = "scaling_factor"
27
+ SYNC_EXCLUDE_LIST = ["*.DS_Store", "**/*.lock"]
28
+ VALUE_COLUMN = "value"
29
+
30
+
31
+ class BackendEngine(enum.StrEnum):
32
+ """Supported backend SQL processing engines"""
33
+
34
+ DUCKDB = "duckdb"
35
+ SPARK = "spark"
File without changes
@@ -0,0 +1,187 @@
1
+ import logging
2
+ from datetime import timedelta, datetime
3
+
4
+ import pandas as pd
5
+ from chronify.time_range_generator_factory import make_time_range_generator
6
+
7
+ from dsgrid.config.date_time_dimension_config import DateTimeDimensionConfig
8
+ from dsgrid.dimension.base_models import DimensionType
9
+ from dsgrid.dimension.time import AnnualTimeRange
10
+ from dsgrid.exceptions import DSGInvalidDataset
11
+ from dsgrid.time.types import AnnualTimestampType
12
+ from dsgrid.dimension.time_utils import is_leap_year
13
+ from dsgrid.spark.functions import (
14
+ cross_join,
15
+ handle_column_spaces,
16
+ select_expr,
17
+ )
18
+ from dsgrid.spark.types import (
19
+ DataFrame,
20
+ StructType,
21
+ StructField,
22
+ IntegerType,
23
+ StringType,
24
+ TimestampType,
25
+ F,
26
+ )
27
+ from dsgrid.utils.timing import timer_stats_collector, track_timing
28
+ from dsgrid.utils.spark import (
29
+ get_spark_session,
30
+ set_session_time_zone,
31
+ )
32
+ from .dimensions import AnnualTimeDimensionModel
33
+ from .time_dimension_base_config import TimeDimensionBaseConfig
34
+
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class AnnualTimeDimensionConfig(TimeDimensionBaseConfig):
40
+ """Provides an interface to an AnnualTimeDimensionModel."""
41
+
42
+ @staticmethod
43
+ def model_class() -> AnnualTimeDimensionModel:
44
+ return AnnualTimeDimensionModel
45
+
46
+ @track_timing(timer_stats_collector)
47
+ def check_dataset_time_consistency(self, load_data_df, time_columns) -> None:
48
+ logger.info("Check AnnualTimeDimensionConfig dataset time consistency.")
49
+ if len(time_columns) > 1:
50
+ msg = (
51
+ "AnnualTimeDimensionConfig expects only one column from "
52
+ f"get_load_data_time_columns, but has {time_columns}"
53
+ )
54
+ raise ValueError(msg)
55
+ time_col = time_columns[0]
56
+ time_ranges = self.get_time_ranges()
57
+ assert len(time_ranges) == 1, len(time_ranges)
58
+ time_range = time_ranges[0]
59
+ # TODO: need to support validation of multiple time ranges: DSGRID-173
60
+
61
+ expected_timestamps = time_range.list_time_range()
62
+ actual_timestamps = [
63
+ pd.Timestamp(str(x[time_col]), tz=self.get_tzinfo()).to_pydatetime()
64
+ for x in load_data_df.select(time_col)
65
+ .distinct()
66
+ .filter(f"{time_col} IS NOT NULL")
67
+ .sort(time_col)
68
+ .collect()
69
+ ]
70
+ if expected_timestamps != actual_timestamps:
71
+ mismatch = sorted(
72
+ set(expected_timestamps).symmetric_difference(set(actual_timestamps))
73
+ )
74
+ msg = f"load_data {time_col}s do not match expected times. mismatch={mismatch}"
75
+ raise DSGInvalidDataset(msg)
76
+
77
+ def build_time_dataframe(self) -> DataFrame:
78
+ time_col = self.get_load_data_time_columns()
79
+ assert len(time_col) == 1, time_col
80
+ time_col = time_col[0]
81
+ schema = StructType([StructField(time_col, IntegerType(), False)])
82
+
83
+ model_time = self.list_expected_dataset_timestamps()
84
+ df_time = get_spark_session().createDataFrame(model_time, schema=schema)
85
+ return df_time
86
+
87
+ def get_frequency(self) -> timedelta:
88
+ return timedelta(days=365)
89
+
90
+ def get_time_ranges(self) -> list[AnnualTimeRange]:
91
+ ranges = []
92
+ frequency = self.get_frequency()
93
+ for start, end in self._build_time_ranges(
94
+ self.model.ranges, self.model.str_format, tz=self.get_tzinfo()
95
+ ):
96
+ start = pd.Timestamp(start)
97
+ end = pd.Timestamp(end)
98
+ ranges.append(
99
+ AnnualTimeRange(
100
+ start=start,
101
+ end=end,
102
+ frequency=frequency,
103
+ )
104
+ )
105
+
106
+ return ranges
107
+
108
+ def get_start_times(self) -> list[pd.Timestamp]:
109
+ tz = self.get_tzinfo()
110
+ start_times = []
111
+ for trange in self.model.ranges:
112
+ start = datetime.strptime(trange.start, self.model.str_format)
113
+ assert start.tzinfo is None
114
+ start_times.append(start.replace(tzinfo=tz))
115
+
116
+ return start_times
117
+
118
+ def get_lengths(self) -> list[int]:
119
+ lengths = []
120
+ for trange in self.model.ranges:
121
+ start = datetime.strptime(trange.start, self.model.str_format)
122
+ end = datetime.strptime(trange.end, self.model.str_format)
123
+ lengths.append(end.year - start.year + 1)
124
+ return lengths
125
+
126
+ def get_load_data_time_columns(self) -> list[str]:
127
+ return list(AnnualTimestampType._fields)
128
+
129
+ def get_time_zone(self) -> None:
130
+ return None
131
+
132
+ def get_tzinfo(self) -> None:
133
+ return None
134
+
135
+ def get_time_interval_type(self) -> None:
136
+ return None
137
+
138
+ def list_expected_dataset_timestamps(self) -> list[AnnualTimestampType]:
139
+ timestamps = []
140
+ for time_range in self.model.ranges:
141
+ start, end = (int(time_range.start), int(time_range.end))
142
+ timestamps += [AnnualTimestampType(x) for x in range(start, end + 1)]
143
+ return timestamps
144
+
145
+
146
+ def map_annual_time_to_date_time(
147
+ df: DataFrame,
148
+ annual_dim: AnnualTimeDimensionConfig,
149
+ dt_dim: DateTimeDimensionConfig,
150
+ value_columns: set[str],
151
+ ) -> DataFrame:
152
+ """Map a DataFrame with an annual time dimension to a DateTime time dimension."""
153
+ annual_col = annual_dim.get_load_data_time_columns()[0]
154
+ myear_column = DimensionType.MODEL_YEAR.value
155
+ timestamps = make_time_range_generator(dt_dim.to_chronify()).list_timestamps()
156
+ time_cols = dt_dim.get_load_data_time_columns()
157
+ assert len(time_cols) == 1, time_cols
158
+ time_col = time_cols[0]
159
+ schema = StructType([StructField(time_col, TimestampType(), False)])
160
+ dt_df = get_spark_session().createDataFrame(
161
+ [(x.to_pydatetime(),) for x in timestamps], schema=schema
162
+ )
163
+
164
+ # Note that MeasurementType.TOTAL has already been verified.
165
+ with set_session_time_zone(dt_dim.model.datetime_format.timezone.tz_name):
166
+ years = (
167
+ select_expr(dt_df, [f"YEAR({handle_column_spaces(time_col)}) AS year"])
168
+ .distinct()
169
+ .collect()
170
+ )
171
+ if len(years) != 1:
172
+ msg = "DateTime dimension has more than one year: {years=}"
173
+ raise NotImplementedError(msg)
174
+ if annual_dim.model.include_leap_day and is_leap_year(years[0].year):
175
+ measured_duration = timedelta(days=366)
176
+ else:
177
+ measured_duration = timedelta(days=365)
178
+
179
+ df2 = (
180
+ cross_join(df, dt_df)
181
+ .withColumn(myear_column, F.col(annual_col).cast(StringType()))
182
+ .drop(annual_col)
183
+ )
184
+ frequency = dt_dim.model.frequency
185
+ for column in value_columns:
186
+ df2 = df2.withColumn(column, F.col(column) / (measured_duration / frequency))
187
+ return df2