dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. build_backend.py +93 -0
  2. dsgrid/__init__.py +22 -0
  3. dsgrid/api/__init__.py +0 -0
  4. dsgrid/api/api_manager.py +179 -0
  5. dsgrid/api/app.py +419 -0
  6. dsgrid/api/models.py +60 -0
  7. dsgrid/api/response_models.py +116 -0
  8. dsgrid/apps/__init__.py +0 -0
  9. dsgrid/apps/project_viewer/app.py +216 -0
  10. dsgrid/apps/registration_gui.py +444 -0
  11. dsgrid/chronify.py +32 -0
  12. dsgrid/cli/__init__.py +0 -0
  13. dsgrid/cli/common.py +120 -0
  14. dsgrid/cli/config.py +176 -0
  15. dsgrid/cli/download.py +13 -0
  16. dsgrid/cli/dsgrid.py +157 -0
  17. dsgrid/cli/dsgrid_admin.py +92 -0
  18. dsgrid/cli/install_notebooks.py +62 -0
  19. dsgrid/cli/query.py +729 -0
  20. dsgrid/cli/registry.py +1862 -0
  21. dsgrid/cloud/__init__.py +0 -0
  22. dsgrid/cloud/cloud_storage_interface.py +140 -0
  23. dsgrid/cloud/factory.py +31 -0
  24. dsgrid/cloud/fake_storage_interface.py +37 -0
  25. dsgrid/cloud/s3_storage_interface.py +156 -0
  26. dsgrid/common.py +36 -0
  27. dsgrid/config/__init__.py +0 -0
  28. dsgrid/config/annual_time_dimension_config.py +194 -0
  29. dsgrid/config/common.py +142 -0
  30. dsgrid/config/config_base.py +148 -0
  31. dsgrid/config/dataset_config.py +907 -0
  32. dsgrid/config/dataset_schema_handler_factory.py +46 -0
  33. dsgrid/config/date_time_dimension_config.py +136 -0
  34. dsgrid/config/dimension_config.py +54 -0
  35. dsgrid/config/dimension_config_factory.py +65 -0
  36. dsgrid/config/dimension_mapping_base.py +350 -0
  37. dsgrid/config/dimension_mappings_config.py +48 -0
  38. dsgrid/config/dimensions.py +1025 -0
  39. dsgrid/config/dimensions_config.py +71 -0
  40. dsgrid/config/file_schema.py +190 -0
  41. dsgrid/config/index_time_dimension_config.py +80 -0
  42. dsgrid/config/input_dataset_requirements.py +31 -0
  43. dsgrid/config/mapping_tables.py +209 -0
  44. dsgrid/config/noop_time_dimension_config.py +42 -0
  45. dsgrid/config/project_config.py +1462 -0
  46. dsgrid/config/registration_models.py +188 -0
  47. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  48. dsgrid/config/simple_models.py +49 -0
  49. dsgrid/config/supplemental_dimension.py +29 -0
  50. dsgrid/config/time_dimension_base_config.py +192 -0
  51. dsgrid/data_models.py +155 -0
  52. dsgrid/dataset/__init__.py +0 -0
  53. dsgrid/dataset/dataset.py +123 -0
  54. dsgrid/dataset/dataset_expression_handler.py +86 -0
  55. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  56. dsgrid/dataset/dataset_schema_handler_base.py +945 -0
  57. dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
  58. dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
  59. dsgrid/dataset/growth_rates.py +162 -0
  60. dsgrid/dataset/models.py +51 -0
  61. dsgrid/dataset/table_format_handler_base.py +257 -0
  62. dsgrid/dataset/table_format_handler_factory.py +17 -0
  63. dsgrid/dataset/unpivoted_table.py +121 -0
  64. dsgrid/dimension/__init__.py +0 -0
  65. dsgrid/dimension/base_models.py +230 -0
  66. dsgrid/dimension/dimension_filters.py +308 -0
  67. dsgrid/dimension/standard.py +252 -0
  68. dsgrid/dimension/time.py +352 -0
  69. dsgrid/dimension/time_utils.py +103 -0
  70. dsgrid/dsgrid_rc.py +88 -0
  71. dsgrid/exceptions.py +105 -0
  72. dsgrid/filesystem/__init__.py +0 -0
  73. dsgrid/filesystem/cloud_filesystem.py +32 -0
  74. dsgrid/filesystem/factory.py +32 -0
  75. dsgrid/filesystem/filesystem_interface.py +136 -0
  76. dsgrid/filesystem/local_filesystem.py +74 -0
  77. dsgrid/filesystem/s3_filesystem.py +118 -0
  78. dsgrid/loggers.py +132 -0
  79. dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
  80. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
  81. dsgrid/notebooks/registration.ipynb +48 -0
  82. dsgrid/notebooks/start_notebook.sh +11 -0
  83. dsgrid/project.py +451 -0
  84. dsgrid/query/__init__.py +0 -0
  85. dsgrid/query/dataset_mapping_plan.py +142 -0
  86. dsgrid/query/derived_dataset.py +388 -0
  87. dsgrid/query/models.py +728 -0
  88. dsgrid/query/query_context.py +287 -0
  89. dsgrid/query/query_submitter.py +994 -0
  90. dsgrid/query/report_factory.py +19 -0
  91. dsgrid/query/report_peak_load.py +70 -0
  92. dsgrid/query/reports_base.py +20 -0
  93. dsgrid/registry/__init__.py +0 -0
  94. dsgrid/registry/bulk_register.py +165 -0
  95. dsgrid/registry/common.py +287 -0
  96. dsgrid/registry/config_update_checker_base.py +63 -0
  97. dsgrid/registry/data_store_factory.py +34 -0
  98. dsgrid/registry/data_store_interface.py +74 -0
  99. dsgrid/registry/dataset_config_generator.py +158 -0
  100. dsgrid/registry/dataset_registry_manager.py +950 -0
  101. dsgrid/registry/dataset_update_checker.py +16 -0
  102. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  103. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  104. dsgrid/registry/dimension_registry_manager.py +413 -0
  105. dsgrid/registry/dimension_update_checker.py +16 -0
  106. dsgrid/registry/duckdb_data_store.py +207 -0
  107. dsgrid/registry/filesystem_data_store.py +150 -0
  108. dsgrid/registry/filter_registry_manager.py +123 -0
  109. dsgrid/registry/project_config_generator.py +57 -0
  110. dsgrid/registry/project_registry_manager.py +1623 -0
  111. dsgrid/registry/project_update_checker.py +48 -0
  112. dsgrid/registry/registration_context.py +223 -0
  113. dsgrid/registry/registry_auto_updater.py +316 -0
  114. dsgrid/registry/registry_database.py +667 -0
  115. dsgrid/registry/registry_interface.py +446 -0
  116. dsgrid/registry/registry_manager.py +558 -0
  117. dsgrid/registry/registry_manager_base.py +367 -0
  118. dsgrid/registry/versioning.py +92 -0
  119. dsgrid/rust_ext/__init__.py +14 -0
  120. dsgrid/rust_ext/find_minimal_patterns.py +129 -0
  121. dsgrid/spark/__init__.py +0 -0
  122. dsgrid/spark/functions.py +589 -0
  123. dsgrid/spark/types.py +110 -0
  124. dsgrid/tests/__init__.py +0 -0
  125. dsgrid/tests/common.py +140 -0
  126. dsgrid/tests/make_us_data_registry.py +265 -0
  127. dsgrid/tests/register_derived_datasets.py +103 -0
  128. dsgrid/tests/utils.py +25 -0
  129. dsgrid/time/__init__.py +0 -0
  130. dsgrid/time/time_conversions.py +80 -0
  131. dsgrid/time/types.py +67 -0
  132. dsgrid/units/__init__.py +0 -0
  133. dsgrid/units/constants.py +113 -0
  134. dsgrid/units/convert.py +71 -0
  135. dsgrid/units/energy.py +145 -0
  136. dsgrid/units/power.py +87 -0
  137. dsgrid/utils/__init__.py +0 -0
  138. dsgrid/utils/dataset.py +830 -0
  139. dsgrid/utils/files.py +179 -0
  140. dsgrid/utils/filters.py +125 -0
  141. dsgrid/utils/id_remappings.py +100 -0
  142. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  143. dsgrid/utils/py_expression_eval/README.md +8 -0
  144. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  145. dsgrid/utils/py_expression_eval/tests.py +283 -0
  146. dsgrid/utils/run_command.py +70 -0
  147. dsgrid/utils/scratch_dir_context.py +65 -0
  148. dsgrid/utils/spark.py +918 -0
  149. dsgrid/utils/spark_partition.py +98 -0
  150. dsgrid/utils/timing.py +239 -0
  151. dsgrid/utils/utilities.py +221 -0
  152. dsgrid/utils/versioning.py +36 -0
  153. dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
  154. dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
  155. dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
  156. dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
  157. dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
