dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. build_backend.py +93 -0
  2. dsgrid/__init__.py +22 -0
  3. dsgrid/api/__init__.py +0 -0
  4. dsgrid/api/api_manager.py +179 -0
  5. dsgrid/api/app.py +419 -0
  6. dsgrid/api/models.py +60 -0
  7. dsgrid/api/response_models.py +116 -0
  8. dsgrid/apps/__init__.py +0 -0
  9. dsgrid/apps/project_viewer/app.py +216 -0
  10. dsgrid/apps/registration_gui.py +444 -0
  11. dsgrid/chronify.py +32 -0
  12. dsgrid/cli/__init__.py +0 -0
  13. dsgrid/cli/common.py +120 -0
  14. dsgrid/cli/config.py +176 -0
  15. dsgrid/cli/download.py +13 -0
  16. dsgrid/cli/dsgrid.py +157 -0
  17. dsgrid/cli/dsgrid_admin.py +92 -0
  18. dsgrid/cli/install_notebooks.py +62 -0
  19. dsgrid/cli/query.py +729 -0
  20. dsgrid/cli/registry.py +1862 -0
  21. dsgrid/cloud/__init__.py +0 -0
  22. dsgrid/cloud/cloud_storage_interface.py +140 -0
  23. dsgrid/cloud/factory.py +31 -0
  24. dsgrid/cloud/fake_storage_interface.py +37 -0
  25. dsgrid/cloud/s3_storage_interface.py +156 -0
  26. dsgrid/common.py +36 -0
  27. dsgrid/config/__init__.py +0 -0
  28. dsgrid/config/annual_time_dimension_config.py +194 -0
  29. dsgrid/config/common.py +142 -0
  30. dsgrid/config/config_base.py +148 -0
  31. dsgrid/config/dataset_config.py +907 -0
  32. dsgrid/config/dataset_schema_handler_factory.py +46 -0
  33. dsgrid/config/date_time_dimension_config.py +136 -0
  34. dsgrid/config/dimension_config.py +54 -0
  35. dsgrid/config/dimension_config_factory.py +65 -0
  36. dsgrid/config/dimension_mapping_base.py +350 -0
  37. dsgrid/config/dimension_mappings_config.py +48 -0
  38. dsgrid/config/dimensions.py +1025 -0
  39. dsgrid/config/dimensions_config.py +71 -0
  40. dsgrid/config/file_schema.py +190 -0
  41. dsgrid/config/index_time_dimension_config.py +80 -0
  42. dsgrid/config/input_dataset_requirements.py +31 -0
  43. dsgrid/config/mapping_tables.py +209 -0
  44. dsgrid/config/noop_time_dimension_config.py +42 -0
  45. dsgrid/config/project_config.py +1462 -0
  46. dsgrid/config/registration_models.py +188 -0
  47. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  48. dsgrid/config/simple_models.py +49 -0
  49. dsgrid/config/supplemental_dimension.py +29 -0
  50. dsgrid/config/time_dimension_base_config.py +192 -0
  51. dsgrid/data_models.py +155 -0
  52. dsgrid/dataset/__init__.py +0 -0
  53. dsgrid/dataset/dataset.py +123 -0
  54. dsgrid/dataset/dataset_expression_handler.py +86 -0
  55. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  56. dsgrid/dataset/dataset_schema_handler_base.py +945 -0
  57. dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
  58. dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
  59. dsgrid/dataset/growth_rates.py +162 -0
  60. dsgrid/dataset/models.py +51 -0
  61. dsgrid/dataset/table_format_handler_base.py +257 -0
  62. dsgrid/dataset/table_format_handler_factory.py +17 -0
  63. dsgrid/dataset/unpivoted_table.py +121 -0
  64. dsgrid/dimension/__init__.py +0 -0
  65. dsgrid/dimension/base_models.py +230 -0
  66. dsgrid/dimension/dimension_filters.py +308 -0
  67. dsgrid/dimension/standard.py +252 -0
  68. dsgrid/dimension/time.py +352 -0
  69. dsgrid/dimension/time_utils.py +103 -0
  70. dsgrid/dsgrid_rc.py +88 -0
  71. dsgrid/exceptions.py +105 -0
  72. dsgrid/filesystem/__init__.py +0 -0
  73. dsgrid/filesystem/cloud_filesystem.py +32 -0
  74. dsgrid/filesystem/factory.py +32 -0
  75. dsgrid/filesystem/filesystem_interface.py +136 -0
  76. dsgrid/filesystem/local_filesystem.py +74 -0
  77. dsgrid/filesystem/s3_filesystem.py +118 -0
  78. dsgrid/loggers.py +132 -0
  79. dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
  80. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
  81. dsgrid/notebooks/registration.ipynb +48 -0
  82. dsgrid/notebooks/start_notebook.sh +11 -0
  83. dsgrid/project.py +451 -0
  84. dsgrid/query/__init__.py +0 -0
  85. dsgrid/query/dataset_mapping_plan.py +142 -0
  86. dsgrid/query/derived_dataset.py +388 -0
  87. dsgrid/query/models.py +728 -0
  88. dsgrid/query/query_context.py +287 -0
  89. dsgrid/query/query_submitter.py +994 -0
  90. dsgrid/query/report_factory.py +19 -0
  91. dsgrid/query/report_peak_load.py +70 -0
  92. dsgrid/query/reports_base.py +20 -0
  93. dsgrid/registry/__init__.py +0 -0
  94. dsgrid/registry/bulk_register.py +165 -0
  95. dsgrid/registry/common.py +287 -0
  96. dsgrid/registry/config_update_checker_base.py +63 -0
  97. dsgrid/registry/data_store_factory.py +34 -0
  98. dsgrid/registry/data_store_interface.py +74 -0
  99. dsgrid/registry/dataset_config_generator.py +158 -0
  100. dsgrid/registry/dataset_registry_manager.py +950 -0
  101. dsgrid/registry/dataset_update_checker.py +16 -0
  102. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  103. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  104. dsgrid/registry/dimension_registry_manager.py +413 -0
  105. dsgrid/registry/dimension_update_checker.py +16 -0
  106. dsgrid/registry/duckdb_data_store.py +207 -0
  107. dsgrid/registry/filesystem_data_store.py +150 -0
  108. dsgrid/registry/filter_registry_manager.py +123 -0
  109. dsgrid/registry/project_config_generator.py +57 -0
  110. dsgrid/registry/project_registry_manager.py +1623 -0
  111. dsgrid/registry/project_update_checker.py +48 -0
  112. dsgrid/registry/registration_context.py +223 -0
  113. dsgrid/registry/registry_auto_updater.py +316 -0
  114. dsgrid/registry/registry_database.py +667 -0
  115. dsgrid/registry/registry_interface.py +446 -0
  116. dsgrid/registry/registry_manager.py +558 -0
  117. dsgrid/registry/registry_manager_base.py +367 -0
  118. dsgrid/registry/versioning.py +92 -0
  119. dsgrid/rust_ext/__init__.py +14 -0
  120. dsgrid/rust_ext/find_minimal_patterns.py +129 -0
  121. dsgrid/spark/__init__.py +0 -0
  122. dsgrid/spark/functions.py +589 -0
  123. dsgrid/spark/types.py +110 -0
  124. dsgrid/tests/__init__.py +0 -0
  125. dsgrid/tests/common.py +140 -0
  126. dsgrid/tests/make_us_data_registry.py +265 -0
  127. dsgrid/tests/register_derived_datasets.py +103 -0
  128. dsgrid/tests/utils.py +25 -0
  129. dsgrid/time/__init__.py +0 -0
  130. dsgrid/time/time_conversions.py +80 -0
  131. dsgrid/time/types.py +67 -0
  132. dsgrid/units/__init__.py +0 -0
  133. dsgrid/units/constants.py +113 -0
  134. dsgrid/units/convert.py +71 -0
  135. dsgrid/units/energy.py +145 -0
  136. dsgrid/units/power.py +87 -0
  137. dsgrid/utils/__init__.py +0 -0
  138. dsgrid/utils/dataset.py +830 -0
  139. dsgrid/utils/files.py +179 -0
  140. dsgrid/utils/filters.py +125 -0
  141. dsgrid/utils/id_remappings.py +100 -0
  142. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  143. dsgrid/utils/py_expression_eval/README.md +8 -0
  144. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  145. dsgrid/utils/py_expression_eval/tests.py +283 -0
  146. dsgrid/utils/run_command.py +70 -0
  147. dsgrid/utils/scratch_dir_context.py +65 -0
  148. dsgrid/utils/spark.py +918 -0
  149. dsgrid/utils/spark_partition.py +98 -0
  150. dsgrid/utils/timing.py +239 -0
  151. dsgrid/utils/utilities.py +221 -0
  152. dsgrid/utils/versioning.py +36 -0
  153. dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
  154. dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
  155. dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
  156. dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
  157. dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,209 @@
