dsgrid-toolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsgrid-toolkit might be problematic. Click here for more details.

Files changed (152) hide show
  1. dsgrid/__init__.py +22 -0
  2. dsgrid/api/__init__.py +0 -0
  3. dsgrid/api/api_manager.py +179 -0
  4. dsgrid/api/app.py +420 -0
  5. dsgrid/api/models.py +60 -0
  6. dsgrid/api/response_models.py +116 -0
  7. dsgrid/apps/__init__.py +0 -0
  8. dsgrid/apps/project_viewer/app.py +216 -0
  9. dsgrid/apps/registration_gui.py +444 -0
  10. dsgrid/chronify.py +22 -0
  11. dsgrid/cli/__init__.py +0 -0
  12. dsgrid/cli/common.py +120 -0
  13. dsgrid/cli/config.py +177 -0
  14. dsgrid/cli/download.py +13 -0
  15. dsgrid/cli/dsgrid.py +142 -0
  16. dsgrid/cli/dsgrid_admin.py +349 -0
  17. dsgrid/cli/install_notebooks.py +62 -0
  18. dsgrid/cli/query.py +711 -0
  19. dsgrid/cli/registry.py +1773 -0
  20. dsgrid/cloud/__init__.py +0 -0
  21. dsgrid/cloud/cloud_storage_interface.py +140 -0
  22. dsgrid/cloud/factory.py +31 -0
  23. dsgrid/cloud/fake_storage_interface.py +37 -0
  24. dsgrid/cloud/s3_storage_interface.py +156 -0
  25. dsgrid/common.py +35 -0
  26. dsgrid/config/__init__.py +0 -0
  27. dsgrid/config/annual_time_dimension_config.py +187 -0
  28. dsgrid/config/common.py +131 -0
  29. dsgrid/config/config_base.py +148 -0
  30. dsgrid/config/dataset_config.py +684 -0
  31. dsgrid/config/dataset_schema_handler_factory.py +41 -0
  32. dsgrid/config/date_time_dimension_config.py +108 -0
  33. dsgrid/config/dimension_config.py +54 -0
  34. dsgrid/config/dimension_config_factory.py +65 -0
  35. dsgrid/config/dimension_mapping_base.py +349 -0
  36. dsgrid/config/dimension_mappings_config.py +48 -0
  37. dsgrid/config/dimensions.py +775 -0
  38. dsgrid/config/dimensions_config.py +71 -0
  39. dsgrid/config/index_time_dimension_config.py +76 -0
  40. dsgrid/config/input_dataset_requirements.py +31 -0
  41. dsgrid/config/mapping_tables.py +209 -0
  42. dsgrid/config/noop_time_dimension_config.py +42 -0
  43. dsgrid/config/project_config.py +1457 -0
  44. dsgrid/config/registration_models.py +199 -0
  45. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  46. dsgrid/config/simple_models.py +49 -0
  47. dsgrid/config/supplemental_dimension.py +29 -0
  48. dsgrid/config/time_dimension_base_config.py +200 -0
  49. dsgrid/data_models.py +155 -0
  50. dsgrid/dataset/__init__.py +0 -0
  51. dsgrid/dataset/dataset.py +123 -0
  52. dsgrid/dataset/dataset_expression_handler.py +86 -0
  53. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  54. dsgrid/dataset/dataset_schema_handler_base.py +899 -0
  55. dsgrid/dataset/dataset_schema_handler_one_table.py +196 -0
  56. dsgrid/dataset/dataset_schema_handler_standard.py +303 -0
  57. dsgrid/dataset/growth_rates.py +162 -0
  58. dsgrid/dataset/models.py +44 -0
  59. dsgrid/dataset/table_format_handler_base.py +257 -0
  60. dsgrid/dataset/table_format_handler_factory.py +17 -0
  61. dsgrid/dataset/unpivoted_table.py +121 -0
  62. dsgrid/dimension/__init__.py +0 -0
  63. dsgrid/dimension/base_models.py +218 -0
  64. dsgrid/dimension/dimension_filters.py +308 -0
  65. dsgrid/dimension/standard.py +213 -0
  66. dsgrid/dimension/time.py +531 -0
  67. dsgrid/dimension/time_utils.py +88 -0
  68. dsgrid/dsgrid_rc.py +88 -0
  69. dsgrid/exceptions.py +105 -0
  70. dsgrid/filesystem/__init__.py +0 -0
  71. dsgrid/filesystem/cloud_filesystem.py +32 -0
  72. dsgrid/filesystem/factory.py +32 -0
  73. dsgrid/filesystem/filesystem_interface.py +136 -0
  74. dsgrid/filesystem/local_filesystem.py +74 -0
  75. dsgrid/filesystem/s3_filesystem.py +118 -0
  76. dsgrid/loggers.py +132 -0
  77. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +950 -0
  78. dsgrid/notebooks/registration.ipynb +48 -0
  79. dsgrid/notebooks/start_notebook.sh +11 -0
  80. dsgrid/project.py +451 -0
  81. dsgrid/query/__init__.py +0 -0
  82. dsgrid/query/dataset_mapping_plan.py +142 -0
  83. dsgrid/query/derived_dataset.py +384 -0
  84. dsgrid/query/models.py +726 -0
  85. dsgrid/query/query_context.py +287 -0
  86. dsgrid/query/query_submitter.py +847 -0
  87. dsgrid/query/report_factory.py +19 -0
  88. dsgrid/query/report_peak_load.py +70 -0
  89. dsgrid/query/reports_base.py +20 -0
  90. dsgrid/registry/__init__.py +0 -0
  91. dsgrid/registry/bulk_register.py +161 -0
  92. dsgrid/registry/common.py +287 -0
  93. dsgrid/registry/config_update_checker_base.py +63 -0
  94. dsgrid/registry/data_store_factory.py +34 -0
  95. dsgrid/registry/data_store_interface.py +69 -0
  96. dsgrid/registry/dataset_config_generator.py +156 -0
  97. dsgrid/registry/dataset_registry_manager.py +734 -0
  98. dsgrid/registry/dataset_update_checker.py +16 -0
  99. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  100. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  101. dsgrid/registry/dimension_registry_manager.py +413 -0
  102. dsgrid/registry/dimension_update_checker.py +16 -0
  103. dsgrid/registry/duckdb_data_store.py +185 -0
  104. dsgrid/registry/filesystem_data_store.py +141 -0
  105. dsgrid/registry/filter_registry_manager.py +123 -0
  106. dsgrid/registry/project_config_generator.py +57 -0
  107. dsgrid/registry/project_registry_manager.py +1616 -0
  108. dsgrid/registry/project_update_checker.py +48 -0
  109. dsgrid/registry/registration_context.py +223 -0
  110. dsgrid/registry/registry_auto_updater.py +316 -0
  111. dsgrid/registry/registry_database.py +662 -0
  112. dsgrid/registry/registry_interface.py +446 -0
  113. dsgrid/registry/registry_manager.py +544 -0
  114. dsgrid/registry/registry_manager_base.py +367 -0
  115. dsgrid/registry/versioning.py +92 -0
  116. dsgrid/spark/__init__.py +0 -0
  117. dsgrid/spark/functions.py +545 -0
  118. dsgrid/spark/types.py +50 -0
  119. dsgrid/tests/__init__.py +0 -0
  120. dsgrid/tests/common.py +139 -0
  121. dsgrid/tests/make_us_data_registry.py +204 -0
  122. dsgrid/tests/register_derived_datasets.py +103 -0
  123. dsgrid/tests/utils.py +25 -0
  124. dsgrid/time/__init__.py +0 -0
  125. dsgrid/time/time_conversions.py +80 -0
  126. dsgrid/time/types.py +67 -0
  127. dsgrid/units/__init__.py +0 -0
  128. dsgrid/units/constants.py +113 -0
  129. dsgrid/units/convert.py +71 -0
  130. dsgrid/units/energy.py +145 -0
  131. dsgrid/units/power.py +87 -0
  132. dsgrid/utils/__init__.py +0 -0
  133. dsgrid/utils/dataset.py +612 -0
  134. dsgrid/utils/files.py +179 -0
  135. dsgrid/utils/filters.py +125 -0
  136. dsgrid/utils/id_remappings.py +100 -0
  137. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  138. dsgrid/utils/py_expression_eval/README.md +8 -0
  139. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  140. dsgrid/utils/py_expression_eval/tests.py +283 -0
  141. dsgrid/utils/run_command.py +70 -0
  142. dsgrid/utils/scratch_dir_context.py +64 -0
  143. dsgrid/utils/spark.py +918 -0
  144. dsgrid/utils/spark_partition.py +98 -0
  145. dsgrid/utils/timing.py +239 -0
  146. dsgrid/utils/utilities.py +184 -0
  147. dsgrid/utils/versioning.py +36 -0
  148. dsgrid_toolkit-0.2.0.dist-info/METADATA +216 -0
  149. dsgrid_toolkit-0.2.0.dist-info/RECORD +152 -0
  150. dsgrid_toolkit-0.2.0.dist-info/WHEEL +4 -0
  151. dsgrid_toolkit-0.2.0.dist-info/entry_points.txt +4 -0
  152. dsgrid_toolkit-0.2.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,196 @@