File without changes
@@ -0,0 +1,140 @@
1
+ import abc
2
+ from contextlib import contextmanager
3
+
4
+
5
+ class CloudStorageInterface(abc.ABC):
6
+ """Defines interface to synchronize data stored on a cloud storage system."""
7
+
8
+ @abc.abstractmethod
9
+ def check_lock_file(self, path):
10
+ """Checks if a given lock file path exists and that it was created by the same username and uuid.
11
+
12
+ Returns an error if the existing lock file's username and uuid do not match.
13
+
14
+ Parameters
15
+ ----------
16
+ path : str
17
+ Lock file path
18
+
19
+ Raises
20
+ ------
21
+ DSGRegistryLockError
22
+ exception is raised if any lock files are found
23
+ """
24
+
25
+ @abc.abstractmethod
26
+ def check_valid_lock_file(self, path):
27
+ """Checks if a given lock file path is valid. Returns errors if invalid.
28
+
29
+ Parameters
30
+ ----------
31
+ path : str
32
+ Lock file path
33
+ """
34
+
35
+ @abc.abstractmethod
36
+ def get_lock_files(self, relative_path=None):
37
+ """Returns a generator of lock files within the /.locks directory (non-recursive).
38
+
39
+ Parameters
40
+ ----------
41
+ relative_path : str
42
+ Relative path to search for lock files in. By default, None.
43
+ """
44
+
45
+ @abc.abstractmethod
46
+ def has_lock_files(self):
47
+ """Returns True if a .lock file exists within the /.locks directory."""
48
+
49
+ @abc.abstractmethod
50
+ @contextmanager
51
+ def make_lock_file_managed(self, path):
52
+ """Context manager to make a lock file given the file path. On close, it removes the lock file.
53
+
54
+ Parameters
55
+ ----------
56
+ path : str
57
+ Lock file path
58
+
59
+ Raises
60
+ ------
61
+ DSGRegistryLockError
62
+ Raises a registry lock error if a registry.lock already exists
63
+ """
64
+
65
+ @abc.abstractmethod
66
+ def make_lock_file(self, path):
67
+ """Make a lock file given the file path.
68
+
69
+ Parameters
70
+ ----------
71
+ path : str
72
+ Lock file path
73
+
74
+ Raises
75
+ ------
76
+ DSGRegistryLockError
77
+ Raises a registry lock error if a registry.lock already exists
78
+ """
79
+
80
+ @abc.abstractmethod
81
+ def read_lock_file(self, path):
82
+ """Reads a lock file and returns a dictionary of its contents.
83
+
84
+ Parameters
85
+ ----------
86
+ path : str
87
+ Lock file path
88
+ """
89
+
90
+ @abc.abstractmethod
91
+ def remove_lock_file(self, path, force=False):
92
+ """Remove a lock file.
93
+
94
+ Parameters
95
+ ----------
96
+ path : str
97
+ Lock file path
98
+ force : bool
99
+ Boolean flag to force removal of lock file that does not have the same UUID or username, by default False
100
+
101
+ Raises
102
+ ------
103
+ DSGRegistryLockError
104
+ Raises a registry lock error if a registry.lock already exists and is was generated by a different user or different UUID (if force==False).
105
+ """
106
+
107
+ @abc.abstractmethod
108
+ def sync_pull(self, remote_path, local_path, exclude=None, delete_local=False, is_file=False):
109
+ """Synchronize data from remote_path to local_path.
110
+ If delete_local is True, this deletes any files in local_path that do not exist in remote_path.
111
+
112
+ Parameters
113
+ ----------
114
+ remote_path : str
115
+ Remote registry path
116
+ local_path : str
117
+ Local registry path
118
+ delete_local : bool, optional
119
+ If true, this method deletes files and directories that exist in the local_path but not in the remote_path
120
+ exclude : list, optional
121
+ List of patterns to exclude, by default None.
122
+ If excluding whole directories, the exclusion must end with /* , e.g. data/*
123
+ is_file : bool, optional
124
+ If the path is a file (not a registry). By default False.
125
+ """
126
+
127
+ @abc.abstractmethod
128
+ def sync_push(self, remote_path, local_path, exclude=None):
129
+ """Synchronize data from local path to remote_path
130
+
131
+ Parameters
132
+ ----------
133
+ remote_path : str
134
+ Remote registry path
135
+ local_path : str
136
+ Local registry path
137
+ exclude : list, optional
138
+ List of patterns to exclude, by default None.
139
+ If excluding whole directories, the exclusion must end with /* , e.g. data/*
140
+ """
@@ -0,0 +1,31 @@
1
+ # from .s3_storage_interface import S3StorageInterface
2
+ from .fake_storage_interface import FakeStorageInterface
3
+
4
+ # from dsgrid.common import AWS_PROFILE_NAME
5
+
6
+
7
+ def make_cloud_storage_interface(local_path, remote_path, uuid, user, offline=False):
8
+ """Creates a CloudStorageInterface appropriate for path.
9
+
10
+ Parameters
11
+ ----------
12
+ local_path : str
13
+ remote_path : str
14
+ uuid : str
15
+ Unique ID to be used when generating cloud locks.
16
+ user : str
17
+ Username to be used when generating cloud locks.
18
+ offline : bool, optional
19
+ If True, don't perform any remote syncing operations.
20
+
21
+
22
+ Returns
23
+ -------
24
+ CloudStorageInterface
25
+
26
+ """
27
+ if not offline and remote_path.lower().startswith("s3"):
28
+ msg = f"Support for S3 is currently disabled: {remote_path=}"
29
+ raise NotImplementedError(msg)
30
+ # return S3StorageInterface(local_path, remote_path, uuid, user, profile=AWS_PROFILE_NAME)
31
+ return FakeStorageInterface()
@@ -0,0 +1,37 @@
1
+ from contextlib import contextmanager
2
+ from .cloud_storage_interface import CloudStorageInterface
3
+
4
+
5
+ class FakeStorageInterface(CloudStorageInterface):
6
+ """Fake interface for tests and local mode."""
7
+
8
+ def check_lock_file(self, path):
9
+ pass
10
+
11
+ def check_valid_lock_file(self, path):
12
+ pass
13
+
14
+ def get_lock_files(self, relative_path=None):
15
+ pass
16
+
17
+ def has_lock_files(self):
18
+ pass
19
+
20
+ @contextmanager
21
+ def make_lock_file_managed(self, path):
22
+ yield
23
+
24
+ def make_lock_file(self, path):
25
+ pass
26
+
27
+ def read_lock_file(self, path):
28
+ pass
29
+
30
+ def remove_lock_file(self, path, force=False):
31
+ pass
32
+
33
+ def sync_pull(self, remote_path, local_path, exclude=None, delete_local=False, is_file=False):
34
+ pass
35
+
36
+ def sync_push(self, remote_path, local_path, exclude=None):
37
+ pass
@@ -0,0 +1,156 @@
1
+ from contextlib import contextmanager
2
+ from datetime import datetime
3
+ import json
4
+ import logging
5
+ import os
6
+ import sys
7
+ from pathlib import Path
8
+ import time
9
+
10
+ from dsgrid.exceptions import DSGMakeLockError, DSGRegistryLockError
11
+ from dsgrid.filesystem.local_filesystem import LocalFilesystem
12
+ from dsgrid.filesystem.s3_filesystem import S3Filesystem
13
+ from dsgrid.utils.run_command import check_run_command
14
+ from dsgrid.utils.timing import track_timing, timer_stats_collector
15
+ from .cloud_storage_interface import CloudStorageInterface
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class S3StorageInterface(CloudStorageInterface):
21
+ """Interface to S3."""
22
+
23
+ def __init__(self, local_path, remote_path, uuid, user, profile):
24
+ self._local_path = local_path
25
+ self._remote_path = remote_path
26
+ self._uuid = uuid
27
+ self._user = user
28
+ self._local_filesystem = LocalFilesystem()
29
+ self._s3_filesystem = S3Filesystem(remote_path, profile)
30
+
31
+ def _sync(self, src, dst, exclude=None, is_file=False):
32
+ start = time.time()
33
+ aws_exec = self._get_aws_executable()
34
+ cmd = "cp" if is_file else "sync"
35
+ sync_command = f"{aws_exec} s3 {cmd} {src} {dst} --profile {self._s3_filesystem.profile}"
36
+ if exclude:
37
+ for x in exclude:
38
+ sync_command = sync_command + f" --exclude {x}"
39
+ logger.info("Running %s", sync_command)
40
+ check_run_command(sync_command)
41
+ logger.info("Command took %s seconds", time.time() - start)
42
+
43
+ @staticmethod
44
+ def _get_aws_executable():
45
+ # subprocess.run cannot find the aws executable on Windows if shell is not True.
46
+ # That's probably an aws cli bug. We can workaround it here.
47
+ return "aws.cmd" if sys.platform == "win32" else "aws"
48
+
49
+ def check_lock_file(self, path):
50
+ self.check_valid_lock_file(path)
51
+ filepath = self._s3_filesystem.path(path)
52
+ if filepath.exists():
53
+ lock_contents = self.read_lock_file(filepath)
54
+ if (
55
+ not self._uuid == lock_contents["uuid"]
56
+ or not self._user == lock_contents["username"]
57
+ ):
58
+ msg = f"Registry path {str(filepath)} is currently locked by {lock_contents['username']}, timestamp={lock_contents['timestamp']}, uuid={lock_contents['uuid']}."
59
+ raise DSGRegistryLockError(msg)
60
+
61
+ def check_valid_lock_file(self, path):
62
+ path = Path(path)
63
+ # check that lock file is of type .lock
64
+ if path.suffix != ".lock":
65
+ msg = f"Lock file path provided ({path}) must be a valid .lock path"
66
+ raise DSGMakeLockError(msg)
67
+ # check that lock file in expected dirs
68
+ relative_path = Path(path).parent
69
+ if str(relative_path).startswith("s3:/nrel-dsgrid-registry/"):
70
+ relative_path = relative_path.relative_to("s3:/nrel-dsgrid-registry/")
71
+ if str(relative_path).startswith("/"):
72
+ relative_path = relative_path.relative_to("/")
73
+ if relative_path == Path("configs/.locks"):
74
+ pass
75
+ elif relative_path == Path("data/.locks"):
76
+ pass
77
+ else:
78
+ DSGMakeLockError(
79
+ "Lock file path provided must have relative path of configs/.locks or data/.locks"
80
+ )
81
+ return True
82
+
83
+ def get_lock_files(self, relative_path=None):
84
+ if relative_path:
85
+ contents = self._s3_filesystem.path(
86
+ f"{self._s3_filesystem.bucket}/{relative_path}"
87
+ ).glob(pattern="**/*.locks")
88
+ else:
89
+ contents = self._s3_filesystem.path(self._s3_filesystem.bucket).glob(
90
+ pattern="**/*.locks"
91
+ )
92
+ return contents
93
+
94
+ def has_lock_files(self):
95
+ contents = self.get_lock_files()
96
+ return next(contents, None) is not None
97
+
98
+ @contextmanager
99
+ def make_lock_file_managed(self, path):
100
+ try:
101
+ self.make_lock_file(path)
102
+ yield
103
+ finally:
104
+ self.remove_lock_file(path)
105
+
106
+ def make_lock_file(self, path):
107
+ self.check_lock_file(path)
108
+ filepath = self._s3_filesystem.path(path)
109
+ lock_content = {
110
+ "username": self._user,
111
+ "uuid": self._uuid,
112
+ "timestamp": str(datetime.now()),
113
+ }
114
+ self._s3_filesystem.path(filepath).write_text(json.dumps(lock_content))
115
+
116
+ def read_lock_file(self, path):
117
+ lockfile_contents = json.loads(self._s3_filesystem.path(path).read_text())
118
+ return lockfile_contents
119
+
120
+ def remove_lock_file(self, path, force=False):
121
+ filepath = self._s3_filesystem.path(path)
122
+ if filepath.exists():
123
+ lockfile_contents = self.read_lock_file(filepath)
124
+ if not force:
125
+ if (
126
+ not self._uuid == lockfile_contents["uuid"]
127
+ and not self._user == lockfile_contents["username"]
128
+ ):
129
+ msg = f"Registry path {str(filepath)} is currently locked by {lockfile_contents['username']}. Lock created as {lockfile_contents['timestamp']} with uuid={lockfile_contents['uuid']}."
130
+ raise DSGRegistryLockError(msg)
131
+ if force:
132
+ logger.warning(
133
+ "Force removed lock file with user=%s and uuid=%s",
134
+ lockfile_contents["username"],
135
+ lockfile_contents["uuid"],
136
+ )
137
+ filepath.unlink()
138
+
139
+ @track_timing(timer_stats_collector)
140
+ def sync_pull(self, remote_path, local_path, exclude=None, delete_local=False, is_file=False):
141
+ if delete_local:
142
+ local_contents = self._local_filesystem.rglob(local_path)
143
+ s3_contents = {
144
+ str(self._s3_filesystem.path(x).relative_to(self._s3_filesystem.path(remote_path)))
145
+ for x in self._s3_filesystem.rglob(remote_path)
146
+ }
147
+ for content in local_contents:
148
+ relcontent = os.path.relpath(content, local_path)
149
+ if relcontent not in s3_contents:
150
+ self._local_filesystem.rm(content)
151
+ logger.info("delete: %s because it is not in %s", relcontent, remote_path)
152
+ self._sync(remote_path, local_path, exclude, is_file)
153
+
154
+ @track_timing(timer_stats_collector)
155
+ def sync_push(self, remote_path, local_path, exclude=None):
156
+ self._sync(local_path, remote_path, exclude)
dsgrid/common.py ADDED
@@ -0,0 +1,36 @@
1
+ import enum
2
+ import os
3
+ from pathlib import Path
4
+
5
+ # AWS_PROFILE_NAME = "nrel-aws-dsgrid"
6
+ REMOTE_REGISTRY = "s3://nrel-dsgrid-registry"
7
+
8
+
9
+ def on_hpc():
10
+ # NREL_CLUSTER is not set when you ssh into a compute node.
11
+ return "NREL_CLUSTER" in os.environ or "SLURM_JOB_ID" in os.environ
12
+
13
+
14
+ if on_hpc():
15
+ LOCAL_REGISTRY = Path("/scratch") / os.environ["USER"] / ".dsgrid-registry"
16
+ else:
17
+ LOCAL_REGISTRY = Path.home() / ".dsgrid-registry"
18
+
19
+ LOCAL_REGISTRY_DATA = LOCAL_REGISTRY / "data"
20
+ PROJECT_FILENAME = "project.json5"
21
+ REGISTRY_FILENAME = "registry.json5"
22
+ DATASET_FILENAME = "dataset.json5"
23
+ DIMENSIONS_FILENAME = "dimensions.json5"
24
+ DEFAULT_DB_PASSWORD = "openSesame"
25
+ DEFAULT_SCRATCH_DIR = "__dsgrid_scratch__"
26
+ SCALING_FACTOR_COLUMN = "scaling_factor"
27
+ SYNC_EXCLUDE_LIST = ["*.DS_Store", "**/*.lock"]
28
+ TIME_ZONE_COLUMN = "time_zone"
29
+ VALUE_COLUMN = "value"
30
+
31
+
32
+ class BackendEngine(enum.StrEnum):
33
+ """Supported backend SQL processing engines"""
34
+
35
+ DUCKDB = "duckdb"
36
+ SPARK = "spark"
File without changes
@@ -0,0 +1,194 @@
1
+ import logging
2
+ from datetime import timedelta
3
+ from dateutil.relativedelta import relativedelta
4
+
5
+ import pandas as pd
6
+ from chronify.time_range_generator_factory import make_time_range_generator
7
+
8
+ from dsgrid.config.date_time_dimension_config import DateTimeDimensionConfig
9
+ from dsgrid.dimension.base_models import DimensionType
10
+ from dsgrid.dimension.time import AnnualTimeRange
11
+ from dsgrid.exceptions import DSGInvalidDataset
12
+ from dsgrid.time.types import AnnualTimestampType
13
+ from dsgrid.dimension.time_utils import is_leap_year, build_annual_ranges
14
+ from dsgrid.spark.functions import (
15
+ cross_join,
16
+ handle_column_spaces,
17
+ select_expr,
18
+ )
19
+ from dsgrid.spark.types import (
20
+ DataFrame,
21
+ StructType,
22
+ StructField,
23
+ IntegerType,
24
+ StringType,
25
+ TimestampType,
26
+ F,
27
+ )
28
+ from dsgrid.utils.timing import timer_stats_collector, track_timing
29
+ from dsgrid.utils.spark import (
30
+ get_spark_session,
31
+ set_session_time_zone,
32
+ )
33
+ from .dimensions import AnnualTimeDimensionModel
34
+ from .time_dimension_base_config import TimeDimensionBaseConfig
35
+
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class AnnualTimeDimensionConfig(TimeDimensionBaseConfig):
41
+ """Provides an interface to an AnnualTimeDimensionModel.
42
+
43
+ Note: Annual time does not currently support Chronify conversion because the annual time
44
+ to datetime mapping is not yet available in Chronify.
45
+ """
46
+
47
+ @staticmethod
48
+ def model_class() -> AnnualTimeDimensionModel:
49
+ return AnnualTimeDimensionModel
50
+
51
+ @track_timing(timer_stats_collector)
52
+ def check_dataset_time_consistency(self, load_data_df, time_columns) -> None:
53
+ logger.info("Check AnnualTimeDimensionConfig dataset time consistency.")
54
+ if len(time_columns) > 1:
55
+ msg = (
56
+ "AnnualTimeDimensionConfig expects only one column from "
57
+ f"get_load_data_time_columns, but has {time_columns}"
58
+ )
59
+ raise ValueError(msg)
60
+ time_col = time_columns[0]
61
+ time_ranges = self.get_time_ranges()
62
+ assert len(time_ranges) == 1, len(time_ranges)
63
+ time_range = time_ranges[0]
64
+ # TODO: need to support validation of multiple time ranges: DSGRID-173
65
+
66
+ expected_timestamps = time_range.list_time_range()
67
+ actual_timestamps = [
68
+ pd.Timestamp(str(x[time_col]), tz=self.get_tzinfo()).to_pydatetime()
69
+ for x in load_data_df.select(time_col)
70
+ .distinct()
71
+ .filter(f"{time_col} IS NOT NULL")
72
+ .sort(time_col)
73
+ .collect()
74
+ ]
75
+ if expected_timestamps != actual_timestamps:
76
+ mismatch = sorted(
77
+ set(expected_timestamps).symmetric_difference(set(actual_timestamps))
78
+ )
79
+ msg = f"load_data {time_col}s do not match expected times. mismatch={mismatch}"
80
+ raise DSGInvalidDataset(msg)
81
+
82
+ def build_time_dataframe(self) -> DataFrame:
83
+ time_col = self.get_load_data_time_columns()
84
+ assert len(time_col) == 1, time_col
85
+ time_col = time_col[0]
86
+ schema = StructType([StructField(time_col, IntegerType(), False)])
87
+
88
+ model_time = self.list_expected_dataset_timestamps()
89
+ df_time = get_spark_session().createDataFrame(model_time, schema=schema)
90
+ return df_time
91
+
92
+ def get_frequency(self) -> relativedelta:
93
+ freqs = [trange.frequency for trange in self.model.ranges]
94
+ if len(set(freqs)) > 1:
95
+ msg = f"AnnualTimeDimensionConfig.get_frequency found multiple frequencies: {freqs}"
96
+ raise ValueError(msg)
97
+ return relativedelta(years=freqs[0])
98
+
99
+ def get_time_ranges(self) -> list[AnnualTimeRange]:
100
+ ranges = []
101
+ for start, end, freq in build_annual_ranges(self.model.ranges, tz=self.get_tzinfo()):
102
+ ranges.append(
103
+ AnnualTimeRange(
104
+ start=start,
105
+ end=end,
106
+ frequency=freq,
107
+ )
108
+ )
109
+
110
+ return ranges
111
+
112
+ def get_start_times(self) -> list[pd.Timestamp]:
113
+ start_times = []
114
+ for start, _, _ in build_annual_ranges(self.model.ranges, tz=self.get_tzinfo()):
115
+ start_times.append(start)
116
+
117
+ return start_times
118
+
119
+ def get_lengths(self) -> list[int]:
120
+ lengths = []
121
+ for start, end, freq in build_annual_ranges(self.model.ranges, tz=self.get_tzinfo()):
122
+ if (end.year - start.year) % freq == 0:
123
+ length = (end.year - start.year) // freq + 1
124
+ else:
125
+ # In case where end year is not inclusive
126
+ length = (end.year - start.year) // freq
127
+ lengths.append(length)
128
+ return lengths
129
+
130
+ def get_load_data_time_columns(self) -> list[str]:
131
+ return list(AnnualTimestampType._fields)
132
+
133
+ def get_time_zone(self) -> None:
134
+ return None
135
+
136
+ def get_tzinfo(self) -> None:
137
+ return None
138
+
139
+ def get_time_interval_type(self) -> None:
140
+ return None
141
+
142
+ def list_expected_dataset_timestamps(self) -> list[AnnualTimestampType]:
143
+ timestamps = []
144
+ for start, end, freq in build_annual_ranges(self.model.ranges, tz=self.get_tzinfo()):
145
+ year = start.year
146
+ while year <= end.year:
147
+ timestamps.append(AnnualTimestampType(year))
148
+ year += freq
149
+ return timestamps
150
+
151
+
152
+ def map_annual_time_to_date_time(
153
+ df: DataFrame,
154
+ annual_dim: AnnualTimeDimensionConfig,
155
+ dt_dim: DateTimeDimensionConfig,
156
+ value_columns: set[str],
157
+ ) -> DataFrame:
158
+ """Map a DataFrame with an annual time dimension to a DateTime time dimension."""
159
+ annual_col = annual_dim.get_load_data_time_columns()[0]
160
+ myear_column = DimensionType.MODEL_YEAR.value
161
+ timestamps = make_time_range_generator(dt_dim.to_chronify()).list_timestamps()
162
+ time_cols = dt_dim.get_load_data_time_columns()
163
+ assert len(time_cols) == 1, time_cols
164
+ time_col = time_cols[0]
165
+ schema = StructType([StructField(time_col, TimestampType(), False)])
166
+ dt_df = get_spark_session().createDataFrame(
167
+ [(x.to_pydatetime(),) for x in timestamps], schema=schema
168
+ )
169
+
170
+ # Note that MeasurementType.TOTAL has already been verified, i.e.,
171
+ # each value associated with an annual time represents the total over that year.
172
+ with set_session_time_zone(dt_dim.model.time_zone_format.time_zone):
173
+ years = (
174
+ select_expr(dt_df, [f"YEAR({handle_column_spaces(time_col)}) AS year"])
175
+ .distinct()
176
+ .collect()
177
+ )
178
+ if len(years) != 1:
179
+ msg = "DateTime dimension has more than one year: {years=}"
180
+ raise NotImplementedError(msg)
181
+ if annual_dim.model.include_leap_day and is_leap_year(years[0].year):
182
+ measured_duration = timedelta(days=366)
183
+ else:
184
+ measured_duration = timedelta(days=365)
185
+
186
+ df2 = (
187
+ cross_join(df, dt_df)
188
+ .withColumn(myear_column, F.col(annual_col).cast(StringType()))
189
+ .drop(annual_col)
190
+ )
191
+ frequency: timedelta = dt_dim.get_frequency()
192
+ for column in value_columns:
193
+ df2 = df2.withColumn(column, F.col(column) / (measured_duration / frequency))
194
+ return df2