1
+ import logging
2
+ from typing import Self
3
+
4
+ from dsgrid.common import TIME_ZONE_COLUMN, VALUE_COLUMN
5
+ from dsgrid.config.dataset_config import DatasetConfig
6
+ from dsgrid.config.project_config import ProjectConfig
7
+ from dsgrid.config.simple_models import DimensionSimpleModel
8
+ from dsgrid.config.time_dimension_base_config import TimeDimensionBaseConfig
9
+ from dsgrid.dataset.models import ValueFormat
10
+ from dsgrid.query.models import DatasetQueryModel
11
+ from dsgrid.registry.data_store_interface import DataStoreInterface
12
+ from dsgrid.spark.types import (
13
+ DataFrame,
14
+ StringType,
15
+ )
16
+ from dsgrid.utils.dataset import (
17
+ convert_types_if_necessary,
18
+ )
19
+ from dsgrid.config.file_schema import read_data_file
20
+ from dsgrid.utils.scratch_dir_context import ScratchDirContext
21
+ from dsgrid.utils.spark import check_for_nulls
22
+ from dsgrid.utils.timing import timer_stats_collector, track_timing
23
+ from dsgrid.dataset.dataset_schema_handler_base import DatasetSchemaHandlerBase
24
+ from dsgrid.dimension.base_models import DatasetDimensionRequirements, DimensionType
25
+ from dsgrid.exceptions import DSGInvalidDataset
26
+ from dsgrid.query.query_context import QueryContext
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class OneTableDatasetSchemaHandler(DatasetSchemaHandlerBase):
32
+ """define interface/required behaviors for ONE_TABLE dataset schema"""
33
+
34
+ def __init__(self, load_data_df, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+ self._load_data = load_data_df
37
+
38
+ @classmethod
39
+ def load(
40
+ cls,
41
+ config: DatasetConfig,
42
+ *args,
43
+ store: DataStoreInterface | None = None,
44
+ scratch_dir_context: ScratchDirContext | None = None,
45
+ **kwargs,
46
+ ) -> Self:
47
+ if store is None:
48
+ if config.data_file_schema is None:
49
+ msg = "Cannot load dataset without data file schema or store"
50
+ raise DSGInvalidDataset(msg)
51
+ df = read_data_file(config.data_file_schema, scratch_dir_context=scratch_dir_context)
52
+ else:
53
+ df = store.read_table(config.model.dataset_id, config.model.version)
54
+ load_data_df = config.add_trivial_dimensions(df)
55
+ load_data_df = convert_types_if_necessary(load_data_df)
56
+ return cls(load_data_df, config, *args, **kwargs)
57
+
58
+ @track_timing(timer_stats_collector)
59
+ def check_consistency(
60
+ self,
61
+ missing_dimension_associations: dict[str, DataFrame],
62
+ scratch_dir_context: ScratchDirContext,
63
+ requirements: DatasetDimensionRequirements,
64
+ ) -> None:
65
+ self._check_one_table_data_consistency()
66
+ self._check_dimension_associations(
67
+ missing_dimension_associations, scratch_dir_context, requirements
68
+ )
69
+
70
+ @track_timing(timer_stats_collector)
71
+ def check_time_consistency(self):
72
+ time_dim = self._config.get_time_dimension()
73
+ if time_dim is not None:
74
+ if time_dim.supports_chronify():
75
+ self._check_dataset_time_consistency_with_chronify()
76
+ else:
77
+ self._check_dataset_time_consistency(self._load_data)
78
+
79
+ @track_timing(timer_stats_collector)
80
+ def _check_one_table_data_consistency(self):
81
+ """Dimension check in load_data, excludes time:
82
+ * check that data matches record for each dimension.
83
+ * check that all data dimension combinations exist. Time is handled separately.
84
+ * Check for any NULL values in dimension columns.
85
+ """
86
+ logger.info("Check one table dataset consistency.")
87
+ dimension_types = set()
88
+ time_dim = self._config.get_time_dimension()
89
+ time_columns: set[str] = set()
90
+ if time_dim is not None:
91
+ time_columns = set(time_dim.get_load_data_time_columns())
92
+ assert (
93
+ self._config.get_value_format() == ValueFormat.STACKED
94
+ ), self._config.get_value_format()
95
+ self._check_load_data_unpivoted_value_column(self._load_data)
96
+ allowed_columns = DimensionType.get_allowed_dimension_column_names().union(time_columns)
97
+ allowed_columns.add(VALUE_COLUMN)
98
+ allowed_columns.add(TIME_ZONE_COLUMN)
99
+
100
+ schema = self._load_data.schema
101
+ for column in self._load_data.columns:
102
+ if column not in allowed_columns:
103
+ msg = f"{column=} is not expected in load_data"
104
+ raise DSGInvalidDataset(msg)
105
+ if not (
106
+ column in time_columns or column == VALUE_COLUMN or column == TIME_ZONE_COLUMN
107
+ ):
108
+ dim_type = DimensionType.from_column(column)
109
+ if schema[column].dataType != StringType():
110
+ msg = f"dimension column {column} must have data type = StringType"
111
+ raise DSGInvalidDataset(msg)
112
+ dimension_types.add(dim_type)
113
+ check_for_nulls(self._load_data)
114
+
115
+ def get_base_load_data_table(self) -> DataFrame:
116
+ return self._load_data
117
+
118
+ def _get_load_data_table(self) -> DataFrame:
119
+ return self._load_data
120
+
121
+ @track_timing(timer_stats_collector)
122
+ def filter_data(self, dimensions: list[DimensionSimpleModel], store: DataStoreInterface):
123
+ assert (
124
+ self._config.get_value_format() == ValueFormat.STACKED
125
+ ), self._config.get_value_format()
126
+ load_df = self._load_data
127
+ df_columns = set(load_df.columns)
128
+ stacked_columns = set()
129
+ for dim in dimensions:
130
+ column = dim.dimension_type.value
131
+ if column in df_columns:
132
+ load_df = load_df.filter(load_df[column].isin(dim.record_ids))
133
+ stacked_columns.add(column)
134
+
135
+ drop_columns = []
136
+ for dim in self._config.model.trivial_dimensions:
137
+ col = dim.value
138
+ count = load_df.select(col).distinct().count()
139
+ assert count == 1, f"{dim}: {count}"
140
+ drop_columns.append(col)
141
+ load_df = load_df.drop(*drop_columns)
142
+
143
+ store.replace_table(load_df, self.dataset_id, self._config.model.version)
144
+ logger.info("Rewrote simplified %s", self._config.model.dataset_id)
145
+
146
+ def make_project_dataframe(
147
+ self, context: QueryContext, project_config: ProjectConfig
148
+ ) -> DataFrame:
149
+ plan = context.model.project.get_dataset_mapping_plan(self.dataset_id)
150
+ if plan is None:
151
+ plan = self.build_default_dataset_mapping_plan()
152
+ with context.dataset_mapping_manager(self.dataset_id, plan) as mapping_manager:
153
+ ld_df = mapping_manager.try_read_checkpointed_table()
154
+ if ld_df is None:
155
+ ld_df = self._load_data
156
+ ld_df = self._prefilter_stacked_dimensions(context, ld_df)
157
+ ld_df = self._prefilter_time_dimension(context, ld_df)
158
+
159
+ ld_df = self._remap_dimension_columns(
160
+ ld_df,
161
+ mapping_manager,
162
+ filtered_records=context.get_record_ids(),
163
+ )
164
+ ld_df = self._apply_fraction(ld_df, {VALUE_COLUMN}, mapping_manager)
165
+ project_metric_records = self._get_project_metric_records(project_config)
166
+ ld_df = self._convert_units(ld_df, project_metric_records, mapping_manager)
167
+ input_dataset = project_config.get_dataset(self._config.model.dataset_id)
168
+ ld_df = self._convert_time_dimension(
169
+ load_data_df=ld_df,
170
+ to_time_dim=project_config.get_base_time_dimension(),
171
+ value_column=VALUE_COLUMN,
172
+ mapping_manager=mapping_manager,
173
+ wrap_time_allowed=input_dataset.wrap_time_allowed,
174
+ time_based_data_adjustment=input_dataset.time_based_data_adjustment,
175
+ to_geo_dim=project_config.get_base_dimension(DimensionType.GEOGRAPHY),
176
+ )
177
+ return self._finalize_table(context, ld_df, project_config)
178
+
179
+ def make_mapped_dataframe(
180
+ self, context: QueryContext, time_dimension: TimeDimensionBaseConfig | None = None
181
+ ) -> DataFrame:
182
+ query = context.model
183
+ assert isinstance(query, DatasetQueryModel)
184
+ plan = query.mapping_plan
185
+ if plan is None:
186
+ plan = self.build_default_dataset_mapping_plan()
187
+ geography_dimension = self._get_mapping_to_dimension(DimensionType.GEOGRAPHY)
188
+ metric_dimension = self._get_mapping_to_dimension(DimensionType.METRIC)
189
+ with context.dataset_mapping_manager(self.dataset_id, plan) as mapping_manager:
190
+ ld_df = mapping_manager.try_read_checkpointed_table()
191
+ if ld_df is None:
192
+ ld_df = self._load_data
193
+
194
+ ld_df = self._remap_dimension_columns(ld_df, mapping_manager)
195
+ ld_df = self._apply_fraction(ld_df, {VALUE_COLUMN}, mapping_manager)
196
+ if metric_dimension is not None:
197
+ metric_records = metric_dimension.get_records_dataframe()
198
+ ld_df = self._convert_units(ld_df, metric_records, mapping_manager)
199
+ if time_dimension is not None:
200
+ ld_df = self._convert_time_dimension(
201
+ load_data_df=ld_df,
202
+ to_time_dim=time_dimension,
203
+ value_column=VALUE_COLUMN,
204
+ mapping_manager=mapping_manager,
205
+ wrap_time_allowed=query.wrap_time_allowed,
206
+ time_based_data_adjustment=query.time_based_data_adjustment,
207
+ to_geo_dim=geography_dimension,
208
+ )
209
+ return ld_df
@@ -0,0 +1,322 @@
1
+ import logging
2
+ from typing import Self
3
+
4
+ from dsgrid.common import SCALING_FACTOR_COLUMN, VALUE_COLUMN
5
+ from dsgrid.config.dataset_config import DatasetConfig
6
+ from dsgrid.config.project_config import ProjectConfig
7
+ from dsgrid.config.simple_models import DimensionSimpleModel
8
+ from dsgrid.config.time_dimension_base_config import TimeDimensionBaseConfig
9
+ from dsgrid.dataset.models import ValueFormat
10
+ from dsgrid.dataset.dataset_schema_handler_base import DatasetSchemaHandlerBase
11
+ from dsgrid.dimension.base_models import DatasetDimensionRequirements, DimensionType
12
+ from dsgrid.exceptions import DSGInvalidDataset
13
+ from dsgrid.query.models import DatasetQueryModel
14
+ from dsgrid.query.query_context import QueryContext
15
+ from dsgrid.registry.data_store_interface import DataStoreInterface
16
+ from dsgrid.spark.functions import (
17
+ cache,
18
+ coalesce,
19
+ collect_list,
20
+ except_all,
21
+ intersect,
22
+ unpersist,
23
+ )
24
+ from dsgrid.spark.types import (
25
+ DataFrame,
26
+ StringType,
27
+ )
28
+ from dsgrid.utils.dataset import (
29
+ apply_scaling_factor,
30
+ convert_types_if_necessary,
31
+ )
32
+ from dsgrid.config.file_schema import read_data_file
33
+ from dsgrid.utils.scratch_dir_context import ScratchDirContext
34
+ from dsgrid.utils.spark import check_for_nulls
35
+ from dsgrid.utils.timing import Timer, timer_stats_collector, track_timing
36
+
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ class TwoTableDatasetSchemaHandler(DatasetSchemaHandlerBase):
42
+ """Handler for TWO_TABLE dataset format (load_data + load_data_lookup tables)."""
43
+
44
+ def __init__(self, load_data_df, load_data_lookup, *args, **kwargs):
45
+ super().__init__(*args, **kwargs)
46
+ self._load_data = load_data_df
47
+ self._load_data_lookup = load_data_lookup
48
+
49
+ @classmethod
50
+ def load(
51
+ cls,
52
+ config: DatasetConfig,
53
+ *args,
54
+ store: DataStoreInterface | None = None,
55
+ scratch_dir_context: ScratchDirContext | None = None,
56
+ **kwargs,
57
+ ) -> Self:
58
+ if store is None:
59
+ if config.data_file_schema is None:
60
+ msg = "Cannot load dataset without data file schema or store"
61
+ raise DSGInvalidDataset(msg)
62
+ if config.lookup_file_schema is None:
63
+ msg = "TWO_TABLE format requires lookup_data_file"
64
+ raise DSGInvalidDataset(msg)
65
+ load_data_df = read_data_file(
66
+ config.data_file_schema, scratch_dir_context=scratch_dir_context
67
+ )
68
+ load_data_lookup = read_data_file(
69
+ config.lookup_file_schema, scratch_dir_context=scratch_dir_context
70
+ )
71
+ else:
72
+ load_data_df = store.read_table(config.model.dataset_id, config.model.version)
73
+ load_data_lookup = store.read_lookup_table(
74
+ config.model.dataset_id, config.model.version
75
+ )
76
+
77
+ load_data_df = convert_types_if_necessary(load_data_df)
78
+ load_data_lookup = config.add_trivial_dimensions(load_data_lookup)
79
+ load_data_lookup = convert_types_if_necessary(load_data_lookup)
80
+ return cls(load_data_df, load_data_lookup, config, *args, **kwargs)
81
+
82
+ @track_timing(timer_stats_collector)
83
+ def check_consistency(
84
+ self,
85
+ missing_dimension_associations: dict[str, DataFrame],
86
+ scratch_dir_context: ScratchDirContext,
87
+ requirements: DatasetDimensionRequirements,
88
+ ) -> None:
89
+ self._check_lookup_data_consistency()
90
+ self._check_dataset_internal_consistency()
91
+ self._check_dimension_associations(
92
+ missing_dimension_associations, scratch_dir_context, requirements
93
+ )
94
+
95
+ @track_timing(timer_stats_collector)
96
+ def check_time_consistency(self):
97
+ time_dim = self._config.get_time_dimension()
98
+ if time_dim is None:
99
+ return None
100
+
101
+ if time_dim.supports_chronify():
102
+ self._check_dataset_time_consistency_with_chronify()
103
+ else:
104
+ self._check_dataset_time_consistency(self._get_load_data_table())
105
+
106
+ def get_base_load_data_table(self) -> DataFrame:
107
+ return self._load_data
108
+
109
+ def _get_load_data_table(self) -> DataFrame:
110
+ return self._load_data.join(self._load_data_lookup, on="id")
111
+
112
+ def make_project_dataframe(
113
+ self, context: QueryContext, project_config: ProjectConfig
114
+ ) -> DataFrame:
115
+ lk_df = self._load_data_lookup
116
+ lk_df = self._prefilter_stacked_dimensions(context, lk_df)
117
+
118
+ plan = context.model.project.get_dataset_mapping_plan(self.dataset_id)
119
+ if plan is None:
120
+ plan = self.build_default_dataset_mapping_plan()
121
+ with context.dataset_mapping_manager(self.dataset_id, plan) as mapping_manager:
122
+ ld_df = mapping_manager.try_read_checkpointed_table()
123
+ if ld_df is None:
124
+ ld_df = self._load_data
125
+ ld_df = self._prefilter_stacked_dimensions(context, ld_df)
126
+ ld_df = self._prefilter_time_dimension(context, ld_df)
127
+ ld_df = ld_df.join(lk_df, on="id").drop("id")
128
+
129
+ ld_df = self._remap_dimension_columns(
130
+ ld_df,
131
+ mapping_manager,
132
+ filtered_records=context.get_record_ids(),
133
+ )
134
+ if SCALING_FACTOR_COLUMN in ld_df.columns:
135
+ ld_df = apply_scaling_factor(ld_df, VALUE_COLUMN, mapping_manager)
136
+
137
+ ld_df = self._apply_fraction(ld_df, {VALUE_COLUMN}, mapping_manager)
138
+ project_metric_records = self._get_project_metric_records(project_config)
139
+ ld_df = self._convert_units(ld_df, project_metric_records, mapping_manager)
140
+ input_dataset = project_config.get_dataset(self._config.model.dataset_id)
141
+ ld_df = self._convert_time_dimension(
142
+ load_data_df=ld_df,
143
+ to_time_dim=project_config.get_base_time_dimension(),
144
+ value_column=VALUE_COLUMN,
145
+ mapping_manager=mapping_manager,
146
+ wrap_time_allowed=input_dataset.wrap_time_allowed,
147
+ time_based_data_adjustment=input_dataset.time_based_data_adjustment,
148
+ to_geo_dim=project_config.get_base_dimension(DimensionType.GEOGRAPHY),
149
+ )
150
+ return self._finalize_table(context, ld_df, project_config)
151
+
152
+ def make_mapped_dataframe(
153
+ self,
154
+ context: QueryContext,
155
+ time_dimension: TimeDimensionBaseConfig | None = None,
156
+ ) -> DataFrame:
157
+ query = context.model
158
+ assert isinstance(query, DatasetQueryModel)
159
+ plan = query.mapping_plan
160
+ if plan is None:
161
+ plan = self.build_default_dataset_mapping_plan()
162
+ geography_dimension = self._get_mapping_to_dimension(DimensionType.GEOGRAPHY)
163
+ metric_dimension = self._get_mapping_to_dimension(DimensionType.METRIC)
164
+ with context.dataset_mapping_manager(self.dataset_id, plan) as mapping_manager:
165
+ ld_df = mapping_manager.try_read_checkpointed_table()
166
+ if ld_df is None:
167
+ ld_df = self._load_data
168
+ lk_df = self._load_data_lookup
169
+ ld_df = ld_df.join(lk_df, on="id").drop("id")
170
+
171
+ ld_df = self._remap_dimension_columns(
172
+ ld_df,
173
+ mapping_manager,
174
+ )
175
+ if SCALING_FACTOR_COLUMN in ld_df.columns:
176
+ ld_df = apply_scaling_factor(ld_df, VALUE_COLUMN, mapping_manager)
177
+
178
+ ld_df = self._apply_fraction(ld_df, {VALUE_COLUMN}, mapping_manager)
179
+ if metric_dimension is not None:
180
+ metric_records = metric_dimension.get_records_dataframe()
181
+ ld_df = self._convert_units(ld_df, metric_records, mapping_manager)
182
+ if time_dimension is not None:
183
+ ld_df = self._convert_time_dimension(
184
+ load_data_df=ld_df,
185
+ to_time_dim=time_dimension,
186
+ value_column=VALUE_COLUMN,
187
+ mapping_manager=mapping_manager,
188
+ wrap_time_allowed=query.wrap_time_allowed,
189
+ time_based_data_adjustment=query.time_based_data_adjustment,
190
+ to_geo_dim=geography_dimension,
191
+ )
192
+ return ld_df
193
+
194
+ @track_timing(timer_stats_collector)
195
+ def _check_lookup_data_consistency(self):
196
+ """Dimension check in load_data_lookup, excludes time.
197
+
198
+ Checks:
199
+ - Data matches record for each dimension.
200
+ - All data dimension combinations exist. Time is handled separately.
201
+ - No NULL values in dimension columns.
202
+ """
203
+ logger.info("Check lookup data consistency.")
204
+ found_id = False
205
+ dimension_types = set()
206
+ for col in self._load_data_lookup.columns:
207
+ if col == "id":
208
+ found_id = True
209
+ continue
210
+ if col == SCALING_FACTOR_COLUMN:
211
+ continue
212
+ if self._load_data_lookup.schema[col].dataType != StringType():
213
+ msg = f"dimension column {col} must have data type = StringType"
214
+ raise DSGInvalidDataset(msg)
215
+ dimension_types.add(DimensionType.from_column(col))
216
+
217
+ if not found_id:
218
+ msg = "load_data_lookup does not include an 'id' column"
219
+ raise DSGInvalidDataset(msg)
220
+
221
+ check_for_nulls(self._load_data_lookup)
222
+ load_data_dimensions = set(self._list_dimension_types_in_load_data(self._load_data))
223
+ expected_dimensions = {
224
+ d
225
+ for d in DimensionType.get_dimension_types_allowed_as_columns()
226
+ if d not in load_data_dimensions
227
+ }
228
+ missing_dimensions = expected_dimensions.difference(dimension_types)
229
+ if missing_dimensions:
230
+ msg = (
231
+ f"load_data_lookup is missing dimensions: {missing_dimensions}. "
232
+ "If these are trivial dimensions, make sure to specify them in the Dataset Config."
233
+ )
234
+
235
+ @track_timing(timer_stats_collector)
236
+ def _check_dataset_internal_consistency(self):
237
+ """Check load_data dimensions and id series."""
238
+ logger.info("Check dataset internal consistency.")
239
+ assert (
240
+ self._config.get_value_format() == ValueFormat.STACKED
241
+ ), self._config.get_value_format()
242
+ self._check_load_data_unpivoted_value_column(self._load_data)
243
+
244
+ time_dim = self._config.get_time_dimension()
245
+ time_columns: set[str] = set()
246
+ if time_dim is not None:
247
+ time_columns = set(time_dim.get_load_data_time_columns())
248
+ allowed_columns = (
249
+ DimensionType.get_allowed_dimension_column_names()
250
+ .union(time_columns)
251
+ .union({VALUE_COLUMN, "id", "scaling_factor"})
252
+ )
253
+
254
+ found_id = False
255
+ for column in self._load_data.columns:
256
+ if column not in allowed_columns:
257
+ msg = f"{column=} is not expected in load_data"
258
+ raise DSGInvalidDataset(msg)
259
+ if column == "id":
260
+ found_id = True
261
+
262
+ if not found_id:
263
+ msg = "load_data does not include an 'id' column"
264
+ raise DSGInvalidDataset(msg)
265
+
266
+ check_for_nulls(self._load_data)
267
+ ld_ids = self._load_data.select("id").distinct()
268
+ ldl_ids = self._load_data_lookup.select("id").distinct()
269
+ ldl_id_count = ldl_ids.count()
270
+ data_id_count = ld_ids.count()
271
+ joined = ld_ids.join(ldl_ids, on="id")
272
+ count = joined.count()
273
+
274
+ if data_id_count != count or ldl_id_count != count:
275
+ with Timer(timer_stats_collector, "show load_data and load_data_lookup ID diff"):
276
+ diff = except_all(ld_ids.unionAll(ldl_ids), intersect(ld_ids, ldl_ids))
277
+ # Only run the query once (with Spark). Number of rows shouldn't be a problem.
278
+ cache(diff)
279
+ diff_count = diff.count()
280
+ limit = 100
281
+ diff_list = diff.limit(limit).collect()
282
+ unpersist(diff)
283
+ logger.error(
284
+ "load_data and load_data_lookup have %s different IDs. Limited to %s: %s",
285
+ diff_count,
286
+ limit,
287
+ diff_list,
288
+ )
289
+ msg = f"Data IDs for {self._config.config_id} data/lookup are inconsistent"
290
+ raise DSGInvalidDataset(msg)
291
+
292
+ @track_timing(timer_stats_collector)
293
+ def filter_data(self, dimensions: list[DimensionSimpleModel], store: DataStoreInterface):
294
+ lookup = self._load_data_lookup
295
+ cache(lookup)
296
+ load_df = self._load_data
297
+ lookup_columns = set(lookup.columns)
298
+ for dim in dimensions:
299
+ column = dim.dimension_type.value
300
+ if column in lookup_columns:
301
+ lookup = lookup.filter(lookup[column].isin(dim.record_ids))
302
+
303
+ drop_columns = []
304
+ for dim in self._config.model.trivial_dimensions:
305
+ col = dim.value
306
+ count = lookup.select(col).distinct().count()
307
+ assert count == 1, f"{dim}: count"
308
+ drop_columns.append(col)
309
+ lookup = lookup.drop(*drop_columns)
310
+
311
+ lookup2 = coalesce(lookup, 1)
312
+ store.replace_lookup_table(lookup2, self.dataset_id, self._config.model.version)
313
+ ids = collect_list(lookup2.select("id").distinct(), "id")
314
+ load_df = self._load_data.filter(self._load_data.id.isin(ids))
315
+ ld_columns = set(load_df.columns)
316
+ for dim in dimensions:
317
+ column = dim.dimension_type.value
318
+ if column in ld_columns:
319
+ load_df = load_df.filter(load_df[column].isin(dim.record_ids))
320
+
321
+ store.replace_table(load_df, self.dataset_id, self._config.model.version)
322
+ logger.info("Rewrote simplified %s", self._config.model.dataset_id)
@@ -0,0 +1,162 @@
1
+ import logging
2
+
3
+ from dsgrid.exceptions import DSGInvalidQuery
4
+ from dsgrid.query.models import ProjectionDatasetModel
5
+ from dsgrid.spark.functions import cross_join, join_multiple_columns, sql_from_df
6
+ from dsgrid.spark.types import DataFrame, F, IntegerType, use_duckdb
7
+ from dsgrid.utils.spark import get_unique_values
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def apply_exponential_growth_rate(
14
+ dataset: ProjectionDatasetModel,
15
+ initial_value_df: DataFrame,
16
+ growth_rate_df: DataFrame,
17
+ time_columns,
18
+ model_year_column,
19
+ value_columns,
20
+ ):
21
+ """Applies exponential growth rate to the initial_value dataframe as follows:
22
+ P(t) = P0*(1+r)^(t-t0)
23
+ where:
24
+ P(t): quantity at t
25
+ P0: initial quantity at t0, = P(t0)
26
+ r: growth rate (per time interval)
27
+ t-t0: number of time intervals
28
+
29
+
30
+ Parameters
31
+ ----------
32
+ dataset : ProjectionDatasetModel
33
+ initial_value_df : pyspark.sql.DataFrame
34
+ growth_rate_df : pyspark.sql.DataFrame
35
+ time_columns : set[str]
36
+ model_year_column : str
37
+ value_columns : set[str]
38
+
39
+ Returns
40
+ -------
41
+ pyspark.sql.DataFrame
42
+
43
+ """
44
+
45
+ initial_value_df, growth_rate_df = _process_exponential_growth_rate(
46
+ dataset,
47
+ initial_value_df,
48
+ growth_rate_df,
49
+ model_year_column,
50
+ value_columns,
51
+ )
52
+
53
+ df = apply_annual_multiplier(
54
+ initial_value_df,
55
+ growth_rate_df,
56
+ time_columns,
57
+ value_columns,
58
+ )
59
+
60
+ return df
61
+
62
+
63
+ def apply_annual_multiplier(
64
+ initial_value_df: DataFrame,
65
+ growth_rate_df: DataFrame,
66
+ time_columns,
67
+ value_columns,
68
+ ):
69
+ """Applies annual growth rate to the initial_value dataframe as follows:
70
+ P(t) = P0 * r(t)
71
+ where:
72
+ P(t): quantity at year t
73
+ P0: initial quantity
74
+ r(t): growth rate per year t (relative to P0)
75
+
76
+ Parameters
77
+ ----------
78
+ dataset : ProjectionDatasetModel
79
+ initial_value_df : pyspark.sql.DataFrame
80
+ growth_rate_df : pyspark.sql.DataFrame
81
+ time_columns : set[str]
82
+ value_columns : set[str]
83
+
84
+ Returns
85
+ -------
86
+ pyspark.sql.DataFrame
87
+
88
+ """
89
+
90
+ def renamed(col):
91
+ return col + "_gr"
92
+
93
+ orig_columns = initial_value_df.columns
94
+
95
+ dim_columns = set(initial_value_df.columns) - value_columns - time_columns
96
+ df = join_multiple_columns(initial_value_df, growth_rate_df, list(dim_columns))
97
+ for column in df.columns:
98
+ if column in value_columns:
99
+ gr_column = renamed(column)
100
+ df = df.withColumn(column, df[column] * df[gr_column])
101
+
102
+ return df.select(*orig_columns)
103
+
104
+
105
+ def _process_exponential_growth_rate(
106
+ dataset,
107
+ initial_value_df,
108
+ growth_rate_df,
109
+ model_year_column,
110
+ value_columns,
111
+ ):
112
+ def renamed(col):
113
+ return col + "_gr"
114
+
115
+ initial_value_df, base_year = _check_model_years(
116
+ dataset, initial_value_df, growth_rate_df, model_year_column
117
+ )
118
+
119
+ gr_df = growth_rate_df
120
+ for column in value_columns:
121
+ gr_col = renamed(column)
122
+ cols = ",".join([x for x in gr_df.columns if x not in (column, gr_col)])
123
+ if use_duckdb():
124
+ query = f"""
125
+ SELECT
126
+ {cols}
127
+ ,(1 + {column}) ** (CAST({model_year_column} AS INTEGER) - {base_year}) AS {gr_col}
128
+ """
129
+ gr_df = sql_from_df(gr_df, query)
130
+ else:
131
+ # Spark SQL uses POW instead of **, so keep the DataFrame API method.
132
+ gr_df = gr_df.withColumn(
133
+ gr_col,
134
+ F.pow(
135
+ (1 + F.col(column)), F.col(model_year_column).cast(IntegerType()) - base_year
136
+ ),
137
+ ).drop(column)
138
+
139
+ return initial_value_df, gr_df
140
+
141
+
142
+ def _check_model_years(dataset, initial_value_df, growth_rate_df, model_year_column):
143
+ iv_years = get_unique_values(initial_value_df, model_year_column)
144
+ iv_years_sorted = sorted((int(x) for x in iv_years))
145
+
146
+ if dataset.base_year is None:
147
+ base_year = iv_years_sorted[0]
148
+ elif dataset.base_year in iv_years:
149
+ base_year = dataset.base_year
150
+ else:
151
+ msg = f"ProjectionDatasetModel base_year={dataset.base_year} is not in {iv_years_sorted}"
152
+ raise DSGInvalidQuery(msg)
153
+
154
+ if len(iv_years) > 1:
155
+ # TODO #198: needs test case
156
+ initial_value_df = initial_value_df.filter(f"{model_year_column} == '{base_year}'")
157
+
158
+ initial_value_df = cross_join(
159
+ initial_value_df.drop(model_year_column),
160
+ growth_rate_df.select(model_year_column).distinct(),
161
+ )
162
+ return initial_value_df, base_year