1
+ import logging
2
+ from typing import Self
3
+
4
+ from dsgrid.common import VALUE_COLUMN
5
+ from dsgrid.config.dataset_config import DatasetConfig
6
+ from dsgrid.config.project_config import ProjectConfig
7
+ from dsgrid.config.simple_models import DimensionSimpleModel
8
+ from dsgrid.config.time_dimension_base_config import TimeDimensionBaseConfig
9
+ from dsgrid.dataset.models import TableFormatType
10
+ from dsgrid.query.models import DatasetQueryModel
11
+ from dsgrid.registry.data_store_interface import DataStoreInterface
12
+ from dsgrid.spark.types import (
13
+ DataFrame,
14
+ StringType,
15
+ )
16
+ from dsgrid.utils.dataset import (
17
+ convert_types_if_necessary,
18
+ )
19
+ from dsgrid.utils.spark import (
20
+ check_for_nulls,
21
+ read_dataframe,
22
+ )
23
+ from dsgrid.utils.timing import timer_stats_collector, track_timing
24
+ from dsgrid.dataset.dataset_schema_handler_base import DatasetSchemaHandlerBase
25
+ from dsgrid.dimension.base_models import DimensionType
26
+ from dsgrid.exceptions import DSGInvalidDataset
27
+ from dsgrid.query.query_context import QueryContext
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class OneTableDatasetSchemaHandler(DatasetSchemaHandlerBase):
33
+ """define interface/required behaviors for ONE_TABLE dataset schema"""
34
+
35
+ def __init__(self, load_data_df, *args, **kwargs):
36
+ super().__init__(*args, **kwargs)
37
+ self._load_data = load_data_df
38
+
39
+ @classmethod
40
+ def load(
41
+ cls,
42
+ config: DatasetConfig,
43
+ *args,
44
+ store: DataStoreInterface | None = None,
45
+ **kwargs,
46
+ ) -> Self:
47
+ if store is None:
48
+ df = read_dataframe(config.load_data_path)
49
+ else:
50
+ df = store.read_table(config.model.dataset_id, config.model.version)
51
+ load_data_df = config.add_trivial_dimensions(df)
52
+ load_data_df = convert_types_if_necessary(load_data_df)
53
+ time_dim = config.get_time_dimension()
54
+ if time_dim is not None:
55
+ load_data_df = time_dim.convert_time_format(load_data_df)
56
+ return cls(load_data_df, config, *args, **kwargs)
57
+
58
+ @track_timing(timer_stats_collector)
59
+ def check_consistency(self, missing_dimension_associations: DataFrame | None) -> None:
60
+ self._check_one_table_data_consistency()
61
+ self._check_dimension_associations(missing_dimension_associations)
62
+
63
+ @track_timing(timer_stats_collector)
64
+ def check_time_consistency(self):
65
+ time_dim = self._config.get_time_dimension()
66
+ if time_dim is not None:
67
+ if time_dim.supports_chronify():
68
+ self._check_dataset_time_consistency_with_chronify()
69
+ else:
70
+ self._check_dataset_time_consistency(self._load_data)
71
+
72
+ @track_timing(timer_stats_collector)
73
+ def _check_one_table_data_consistency(self):
74
+ """Dimension check in load_data, excludes time:
75
+ * check that data matches record for each dimension.
76
+ * check that all data dimension combinations exist. Time is handled separately.
77
+ * Check for any NULL values in dimension columns.
78
+ """
79
+ logger.info("Check one table dataset consistency.")
80
+ dimension_types = set()
81
+ time_dim = self._config.get_time_dimension()
82
+ time_columns: set[str] = set()
83
+ if time_dim is not None:
84
+ time_columns = set(time_dim.get_load_data_time_columns())
85
+ assert (
86
+ self._config.get_table_format_type() == TableFormatType.UNPIVOTED
87
+ ), self._config.get_table_format_type()
88
+ self._check_load_data_unpivoted_value_column(self._load_data)
89
+ allowed_columns = DimensionType.get_allowed_dimension_column_names().union(time_columns)
90
+ allowed_columns.add(VALUE_COLUMN)
91
+
92
+ schema = self._load_data.schema
93
+ for column in self._load_data.columns:
94
+ if column not in allowed_columns:
95
+ msg = f"{column=} is not expected in load_data"
96
+ raise DSGInvalidDataset(msg)
97
+ if not (column in time_columns or column == VALUE_COLUMN):
98
+ dim_type = DimensionType.from_column(column)
99
+ if schema[column].dataType != StringType():
100
+ msg = f"dimension column {column} must have data type = StringType"
101
+ raise DSGInvalidDataset(msg)
102
+ dimension_types.add(dim_type)
103
+ check_for_nulls(self._load_data)
104
+
105
+ def _get_load_data_table(self) -> DataFrame:
106
+ return self._load_data
107
+
108
+ @track_timing(timer_stats_collector)
109
+ def filter_data(self, dimensions: list[DimensionSimpleModel], store: DataStoreInterface):
110
+ assert (
111
+ self._config.get_table_format_type() == TableFormatType.UNPIVOTED
112
+ ), self._config.get_table_format_type()
113
+ load_df = self._load_data
114
+ df_columns = set(load_df.columns)
115
+ stacked_columns = set()
116
+ for dim in dimensions:
117
+ column = dim.dimension_type.value
118
+ if column in df_columns:
119
+ load_df = load_df.filter(load_df[column].isin(dim.record_ids))
120
+ stacked_columns.add(column)
121
+
122
+ drop_columns = []
123
+ for dim in self._config.model.trivial_dimensions:
124
+ col = dim.value
125
+ count = load_df.select(col).distinct().count()
126
+ assert count == 1, f"{dim}: {count}"
127
+ drop_columns.append(col)
128
+ load_df = load_df.drop(*drop_columns)
129
+
130
+ store.replace_table(load_df, self.dataset_id, self._config.model.version)
131
+ logger.info("Rewrote simplified %s", self._config.model.dataset_id)
132
+
133
+ def make_project_dataframe(
134
+ self, context: QueryContext, project_config: ProjectConfig
135
+ ) -> DataFrame:
136
+ plan = context.model.project.get_dataset_mapping_plan(self.dataset_id)
137
+ if plan is None:
138
+ plan = self.build_default_dataset_mapping_plan()
139
+ with context.dataset_mapping_manager(self.dataset_id, plan) as mapping_manager:
140
+ ld_df = mapping_manager.try_read_checkpointed_table()
141
+ if ld_df is None:
142
+ ld_df = self._load_data
143
+ ld_df = self._prefilter_stacked_dimensions(context, ld_df)
144
+ ld_df = self._prefilter_time_dimension(context, ld_df)
145
+
146
+ ld_df = self._remap_dimension_columns(
147
+ ld_df,
148
+ mapping_manager,
149
+ filtered_records=context.get_record_ids(),
150
+ )
151
+ ld_df = self._apply_fraction(ld_df, {VALUE_COLUMN}, mapping_manager)
152
+ project_metric_records = self._get_project_metric_records(project_config)
153
+ ld_df = self._convert_units(ld_df, project_metric_records, mapping_manager)
154
+ input_dataset = project_config.get_dataset(self._config.model.dataset_id)
155
+ ld_df = self._convert_time_dimension(
156
+ load_data_df=ld_df,
157
+ to_time_dim=project_config.get_base_time_dimension(),
158
+ value_column=VALUE_COLUMN,
159
+ mapping_manager=mapping_manager,
160
+ wrap_time_allowed=input_dataset.wrap_time_allowed,
161
+ time_based_data_adjustment=input_dataset.time_based_data_adjustment,
162
+ to_geo_dim=project_config.get_base_dimension(DimensionType.GEOGRAPHY),
163
+ )
164
+ return self._finalize_table(context, ld_df, project_config)
165
+
166
+ def make_mapped_dataframe(
167
+ self, context: QueryContext, time_dimension: TimeDimensionBaseConfig | None = None
168
+ ) -> DataFrame:
169
+ query = context.model
170
+ assert isinstance(query, DatasetQueryModel)
171
+ plan = query.mapping_plan
172
+ if plan is None:
173
+ plan = self.build_default_dataset_mapping_plan()
174
+ geography_dimension = self._get_mapping_to_dimension(DimensionType.GEOGRAPHY)
175
+ metric_dimension = self._get_mapping_to_dimension(DimensionType.METRIC)
176
+ with context.dataset_mapping_manager(self.dataset_id, plan) as mapping_manager:
177
+ ld_df = mapping_manager.try_read_checkpointed_table()
178
+ if ld_df is None:
179
+ ld_df = self._load_data
180
+
181
+ ld_df = self._remap_dimension_columns(ld_df, mapping_manager)
182
+ ld_df = self._apply_fraction(ld_df, {VALUE_COLUMN}, mapping_manager)
183
+ if metric_dimension is not None:
184
+ metric_records = metric_dimension.get_records_dataframe()
185
+ ld_df = self._convert_units(ld_df, metric_records, mapping_manager)
186
+ if time_dimension is not None:
187
+ ld_df = self._convert_time_dimension(
188
+ load_data_df=ld_df,
189
+ to_time_dim=time_dimension,
190
+ value_column=VALUE_COLUMN,
191
+ mapping_manager=mapping_manager,
192
+ wrap_time_allowed=query.wrap_time_allowed,
193
+ time_based_data_adjustment=query.time_based_data_adjustment,
194
+ to_geo_dim=geography_dimension,
195
+ )
196
+ return ld_df
@@ -0,0 +1,303 @@
1
+ import logging
2
+ from typing import Self
3
+
4
+ from dsgrid.common import SCALING_FACTOR_COLUMN, VALUE_COLUMN
5
+ from dsgrid.config.dataset_config import DatasetConfig
6
+ from dsgrid.config.project_config import ProjectConfig
7
+ from dsgrid.config.simple_models import DimensionSimpleModel
8
+ from dsgrid.config.time_dimension_base_config import TimeDimensionBaseConfig
9
+ from dsgrid.dataset.models import TableFormatType
10
+ from dsgrid.dataset.dataset_schema_handler_base import DatasetSchemaHandlerBase
11
+ from dsgrid.dimension.base_models import DimensionType
12
+ from dsgrid.exceptions import DSGInvalidDataset
13
+ from dsgrid.query.models import DatasetQueryModel
14
+ from dsgrid.query.query_context import QueryContext
15
+ from dsgrid.registry.data_store_interface import DataStoreInterface
16
+ from dsgrid.spark.functions import (
17
+ cache,
18
+ coalesce,
19
+ collect_list,
20
+ except_all,
21
+ intersect,
22
+ unpersist,
23
+ )
24
+ from dsgrid.spark.types import (
25
+ DataFrame,
26
+ StringType,
27
+ )
28
+ from dsgrid.utils.dataset import (
29
+ apply_scaling_factor,
30
+ convert_types_if_necessary,
31
+ )
32
+ from dsgrid.utils.spark import (
33
+ check_for_nulls,
34
+ read_dataframe,
35
+ )
36
+ from dsgrid.utils.timing import Timer, timer_stats_collector, track_timing
37
+
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class StandardDatasetSchemaHandler(DatasetSchemaHandlerBase):
43
+ """define interface/required behaviors for STANDARD dataset schema"""
44
+
45
+ def __init__(self, load_data_df, load_data_lookup, *args, **kwargs):
46
+ super().__init__(*args, **kwargs)
47
+ self._load_data = load_data_df
48
+ self._load_data_lookup = load_data_lookup
49
+
50
+ @classmethod
51
+ def load(
52
+ cls,
53
+ config: DatasetConfig,
54
+ *args,
55
+ store: DataStoreInterface | None = None,
56
+ **kwargs,
57
+ ) -> Self:
58
+ if store is None:
59
+ load_data_df = read_dataframe(config.load_data_path)
60
+ load_data_lookup = read_dataframe(config.load_data_lookup_path)
61
+ else:
62
+ load_data_df = store.read_table(config.model.dataset_id, config.model.version)
63
+ load_data_lookup = store.read_lookup_table(
64
+ config.model.dataset_id, config.model.version
65
+ )
66
+
67
+ load_data_df = convert_types_if_necessary(load_data_df)
68
+ time_dim = config.get_time_dimension()
69
+ if time_dim is not None:
70
+ load_data_df = time_dim.convert_time_format(load_data_df)
71
+ load_data_lookup = config.add_trivial_dimensions(load_data_lookup)
72
+ load_data_lookup = convert_types_if_necessary(load_data_lookup)
73
+ return cls(load_data_df, load_data_lookup, config, *args, **kwargs)
74
+
75
+ @track_timing(timer_stats_collector)
76
+ def check_consistency(self, missing_dimension_associations: DataFrame | None) -> None:
77
+ self._check_lookup_data_consistency()
78
+ self._check_dataset_internal_consistency()
79
+ self._check_dimension_associations(missing_dimension_associations)
80
+
81
+ @track_timing(timer_stats_collector)
82
+ def check_time_consistency(self):
83
+ time_dim = self._config.get_time_dimension()
84
+ if time_dim is None:
85
+ return None
86
+
87
+ if time_dim.supports_chronify():
88
+ self._check_dataset_time_consistency_with_chronify()
89
+ else:
90
+ self._check_dataset_time_consistency(self._get_load_data_table())
91
+
92
+ def _get_load_data_table(self) -> DataFrame:
93
+ return self._load_data.join(self._load_data_lookup, on="id")
94
+
95
+ def make_project_dataframe(
96
+ self, context: QueryContext, project_config: ProjectConfig
97
+ ) -> DataFrame:
98
+ lk_df = self._load_data_lookup
99
+ lk_df = self._prefilter_stacked_dimensions(context, lk_df)
100
+
101
+ plan = context.model.project.get_dataset_mapping_plan(self.dataset_id)
102
+ if plan is None:
103
+ plan = self.build_default_dataset_mapping_plan()
104
+ with context.dataset_mapping_manager(self.dataset_id, plan) as mapping_manager:
105
+ ld_df = mapping_manager.try_read_checkpointed_table()
106
+ if ld_df is None:
107
+ ld_df = self._load_data
108
+ ld_df = self._prefilter_stacked_dimensions(context, ld_df)
109
+ ld_df = self._prefilter_time_dimension(context, ld_df)
110
+ ld_df = ld_df.join(lk_df, on="id").drop("id")
111
+
112
+ ld_df = self._remap_dimension_columns(
113
+ ld_df,
114
+ mapping_manager,
115
+ filtered_records=context.get_record_ids(),
116
+ )
117
+ if SCALING_FACTOR_COLUMN in ld_df.columns:
118
+ ld_df = apply_scaling_factor(ld_df, VALUE_COLUMN, mapping_manager)
119
+
120
+ ld_df = self._apply_fraction(ld_df, {VALUE_COLUMN}, mapping_manager)
121
+ project_metric_records = self._get_project_metric_records(project_config)
122
+ ld_df = self._convert_units(ld_df, project_metric_records, mapping_manager)
123
+ input_dataset = project_config.get_dataset(self._config.model.dataset_id)
124
+ ld_df = self._convert_time_dimension(
125
+ load_data_df=ld_df,
126
+ to_time_dim=project_config.get_base_time_dimension(),
127
+ value_column=VALUE_COLUMN,
128
+ mapping_manager=mapping_manager,
129
+ wrap_time_allowed=input_dataset.wrap_time_allowed,
130
+ time_based_data_adjustment=input_dataset.time_based_data_adjustment,
131
+ to_geo_dim=project_config.get_base_dimension(DimensionType.GEOGRAPHY),
132
+ )
133
+ return self._finalize_table(context, ld_df, project_config)
134
+
135
+ def make_mapped_dataframe(
136
+ self,
137
+ context: QueryContext,
138
+ time_dimension: TimeDimensionBaseConfig | None = None,
139
+ ) -> DataFrame:
140
+ query = context.model
141
+ assert isinstance(query, DatasetQueryModel)
142
+ plan = query.mapping_plan
143
+ if plan is None:
144
+ plan = self.build_default_dataset_mapping_plan()
145
+ geography_dimension = self._get_mapping_to_dimension(DimensionType.GEOGRAPHY)
146
+ metric_dimension = self._get_mapping_to_dimension(DimensionType.METRIC)
147
+ with context.dataset_mapping_manager(self.dataset_id, plan) as mapping_manager:
148
+ ld_df = mapping_manager.try_read_checkpointed_table()
149
+ if ld_df is None:
150
+ ld_df = self._load_data
151
+ lk_df = self._load_data_lookup
152
+ ld_df = ld_df.join(lk_df, on="id").drop("id")
153
+
154
+ ld_df = self._remap_dimension_columns(
155
+ ld_df,
156
+ mapping_manager,
157
+ )
158
+ if SCALING_FACTOR_COLUMN in ld_df.columns:
159
+ ld_df = apply_scaling_factor(ld_df, VALUE_COLUMN, mapping_manager)
160
+
161
+ ld_df = self._apply_fraction(ld_df, {VALUE_COLUMN}, mapping_manager)
162
+ if metric_dimension is not None:
163
+ metric_records = metric_dimension.get_records_dataframe()
164
+ ld_df = self._convert_units(ld_df, metric_records, mapping_manager)
165
+ if time_dimension is not None:
166
+ ld_df = self._convert_time_dimension(
167
+ load_data_df=ld_df,
168
+ to_time_dim=time_dimension,
169
+ value_column=VALUE_COLUMN,
170
+ mapping_manager=mapping_manager,
171
+ wrap_time_allowed=query.wrap_time_allowed,
172
+ time_based_data_adjustment=query.time_based_data_adjustment,
173
+ to_geo_dim=geography_dimension,
174
+ )
175
+ return ld_df
176
+
177
+ @track_timing(timer_stats_collector)
178
+ def _check_lookup_data_consistency(self):
179
+ """Dimension check in load_data_lookup, excludes time:
180
+ * check that data matches record for each dimension.
181
+ * check that all data dimension combinations exist. Time is handled separately.
182
+ * Check for any NULL values in dimension columns.
183
+ """
184
+ logger.info("Check lookup data consistency.")
185
+ found_id = False
186
+ dimension_types = set()
187
+ for col in self._load_data_lookup.columns:
188
+ if col == "id":
189
+ found_id = True
190
+ continue
191
+ if col == SCALING_FACTOR_COLUMN:
192
+ continue
193
+ if self._load_data_lookup.schema[col].dataType != StringType():
194
+ msg = f"dimension column {col} must have data type = StringType"
195
+ raise DSGInvalidDataset(msg)
196
+ dimension_types.add(DimensionType.from_column(col))
197
+
198
+ if not found_id:
199
+ msg = "load_data_lookup does not include an 'id' column"
200
+ raise DSGInvalidDataset(msg)
201
+
202
+ check_for_nulls(self._load_data_lookup)
203
+ load_data_dimensions = set(self._list_dimension_types_in_load_data(self._load_data))
204
+ expected_dimensions = {
205
+ d
206
+ for d in DimensionType.get_dimension_types_allowed_as_columns()
207
+ if d not in load_data_dimensions
208
+ }
209
+ missing_dimensions = expected_dimensions.difference(dimension_types)
210
+ if missing_dimensions:
211
+ msg = (
212
+ f"load_data_lookup is missing dimensions: {missing_dimensions}. "
213
+ "If these are trivial dimensions, make sure to specify them in the Dataset Config."
214
+ )
215
+
216
+ @track_timing(timer_stats_collector)
217
+ def _check_dataset_internal_consistency(self):
218
+ """Check load_data dimensions and id series."""
219
+ logger.info("Check dataset internal consistency.")
220
+ assert (
221
+ self._config.get_table_format_type() == TableFormatType.UNPIVOTED
222
+ ), self._config.get_table_format_type()
223
+ self._check_load_data_unpivoted_value_column(self._load_data)
224
+
225
+ time_dim = self._config.get_time_dimension()
226
+ time_columns: set[str] = set()
227
+ if time_dim is not None:
228
+ time_columns = set(time_dim.get_load_data_time_columns())
229
+ allowed_columns = (
230
+ DimensionType.get_allowed_dimension_column_names()
231
+ .union(time_columns)
232
+ .union({VALUE_COLUMN, "id", "scaling_factor"})
233
+ )
234
+
235
+ found_id = False
236
+ for column in self._load_data.columns:
237
+ if column not in allowed_columns:
238
+ msg = f"{column=} is not expected in load_data"
239
+ raise DSGInvalidDataset(msg)
240
+ if column == "id":
241
+ found_id = True
242
+
243
+ if not found_id:
244
+ msg = "load_data does not include an 'id' column"
245
+ raise DSGInvalidDataset(msg)
246
+
247
+ check_for_nulls(self._load_data)
248
+ ld_ids = self._load_data.select("id").distinct()
249
+ ldl_ids = self._load_data_lookup.select("id").distinct()
250
+ ldl_id_count = ldl_ids.count()
251
+ data_id_count = ld_ids.count()
252
+ joined = ld_ids.join(ldl_ids, on="id")
253
+ count = joined.count()
254
+
255
+ if data_id_count != count or ldl_id_count != count:
256
+ with Timer(timer_stats_collector, "show load_data and load_data_lookup ID diff"):
257
+ diff = except_all(ld_ids.unionAll(ldl_ids), intersect(ld_ids, ldl_ids))
258
+ # Only run the query once (with Spark). Number of rows shouldn't be a problem.
259
+ cache(diff)
260
+ diff_count = diff.count()
261
+ limit = 100
262
+ diff_list = diff.limit(limit).collect()
263
+ unpersist(diff)
264
+ logger.error(
265
+ "load_data and load_data_lookup have %s different IDs. Limited to %s: %s",
266
+ diff_count,
267
+ limit,
268
+ diff_list,
269
+ )
270
+ msg = f"Data IDs for {self._config.config_id} data/lookup are inconsistent"
271
+ raise DSGInvalidDataset(msg)
272
+
273
+ @track_timing(timer_stats_collector)
274
+ def filter_data(self, dimensions: list[DimensionSimpleModel], store: DataStoreInterface):
275
+ lookup = self._load_data_lookup
276
+ cache(lookup)
277
+ load_df = self._load_data
278
+ lookup_columns = set(lookup.columns)
279
+ for dim in dimensions:
280
+ column = dim.dimension_type.value
281
+ if column in lookup_columns:
282
+ lookup = lookup.filter(lookup[column].isin(dim.record_ids))
283
+
284
+ drop_columns = []
285
+ for dim in self._config.model.trivial_dimensions:
286
+ col = dim.value
287
+ count = lookup.select(col).distinct().count()
288
+ assert count == 1, f"{dim}: count"
289
+ drop_columns.append(col)
290
+ lookup = lookup.drop(*drop_columns)
291
+
292
+ lookup2 = coalesce(lookup, 1)
293
+ store.replace_lookup_table(lookup2, self.dataset_id, self._config.model.version)
294
+ ids = collect_list(lookup2.select("id").distinct(), "id")
295
+ load_df = self._load_data.filter(self._load_data.id.isin(ids))
296
+ ld_columns = set(load_df.columns)
297
+ for dim in dimensions:
298
+ column = dim.dimension_type.value
299
+ if column in ld_columns:
300
+ load_df = load_df.filter(load_df[column].isin(dim.record_ids))
301
+
302
+ store.replace_table(load_df, self.dataset_id, self._config.model.version)
303
+ logger.info("Rewrote simplified %s", self._config.model.dataset_id)
@@ -0,0 +1,162 @@
1
+ import logging
2
+
3
+ from dsgrid.exceptions import DSGInvalidQuery
4
+ from dsgrid.query.models import ProjectionDatasetModel
5
+ from dsgrid.spark.functions import cross_join, join_multiple_columns, sql_from_df
6
+ from dsgrid.spark.types import DataFrame, F, IntegerType, use_duckdb
7
+ from dsgrid.utils.spark import get_unique_values
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def apply_exponential_growth_rate(
14
+ dataset: ProjectionDatasetModel,
15
+ initial_value_df: DataFrame,
16
+ growth_rate_df: DataFrame,
17
+ time_columns,
18
+ model_year_column,
19
+ value_columns,
20
+ ):
21
+ """Applies exponential growth rate to the initial_value dataframe as follows:
22
+ P(t) = P0*(1+r)^(t-t0)
23
+ where:
24
+ P(t): quantity at t
25
+ P0: initial quantity at t0, = P(t0)
26
+ r: growth rate (per time interval)
27
+ t-t0: number of time intervals
28
+
29
+
30
+ Parameters
31
+ ----------
32
+ dataset : ProjectionDatasetModel
33
+ initial_value_df : pyspark.sql.DataFrame
34
+ growth_rate_df : pyspark.sql.DataFrame
35
+ time_columns : set[str]
36
+ model_year_column : str
37
+ value_columns : set[str]
38
+
39
+ Returns
40
+ -------
41
+ pyspark.sql.DataFrame
42
+
43
+ """
44
+
45
+ initial_value_df, growth_rate_df = _process_exponential_growth_rate(
46
+ dataset,
47
+ initial_value_df,
48
+ growth_rate_df,
49
+ model_year_column,
50
+ value_columns,
51
+ )
52
+
53
+ df = apply_annual_multiplier(
54
+ initial_value_df,
55
+ growth_rate_df,
56
+ time_columns,
57
+ value_columns,
58
+ )
59
+
60
+ return df
61
+
62
+
63
+ def apply_annual_multiplier(
64
+ initial_value_df: DataFrame,
65
+ growth_rate_df: DataFrame,
66
+ time_columns,
67
+ value_columns,
68
+ ):
69
+ """Applies annual growth rate to the initial_value dataframe as follows:
70
+ P(t) = P0 * r(t)
71
+ where:
72
+ P(t): quantity at year t
73
+ P0: initial quantity
74
+ r(t): growth rate per year t (relative to P0)
75
+
76
+ Parameters
77
+ ----------
78
+ dataset : ProjectionDatasetModel
79
+ initial_value_df : pyspark.sql.DataFrame
80
+ growth_rate_df : pyspark.sql.DataFrame
81
+ time_columns : set[str]
82
+ value_columns : set[str]
83
+
84
+ Returns
85
+ -------
86
+ pyspark.sql.DataFrame
87
+
88
+ """
89
+
90
+ def renamed(col):
91
+ return col + "_gr"
92
+
93
+ orig_columns = initial_value_df.columns
94
+
95
+ dim_columns = set(initial_value_df.columns) - value_columns - time_columns
96
+ df = join_multiple_columns(initial_value_df, growth_rate_df, list(dim_columns))
97
+ for column in df.columns:
98
+ if column in value_columns:
99
+ gr_column = renamed(column)
100
+ df = df.withColumn(column, df[column] * df[gr_column])
101
+
102
+ return df.select(*orig_columns)
103
+
104
+
105
+ def _process_exponential_growth_rate(
106
+ dataset,
107
+ initial_value_df,
108
+ growth_rate_df,
109
+ model_year_column,
110
+ value_columns,
111
+ ):
112
+ def renamed(col):
113
+ return col + "_gr"
114
+
115
+ initial_value_df, base_year = _check_model_years(
116
+ dataset, initial_value_df, growth_rate_df, model_year_column
117
+ )
118
+
119
+ gr_df = growth_rate_df
120
+ for column in value_columns:
121
+ gr_col = renamed(column)
122
+ cols = ",".join([x for x in gr_df.columns if x not in (column, gr_col)])
123
+ if use_duckdb():
124
+ query = f"""
125
+ SELECT
126
+ {cols}
127
+ ,(1 + {column}) ** (CAST({model_year_column} AS INTEGER) - {base_year}) AS {gr_col}
128
+ """
129
+ gr_df = sql_from_df(gr_df, query)
130
+ else:
131
+ # Spark SQL uses POW instead of **, so keep the DataFrame API method.
132
+ gr_df = gr_df.withColumn(
133
+ gr_col,
134
+ F.pow(
135
+ (1 + F.col(column)), F.col(model_year_column).cast(IntegerType()) - base_year
136
+ ),
137
+ ).drop(column)
138
+
139
+ return initial_value_df, gr_df
140
+
141
+
142
+ def _check_model_years(dataset, initial_value_df, growth_rate_df, model_year_column):
143
+ iv_years = get_unique_values(initial_value_df, model_year_column)
144
+ iv_years_sorted = sorted((int(x) for x in iv_years))
145
+
146
+ if dataset.base_year is None:
147
+ base_year = iv_years_sorted[0]
148
+ elif dataset.base_year in iv_years:
149
+ base_year = dataset.base_year
150
+ else:
151
+ msg = f"ProjectionDatasetModel base_year={dataset.base_year} is not in {iv_years_sorted}"
152
+ raise DSGInvalidQuery(msg)
153
+
154
+ if len(iv_years) > 1:
155
+ # TODO #198: needs test case
156
+ initial_value_df = initial_value_df.filter(f"{model_year_column} == '{base_year}'")
157
+
158
+ initial_value_df = cross_join(
159
+ initial_value_df.drop(model_year_column),
160
+ growth_rate_df.select(model_year_column).distinct(),
161
+ )
162
+ return initial_value_df, base_year