dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- build_backend.py +93 -0
- dsgrid/__init__.py +22 -0
- dsgrid/api/__init__.py +0 -0
- dsgrid/api/api_manager.py +179 -0
- dsgrid/api/app.py +419 -0
- dsgrid/api/models.py +60 -0
- dsgrid/api/response_models.py +116 -0
- dsgrid/apps/__init__.py +0 -0
- dsgrid/apps/project_viewer/app.py +216 -0
- dsgrid/apps/registration_gui.py +444 -0
- dsgrid/chronify.py +32 -0
- dsgrid/cli/__init__.py +0 -0
- dsgrid/cli/common.py +120 -0
- dsgrid/cli/config.py +176 -0
- dsgrid/cli/download.py +13 -0
- dsgrid/cli/dsgrid.py +157 -0
- dsgrid/cli/dsgrid_admin.py +92 -0
- dsgrid/cli/install_notebooks.py +62 -0
- dsgrid/cli/query.py +729 -0
- dsgrid/cli/registry.py +1862 -0
- dsgrid/cloud/__init__.py +0 -0
- dsgrid/cloud/cloud_storage_interface.py +140 -0
- dsgrid/cloud/factory.py +31 -0
- dsgrid/cloud/fake_storage_interface.py +37 -0
- dsgrid/cloud/s3_storage_interface.py +156 -0
- dsgrid/common.py +36 -0
- dsgrid/config/__init__.py +0 -0
- dsgrid/config/annual_time_dimension_config.py +194 -0
- dsgrid/config/common.py +142 -0
- dsgrid/config/config_base.py +148 -0
- dsgrid/config/dataset_config.py +907 -0
- dsgrid/config/dataset_schema_handler_factory.py +46 -0
- dsgrid/config/date_time_dimension_config.py +136 -0
- dsgrid/config/dimension_config.py +54 -0
- dsgrid/config/dimension_config_factory.py +65 -0
- dsgrid/config/dimension_mapping_base.py +350 -0
- dsgrid/config/dimension_mappings_config.py +48 -0
- dsgrid/config/dimensions.py +1025 -0
- dsgrid/config/dimensions_config.py +71 -0
- dsgrid/config/file_schema.py +190 -0
- dsgrid/config/index_time_dimension_config.py +80 -0
- dsgrid/config/input_dataset_requirements.py +31 -0
- dsgrid/config/mapping_tables.py +209 -0
- dsgrid/config/noop_time_dimension_config.py +42 -0
- dsgrid/config/project_config.py +1462 -0
- dsgrid/config/registration_models.py +188 -0
- dsgrid/config/representative_period_time_dimension_config.py +194 -0
- dsgrid/config/simple_models.py +49 -0
- dsgrid/config/supplemental_dimension.py +29 -0
- dsgrid/config/time_dimension_base_config.py +192 -0
- dsgrid/data_models.py +155 -0
- dsgrid/dataset/__init__.py +0 -0
- dsgrid/dataset/dataset.py +123 -0
- dsgrid/dataset/dataset_expression_handler.py +86 -0
- dsgrid/dataset/dataset_mapping_manager.py +121 -0
- dsgrid/dataset/dataset_schema_handler_base.py +945 -0
- dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
- dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
- dsgrid/dataset/growth_rates.py +162 -0
- dsgrid/dataset/models.py +51 -0
- dsgrid/dataset/table_format_handler_base.py +257 -0
- dsgrid/dataset/table_format_handler_factory.py +17 -0
- dsgrid/dataset/unpivoted_table.py +121 -0
- dsgrid/dimension/__init__.py +0 -0
- dsgrid/dimension/base_models.py +230 -0
- dsgrid/dimension/dimension_filters.py +308 -0
- dsgrid/dimension/standard.py +252 -0
- dsgrid/dimension/time.py +352 -0
- dsgrid/dimension/time_utils.py +103 -0
- dsgrid/dsgrid_rc.py +88 -0
- dsgrid/exceptions.py +105 -0
- dsgrid/filesystem/__init__.py +0 -0
- dsgrid/filesystem/cloud_filesystem.py +32 -0
- dsgrid/filesystem/factory.py +32 -0
- dsgrid/filesystem/filesystem_interface.py +136 -0
- dsgrid/filesystem/local_filesystem.py +74 -0
- dsgrid/filesystem/s3_filesystem.py +118 -0
- dsgrid/loggers.py +132 -0
- dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
- dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
- dsgrid/notebooks/registration.ipynb +48 -0
- dsgrid/notebooks/start_notebook.sh +11 -0
- dsgrid/project.py +451 -0
- dsgrid/query/__init__.py +0 -0
- dsgrid/query/dataset_mapping_plan.py +142 -0
- dsgrid/query/derived_dataset.py +388 -0
- dsgrid/query/models.py +728 -0
- dsgrid/query/query_context.py +287 -0
- dsgrid/query/query_submitter.py +994 -0
- dsgrid/query/report_factory.py +19 -0
- dsgrid/query/report_peak_load.py +70 -0
- dsgrid/query/reports_base.py +20 -0
- dsgrid/registry/__init__.py +0 -0
- dsgrid/registry/bulk_register.py +165 -0
- dsgrid/registry/common.py +287 -0
- dsgrid/registry/config_update_checker_base.py +63 -0
- dsgrid/registry/data_store_factory.py +34 -0
- dsgrid/registry/data_store_interface.py +74 -0
- dsgrid/registry/dataset_config_generator.py +158 -0
- dsgrid/registry/dataset_registry_manager.py +950 -0
- dsgrid/registry/dataset_update_checker.py +16 -0
- dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
- dsgrid/registry/dimension_mapping_update_checker.py +16 -0
- dsgrid/registry/dimension_registry_manager.py +413 -0
- dsgrid/registry/dimension_update_checker.py +16 -0
- dsgrid/registry/duckdb_data_store.py +207 -0
- dsgrid/registry/filesystem_data_store.py +150 -0
- dsgrid/registry/filter_registry_manager.py +123 -0
- dsgrid/registry/project_config_generator.py +57 -0
- dsgrid/registry/project_registry_manager.py +1623 -0
- dsgrid/registry/project_update_checker.py +48 -0
- dsgrid/registry/registration_context.py +223 -0
- dsgrid/registry/registry_auto_updater.py +316 -0
- dsgrid/registry/registry_database.py +667 -0
- dsgrid/registry/registry_interface.py +446 -0
- dsgrid/registry/registry_manager.py +558 -0
- dsgrid/registry/registry_manager_base.py +367 -0
- dsgrid/registry/versioning.py +92 -0
- dsgrid/rust_ext/__init__.py +14 -0
- dsgrid/rust_ext/find_minimal_patterns.py +129 -0
- dsgrid/spark/__init__.py +0 -0
- dsgrid/spark/functions.py +589 -0
- dsgrid/spark/types.py +110 -0
- dsgrid/tests/__init__.py +0 -0
- dsgrid/tests/common.py +140 -0
- dsgrid/tests/make_us_data_registry.py +265 -0
- dsgrid/tests/register_derived_datasets.py +103 -0
- dsgrid/tests/utils.py +25 -0
- dsgrid/time/__init__.py +0 -0
- dsgrid/time/time_conversions.py +80 -0
- dsgrid/time/types.py +67 -0
- dsgrid/units/__init__.py +0 -0
- dsgrid/units/constants.py +113 -0
- dsgrid/units/convert.py +71 -0
- dsgrid/units/energy.py +145 -0
- dsgrid/units/power.py +87 -0
- dsgrid/utils/__init__.py +0 -0
- dsgrid/utils/dataset.py +830 -0
- dsgrid/utils/files.py +179 -0
- dsgrid/utils/filters.py +125 -0
- dsgrid/utils/id_remappings.py +100 -0
- dsgrid/utils/py_expression_eval/LICENSE +19 -0
- dsgrid/utils/py_expression_eval/README.md +8 -0
- dsgrid/utils/py_expression_eval/__init__.py +847 -0
- dsgrid/utils/py_expression_eval/tests.py +283 -0
- dsgrid/utils/run_command.py +70 -0
- dsgrid/utils/scratch_dir_context.py +65 -0
- dsgrid/utils/spark.py +918 -0
- dsgrid/utils/spark_partition.py +98 -0
- dsgrid/utils/timing.py +239 -0
- dsgrid/utils/utilities.py +221 -0
- dsgrid/utils/versioning.py +36 -0
- dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
- dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
- dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
- dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
- dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Self
|
|
3
|
+
|
|
4
|
+
from dsgrid.common import TIME_ZONE_COLUMN, VALUE_COLUMN
|
|
5
|
+
from dsgrid.config.dataset_config import DatasetConfig
|
|
6
|
+
from dsgrid.config.project_config import ProjectConfig
|
|
7
|
+
from dsgrid.config.simple_models import DimensionSimpleModel
|
|
8
|
+
from dsgrid.config.time_dimension_base_config import TimeDimensionBaseConfig
|
|
9
|
+
from dsgrid.dataset.models import ValueFormat
|
|
10
|
+
from dsgrid.query.models import DatasetQueryModel
|
|
11
|
+
from dsgrid.registry.data_store_interface import DataStoreInterface
|
|
12
|
+
from dsgrid.spark.types import (
|
|
13
|
+
DataFrame,
|
|
14
|
+
StringType,
|
|
15
|
+
)
|
|
16
|
+
from dsgrid.utils.dataset import (
|
|
17
|
+
convert_types_if_necessary,
|
|
18
|
+
)
|
|
19
|
+
from dsgrid.config.file_schema import read_data_file
|
|
20
|
+
from dsgrid.utils.scratch_dir_context import ScratchDirContext
|
|
21
|
+
from dsgrid.utils.spark import check_for_nulls
|
|
22
|
+
from dsgrid.utils.timing import timer_stats_collector, track_timing
|
|
23
|
+
from dsgrid.dataset.dataset_schema_handler_base import DatasetSchemaHandlerBase
|
|
24
|
+
from dsgrid.dimension.base_models import DatasetDimensionRequirements, DimensionType
|
|
25
|
+
from dsgrid.exceptions import DSGInvalidDataset
|
|
26
|
+
from dsgrid.query.query_context import QueryContext
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class OneTableDatasetSchemaHandler(DatasetSchemaHandlerBase):
|
|
32
|
+
"""define interface/required behaviors for ONE_TABLE dataset schema"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, load_data_df, *args, **kwargs):
|
|
35
|
+
super().__init__(*args, **kwargs)
|
|
36
|
+
self._load_data = load_data_df
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def load(
|
|
40
|
+
cls,
|
|
41
|
+
config: DatasetConfig,
|
|
42
|
+
*args,
|
|
43
|
+
store: DataStoreInterface | None = None,
|
|
44
|
+
scratch_dir_context: ScratchDirContext | None = None,
|
|
45
|
+
**kwargs,
|
|
46
|
+
) -> Self:
|
|
47
|
+
if store is None:
|
|
48
|
+
if config.data_file_schema is None:
|
|
49
|
+
msg = "Cannot load dataset without data file schema or store"
|
|
50
|
+
raise DSGInvalidDataset(msg)
|
|
51
|
+
df = read_data_file(config.data_file_schema, scratch_dir_context=scratch_dir_context)
|
|
52
|
+
else:
|
|
53
|
+
df = store.read_table(config.model.dataset_id, config.model.version)
|
|
54
|
+
load_data_df = config.add_trivial_dimensions(df)
|
|
55
|
+
load_data_df = convert_types_if_necessary(load_data_df)
|
|
56
|
+
return cls(load_data_df, config, *args, **kwargs)
|
|
57
|
+
|
|
58
|
+
@track_timing(timer_stats_collector)
|
|
59
|
+
def check_consistency(
|
|
60
|
+
self,
|
|
61
|
+
missing_dimension_associations: dict[str, DataFrame],
|
|
62
|
+
scratch_dir_context: ScratchDirContext,
|
|
63
|
+
requirements: DatasetDimensionRequirements,
|
|
64
|
+
) -> None:
|
|
65
|
+
self._check_one_table_data_consistency()
|
|
66
|
+
self._check_dimension_associations(
|
|
67
|
+
missing_dimension_associations, scratch_dir_context, requirements
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
@track_timing(timer_stats_collector)
|
|
71
|
+
def check_time_consistency(self):
|
|
72
|
+
time_dim = self._config.get_time_dimension()
|
|
73
|
+
if time_dim is not None:
|
|
74
|
+
if time_dim.supports_chronify():
|
|
75
|
+
self._check_dataset_time_consistency_with_chronify()
|
|
76
|
+
else:
|
|
77
|
+
self._check_dataset_time_consistency(self._load_data)
|
|
78
|
+
|
|
79
|
+
@track_timing(timer_stats_collector)
|
|
80
|
+
def _check_one_table_data_consistency(self):
|
|
81
|
+
"""Dimension check in load_data, excludes time:
|
|
82
|
+
* check that data matches record for each dimension.
|
|
83
|
+
* check that all data dimension combinations exist. Time is handled separately.
|
|
84
|
+
* Check for any NULL values in dimension columns.
|
|
85
|
+
"""
|
|
86
|
+
logger.info("Check one table dataset consistency.")
|
|
87
|
+
dimension_types = set()
|
|
88
|
+
time_dim = self._config.get_time_dimension()
|
|
89
|
+
time_columns: set[str] = set()
|
|
90
|
+
if time_dim is not None:
|
|
91
|
+
time_columns = set(time_dim.get_load_data_time_columns())
|
|
92
|
+
assert (
|
|
93
|
+
self._config.get_value_format() == ValueFormat.STACKED
|
|
94
|
+
), self._config.get_value_format()
|
|
95
|
+
self._check_load_data_unpivoted_value_column(self._load_data)
|
|
96
|
+
allowed_columns = DimensionType.get_allowed_dimension_column_names().union(time_columns)
|
|
97
|
+
allowed_columns.add(VALUE_COLUMN)
|
|
98
|
+
allowed_columns.add(TIME_ZONE_COLUMN)
|
|
99
|
+
|
|
100
|
+
schema = self._load_data.schema
|
|
101
|
+
for column in self._load_data.columns:
|
|
102
|
+
if column not in allowed_columns:
|
|
103
|
+
msg = f"{column=} is not expected in load_data"
|
|
104
|
+
raise DSGInvalidDataset(msg)
|
|
105
|
+
if not (
|
|
106
|
+
column in time_columns or column == VALUE_COLUMN or column == TIME_ZONE_COLUMN
|
|
107
|
+
):
|
|
108
|
+
dim_type = DimensionType.from_column(column)
|
|
109
|
+
if schema[column].dataType != StringType():
|
|
110
|
+
msg = f"dimension column {column} must have data type = StringType"
|
|
111
|
+
raise DSGInvalidDataset(msg)
|
|
112
|
+
dimension_types.add(dim_type)
|
|
113
|
+
check_for_nulls(self._load_data)
|
|
114
|
+
|
|
115
|
+
def get_base_load_data_table(self) -> DataFrame:
|
|
116
|
+
return self._load_data
|
|
117
|
+
|
|
118
|
+
def _get_load_data_table(self) -> DataFrame:
|
|
119
|
+
return self._load_data
|
|
120
|
+
|
|
121
|
+
@track_timing(timer_stats_collector)
|
|
122
|
+
def filter_data(self, dimensions: list[DimensionSimpleModel], store: DataStoreInterface):
|
|
123
|
+
assert (
|
|
124
|
+
self._config.get_value_format() == ValueFormat.STACKED
|
|
125
|
+
), self._config.get_value_format()
|
|
126
|
+
load_df = self._load_data
|
|
127
|
+
df_columns = set(load_df.columns)
|
|
128
|
+
stacked_columns = set()
|
|
129
|
+
for dim in dimensions:
|
|
130
|
+
column = dim.dimension_type.value
|
|
131
|
+
if column in df_columns:
|
|
132
|
+
load_df = load_df.filter(load_df[column].isin(dim.record_ids))
|
|
133
|
+
stacked_columns.add(column)
|
|
134
|
+
|
|
135
|
+
drop_columns = []
|
|
136
|
+
for dim in self._config.model.trivial_dimensions:
|
|
137
|
+
col = dim.value
|
|
138
|
+
count = load_df.select(col).distinct().count()
|
|
139
|
+
assert count == 1, f"{dim}: {count}"
|
|
140
|
+
drop_columns.append(col)
|
|
141
|
+
load_df = load_df.drop(*drop_columns)
|
|
142
|
+
|
|
143
|
+
store.replace_table(load_df, self.dataset_id, self._config.model.version)
|
|
144
|
+
logger.info("Rewrote simplified %s", self._config.model.dataset_id)
|
|
145
|
+
|
|
146
|
+
def make_project_dataframe(
|
|
147
|
+
self, context: QueryContext, project_config: ProjectConfig
|
|
148
|
+
) -> DataFrame:
|
|
149
|
+
plan = context.model.project.get_dataset_mapping_plan(self.dataset_id)
|
|
150
|
+
if plan is None:
|
|
151
|
+
plan = self.build_default_dataset_mapping_plan()
|
|
152
|
+
with context.dataset_mapping_manager(self.dataset_id, plan) as mapping_manager:
|
|
153
|
+
ld_df = mapping_manager.try_read_checkpointed_table()
|
|
154
|
+
if ld_df is None:
|
|
155
|
+
ld_df = self._load_data
|
|
156
|
+
ld_df = self._prefilter_stacked_dimensions(context, ld_df)
|
|
157
|
+
ld_df = self._prefilter_time_dimension(context, ld_df)
|
|
158
|
+
|
|
159
|
+
ld_df = self._remap_dimension_columns(
|
|
160
|
+
ld_df,
|
|
161
|
+
mapping_manager,
|
|
162
|
+
filtered_records=context.get_record_ids(),
|
|
163
|
+
)
|
|
164
|
+
ld_df = self._apply_fraction(ld_df, {VALUE_COLUMN}, mapping_manager)
|
|
165
|
+
project_metric_records = self._get_project_metric_records(project_config)
|
|
166
|
+
ld_df = self._convert_units(ld_df, project_metric_records, mapping_manager)
|
|
167
|
+
input_dataset = project_config.get_dataset(self._config.model.dataset_id)
|
|
168
|
+
ld_df = self._convert_time_dimension(
|
|
169
|
+
load_data_df=ld_df,
|
|
170
|
+
to_time_dim=project_config.get_base_time_dimension(),
|
|
171
|
+
value_column=VALUE_COLUMN,
|
|
172
|
+
mapping_manager=mapping_manager,
|
|
173
|
+
wrap_time_allowed=input_dataset.wrap_time_allowed,
|
|
174
|
+
time_based_data_adjustment=input_dataset.time_based_data_adjustment,
|
|
175
|
+
to_geo_dim=project_config.get_base_dimension(DimensionType.GEOGRAPHY),
|
|
176
|
+
)
|
|
177
|
+
return self._finalize_table(context, ld_df, project_config)
|
|
178
|
+
|
|
179
|
+
def make_mapped_dataframe(
|
|
180
|
+
self, context: QueryContext, time_dimension: TimeDimensionBaseConfig | None = None
|
|
181
|
+
) -> DataFrame:
|
|
182
|
+
query = context.model
|
|
183
|
+
assert isinstance(query, DatasetQueryModel)
|
|
184
|
+
plan = query.mapping_plan
|
|
185
|
+
if plan is None:
|
|
186
|
+
plan = self.build_default_dataset_mapping_plan()
|
|
187
|
+
geography_dimension = self._get_mapping_to_dimension(DimensionType.GEOGRAPHY)
|
|
188
|
+
metric_dimension = self._get_mapping_to_dimension(DimensionType.METRIC)
|
|
189
|
+
with context.dataset_mapping_manager(self.dataset_id, plan) as mapping_manager:
|
|
190
|
+
ld_df = mapping_manager.try_read_checkpointed_table()
|
|
191
|
+
if ld_df is None:
|
|
192
|
+
ld_df = self._load_data
|
|
193
|
+
|
|
194
|
+
ld_df = self._remap_dimension_columns(ld_df, mapping_manager)
|
|
195
|
+
ld_df = self._apply_fraction(ld_df, {VALUE_COLUMN}, mapping_manager)
|
|
196
|
+
if metric_dimension is not None:
|
|
197
|
+
metric_records = metric_dimension.get_records_dataframe()
|
|
198
|
+
ld_df = self._convert_units(ld_df, metric_records, mapping_manager)
|
|
199
|
+
if time_dimension is not None:
|
|
200
|
+
ld_df = self._convert_time_dimension(
|
|
201
|
+
load_data_df=ld_df,
|
|
202
|
+
to_time_dim=time_dimension,
|
|
203
|
+
value_column=VALUE_COLUMN,
|
|
204
|
+
mapping_manager=mapping_manager,
|
|
205
|
+
wrap_time_allowed=query.wrap_time_allowed,
|
|
206
|
+
time_based_data_adjustment=query.time_based_data_adjustment,
|
|
207
|
+
to_geo_dim=geography_dimension,
|
|
208
|
+
)
|
|
209
|
+
return ld_df
|
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Self
|
|
3
|
+
|
|
4
|
+
from dsgrid.common import SCALING_FACTOR_COLUMN, VALUE_COLUMN
|
|
5
|
+
from dsgrid.config.dataset_config import DatasetConfig
|
|
6
|
+
from dsgrid.config.project_config import ProjectConfig
|
|
7
|
+
from dsgrid.config.simple_models import DimensionSimpleModel
|
|
8
|
+
from dsgrid.config.time_dimension_base_config import TimeDimensionBaseConfig
|
|
9
|
+
from dsgrid.dataset.models import ValueFormat
|
|
10
|
+
from dsgrid.dataset.dataset_schema_handler_base import DatasetSchemaHandlerBase
|
|
11
|
+
from dsgrid.dimension.base_models import DatasetDimensionRequirements, DimensionType
|
|
12
|
+
from dsgrid.exceptions import DSGInvalidDataset
|
|
13
|
+
from dsgrid.query.models import DatasetQueryModel
|
|
14
|
+
from dsgrid.query.query_context import QueryContext
|
|
15
|
+
from dsgrid.registry.data_store_interface import DataStoreInterface
|
|
16
|
+
from dsgrid.spark.functions import (
|
|
17
|
+
cache,
|
|
18
|
+
coalesce,
|
|
19
|
+
collect_list,
|
|
20
|
+
except_all,
|
|
21
|
+
intersect,
|
|
22
|
+
unpersist,
|
|
23
|
+
)
|
|
24
|
+
from dsgrid.spark.types import (
|
|
25
|
+
DataFrame,
|
|
26
|
+
StringType,
|
|
27
|
+
)
|
|
28
|
+
from dsgrid.utils.dataset import (
|
|
29
|
+
apply_scaling_factor,
|
|
30
|
+
convert_types_if_necessary,
|
|
31
|
+
)
|
|
32
|
+
from dsgrid.config.file_schema import read_data_file
|
|
33
|
+
from dsgrid.utils.scratch_dir_context import ScratchDirContext
|
|
34
|
+
from dsgrid.utils.spark import check_for_nulls
|
|
35
|
+
from dsgrid.utils.timing import Timer, timer_stats_collector, track_timing
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class TwoTableDatasetSchemaHandler(DatasetSchemaHandlerBase):
|
|
42
|
+
"""Handler for TWO_TABLE dataset format (load_data + load_data_lookup tables)."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, load_data_df, load_data_lookup, *args, **kwargs):
|
|
45
|
+
super().__init__(*args, **kwargs)
|
|
46
|
+
self._load_data = load_data_df
|
|
47
|
+
self._load_data_lookup = load_data_lookup
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def load(
|
|
51
|
+
cls,
|
|
52
|
+
config: DatasetConfig,
|
|
53
|
+
*args,
|
|
54
|
+
store: DataStoreInterface | None = None,
|
|
55
|
+
scratch_dir_context: ScratchDirContext | None = None,
|
|
56
|
+
**kwargs,
|
|
57
|
+
) -> Self:
|
|
58
|
+
if store is None:
|
|
59
|
+
if config.data_file_schema is None:
|
|
60
|
+
msg = "Cannot load dataset without data file schema or store"
|
|
61
|
+
raise DSGInvalidDataset(msg)
|
|
62
|
+
if config.lookup_file_schema is None:
|
|
63
|
+
msg = "TWO_TABLE format requires lookup_data_file"
|
|
64
|
+
raise DSGInvalidDataset(msg)
|
|
65
|
+
load_data_df = read_data_file(
|
|
66
|
+
config.data_file_schema, scratch_dir_context=scratch_dir_context
|
|
67
|
+
)
|
|
68
|
+
load_data_lookup = read_data_file(
|
|
69
|
+
config.lookup_file_schema, scratch_dir_context=scratch_dir_context
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
load_data_df = store.read_table(config.model.dataset_id, config.model.version)
|
|
73
|
+
load_data_lookup = store.read_lookup_table(
|
|
74
|
+
config.model.dataset_id, config.model.version
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
load_data_df = convert_types_if_necessary(load_data_df)
|
|
78
|
+
load_data_lookup = config.add_trivial_dimensions(load_data_lookup)
|
|
79
|
+
load_data_lookup = convert_types_if_necessary(load_data_lookup)
|
|
80
|
+
return cls(load_data_df, load_data_lookup, config, *args, **kwargs)
|
|
81
|
+
|
|
82
|
+
@track_timing(timer_stats_collector)
|
|
83
|
+
def check_consistency(
|
|
84
|
+
self,
|
|
85
|
+
missing_dimension_associations: dict[str, DataFrame],
|
|
86
|
+
scratch_dir_context: ScratchDirContext,
|
|
87
|
+
requirements: DatasetDimensionRequirements,
|
|
88
|
+
) -> None:
|
|
89
|
+
self._check_lookup_data_consistency()
|
|
90
|
+
self._check_dataset_internal_consistency()
|
|
91
|
+
self._check_dimension_associations(
|
|
92
|
+
missing_dimension_associations, scratch_dir_context, requirements
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
@track_timing(timer_stats_collector)
|
|
96
|
+
def check_time_consistency(self):
|
|
97
|
+
time_dim = self._config.get_time_dimension()
|
|
98
|
+
if time_dim is None:
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
if time_dim.supports_chronify():
|
|
102
|
+
self._check_dataset_time_consistency_with_chronify()
|
|
103
|
+
else:
|
|
104
|
+
self._check_dataset_time_consistency(self._get_load_data_table())
|
|
105
|
+
|
|
106
|
+
def get_base_load_data_table(self) -> DataFrame:
|
|
107
|
+
return self._load_data
|
|
108
|
+
|
|
109
|
+
def _get_load_data_table(self) -> DataFrame:
|
|
110
|
+
return self._load_data.join(self._load_data_lookup, on="id")
|
|
111
|
+
|
|
112
|
+
def make_project_dataframe(
|
|
113
|
+
self, context: QueryContext, project_config: ProjectConfig
|
|
114
|
+
) -> DataFrame:
|
|
115
|
+
lk_df = self._load_data_lookup
|
|
116
|
+
lk_df = self._prefilter_stacked_dimensions(context, lk_df)
|
|
117
|
+
|
|
118
|
+
plan = context.model.project.get_dataset_mapping_plan(self.dataset_id)
|
|
119
|
+
if plan is None:
|
|
120
|
+
plan = self.build_default_dataset_mapping_plan()
|
|
121
|
+
with context.dataset_mapping_manager(self.dataset_id, plan) as mapping_manager:
|
|
122
|
+
ld_df = mapping_manager.try_read_checkpointed_table()
|
|
123
|
+
if ld_df is None:
|
|
124
|
+
ld_df = self._load_data
|
|
125
|
+
ld_df = self._prefilter_stacked_dimensions(context, ld_df)
|
|
126
|
+
ld_df = self._prefilter_time_dimension(context, ld_df)
|
|
127
|
+
ld_df = ld_df.join(lk_df, on="id").drop("id")
|
|
128
|
+
|
|
129
|
+
ld_df = self._remap_dimension_columns(
|
|
130
|
+
ld_df,
|
|
131
|
+
mapping_manager,
|
|
132
|
+
filtered_records=context.get_record_ids(),
|
|
133
|
+
)
|
|
134
|
+
if SCALING_FACTOR_COLUMN in ld_df.columns:
|
|
135
|
+
ld_df = apply_scaling_factor(ld_df, VALUE_COLUMN, mapping_manager)
|
|
136
|
+
|
|
137
|
+
ld_df = self._apply_fraction(ld_df, {VALUE_COLUMN}, mapping_manager)
|
|
138
|
+
project_metric_records = self._get_project_metric_records(project_config)
|
|
139
|
+
ld_df = self._convert_units(ld_df, project_metric_records, mapping_manager)
|
|
140
|
+
input_dataset = project_config.get_dataset(self._config.model.dataset_id)
|
|
141
|
+
ld_df = self._convert_time_dimension(
|
|
142
|
+
load_data_df=ld_df,
|
|
143
|
+
to_time_dim=project_config.get_base_time_dimension(),
|
|
144
|
+
value_column=VALUE_COLUMN,
|
|
145
|
+
mapping_manager=mapping_manager,
|
|
146
|
+
wrap_time_allowed=input_dataset.wrap_time_allowed,
|
|
147
|
+
time_based_data_adjustment=input_dataset.time_based_data_adjustment,
|
|
148
|
+
to_geo_dim=project_config.get_base_dimension(DimensionType.GEOGRAPHY),
|
|
149
|
+
)
|
|
150
|
+
return self._finalize_table(context, ld_df, project_config)
|
|
151
|
+
|
|
152
|
+
def make_mapped_dataframe(
|
|
153
|
+
self,
|
|
154
|
+
context: QueryContext,
|
|
155
|
+
time_dimension: TimeDimensionBaseConfig | None = None,
|
|
156
|
+
) -> DataFrame:
|
|
157
|
+
query = context.model
|
|
158
|
+
assert isinstance(query, DatasetQueryModel)
|
|
159
|
+
plan = query.mapping_plan
|
|
160
|
+
if plan is None:
|
|
161
|
+
plan = self.build_default_dataset_mapping_plan()
|
|
162
|
+
geography_dimension = self._get_mapping_to_dimension(DimensionType.GEOGRAPHY)
|
|
163
|
+
metric_dimension = self._get_mapping_to_dimension(DimensionType.METRIC)
|
|
164
|
+
with context.dataset_mapping_manager(self.dataset_id, plan) as mapping_manager:
|
|
165
|
+
ld_df = mapping_manager.try_read_checkpointed_table()
|
|
166
|
+
if ld_df is None:
|
|
167
|
+
ld_df = self._load_data
|
|
168
|
+
lk_df = self._load_data_lookup
|
|
169
|
+
ld_df = ld_df.join(lk_df, on="id").drop("id")
|
|
170
|
+
|
|
171
|
+
ld_df = self._remap_dimension_columns(
|
|
172
|
+
ld_df,
|
|
173
|
+
mapping_manager,
|
|
174
|
+
)
|
|
175
|
+
if SCALING_FACTOR_COLUMN in ld_df.columns:
|
|
176
|
+
ld_df = apply_scaling_factor(ld_df, VALUE_COLUMN, mapping_manager)
|
|
177
|
+
|
|
178
|
+
ld_df = self._apply_fraction(ld_df, {VALUE_COLUMN}, mapping_manager)
|
|
179
|
+
if metric_dimension is not None:
|
|
180
|
+
metric_records = metric_dimension.get_records_dataframe()
|
|
181
|
+
ld_df = self._convert_units(ld_df, metric_records, mapping_manager)
|
|
182
|
+
if time_dimension is not None:
|
|
183
|
+
ld_df = self._convert_time_dimension(
|
|
184
|
+
load_data_df=ld_df,
|
|
185
|
+
to_time_dim=time_dimension,
|
|
186
|
+
value_column=VALUE_COLUMN,
|
|
187
|
+
mapping_manager=mapping_manager,
|
|
188
|
+
wrap_time_allowed=query.wrap_time_allowed,
|
|
189
|
+
time_based_data_adjustment=query.time_based_data_adjustment,
|
|
190
|
+
to_geo_dim=geography_dimension,
|
|
191
|
+
)
|
|
192
|
+
return ld_df
|
|
193
|
+
|
|
194
|
+
@track_timing(timer_stats_collector)
|
|
195
|
+
def _check_lookup_data_consistency(self):
|
|
196
|
+
"""Dimension check in load_data_lookup, excludes time.
|
|
197
|
+
|
|
198
|
+
Checks:
|
|
199
|
+
- Data matches record for each dimension.
|
|
200
|
+
- All data dimension combinations exist. Time is handled separately.
|
|
201
|
+
- No NULL values in dimension columns.
|
|
202
|
+
"""
|
|
203
|
+
logger.info("Check lookup data consistency.")
|
|
204
|
+
found_id = False
|
|
205
|
+
dimension_types = set()
|
|
206
|
+
for col in self._load_data_lookup.columns:
|
|
207
|
+
if col == "id":
|
|
208
|
+
found_id = True
|
|
209
|
+
continue
|
|
210
|
+
if col == SCALING_FACTOR_COLUMN:
|
|
211
|
+
continue
|
|
212
|
+
if self._load_data_lookup.schema[col].dataType != StringType():
|
|
213
|
+
msg = f"dimension column {col} must have data type = StringType"
|
|
214
|
+
raise DSGInvalidDataset(msg)
|
|
215
|
+
dimension_types.add(DimensionType.from_column(col))
|
|
216
|
+
|
|
217
|
+
if not found_id:
|
|
218
|
+
msg = "load_data_lookup does not include an 'id' column"
|
|
219
|
+
raise DSGInvalidDataset(msg)
|
|
220
|
+
|
|
221
|
+
check_for_nulls(self._load_data_lookup)
|
|
222
|
+
load_data_dimensions = set(self._list_dimension_types_in_load_data(self._load_data))
|
|
223
|
+
expected_dimensions = {
|
|
224
|
+
d
|
|
225
|
+
for d in DimensionType.get_dimension_types_allowed_as_columns()
|
|
226
|
+
if d not in load_data_dimensions
|
|
227
|
+
}
|
|
228
|
+
missing_dimensions = expected_dimensions.difference(dimension_types)
|
|
229
|
+
if missing_dimensions:
|
|
230
|
+
msg = (
|
|
231
|
+
f"load_data_lookup is missing dimensions: {missing_dimensions}. "
|
|
232
|
+
"If these are trivial dimensions, make sure to specify them in the Dataset Config."
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
@track_timing(timer_stats_collector)
|
|
236
|
+
def _check_dataset_internal_consistency(self):
|
|
237
|
+
"""Check load_data dimensions and id series."""
|
|
238
|
+
logger.info("Check dataset internal consistency.")
|
|
239
|
+
assert (
|
|
240
|
+
self._config.get_value_format() == ValueFormat.STACKED
|
|
241
|
+
), self._config.get_value_format()
|
|
242
|
+
self._check_load_data_unpivoted_value_column(self._load_data)
|
|
243
|
+
|
|
244
|
+
time_dim = self._config.get_time_dimension()
|
|
245
|
+
time_columns: set[str] = set()
|
|
246
|
+
if time_dim is not None:
|
|
247
|
+
time_columns = set(time_dim.get_load_data_time_columns())
|
|
248
|
+
allowed_columns = (
|
|
249
|
+
DimensionType.get_allowed_dimension_column_names()
|
|
250
|
+
.union(time_columns)
|
|
251
|
+
.union({VALUE_COLUMN, "id", "scaling_factor"})
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
found_id = False
|
|
255
|
+
for column in self._load_data.columns:
|
|
256
|
+
if column not in allowed_columns:
|
|
257
|
+
msg = f"{column=} is not expected in load_data"
|
|
258
|
+
raise DSGInvalidDataset(msg)
|
|
259
|
+
if column == "id":
|
|
260
|
+
found_id = True
|
|
261
|
+
|
|
262
|
+
if not found_id:
|
|
263
|
+
msg = "load_data does not include an 'id' column"
|
|
264
|
+
raise DSGInvalidDataset(msg)
|
|
265
|
+
|
|
266
|
+
check_for_nulls(self._load_data)
|
|
267
|
+
ld_ids = self._load_data.select("id").distinct()
|
|
268
|
+
ldl_ids = self._load_data_lookup.select("id").distinct()
|
|
269
|
+
ldl_id_count = ldl_ids.count()
|
|
270
|
+
data_id_count = ld_ids.count()
|
|
271
|
+
joined = ld_ids.join(ldl_ids, on="id")
|
|
272
|
+
count = joined.count()
|
|
273
|
+
|
|
274
|
+
if data_id_count != count or ldl_id_count != count:
|
|
275
|
+
with Timer(timer_stats_collector, "show load_data and load_data_lookup ID diff"):
|
|
276
|
+
diff = except_all(ld_ids.unionAll(ldl_ids), intersect(ld_ids, ldl_ids))
|
|
277
|
+
# Only run the query once (with Spark). Number of rows shouldn't be a problem.
|
|
278
|
+
cache(diff)
|
|
279
|
+
diff_count = diff.count()
|
|
280
|
+
limit = 100
|
|
281
|
+
diff_list = diff.limit(limit).collect()
|
|
282
|
+
unpersist(diff)
|
|
283
|
+
logger.error(
|
|
284
|
+
"load_data and load_data_lookup have %s different IDs. Limited to %s: %s",
|
|
285
|
+
diff_count,
|
|
286
|
+
limit,
|
|
287
|
+
diff_list,
|
|
288
|
+
)
|
|
289
|
+
msg = f"Data IDs for {self._config.config_id} data/lookup are inconsistent"
|
|
290
|
+
raise DSGInvalidDataset(msg)
|
|
291
|
+
|
|
292
|
+
@track_timing(timer_stats_collector)
|
|
293
|
+
def filter_data(self, dimensions: list[DimensionSimpleModel], store: DataStoreInterface):
|
|
294
|
+
lookup = self._load_data_lookup
|
|
295
|
+
cache(lookup)
|
|
296
|
+
load_df = self._load_data
|
|
297
|
+
lookup_columns = set(lookup.columns)
|
|
298
|
+
for dim in dimensions:
|
|
299
|
+
column = dim.dimension_type.value
|
|
300
|
+
if column in lookup_columns:
|
|
301
|
+
lookup = lookup.filter(lookup[column].isin(dim.record_ids))
|
|
302
|
+
|
|
303
|
+
drop_columns = []
|
|
304
|
+
for dim in self._config.model.trivial_dimensions:
|
|
305
|
+
col = dim.value
|
|
306
|
+
count = lookup.select(col).distinct().count()
|
|
307
|
+
assert count == 1, f"{dim}: count"
|
|
308
|
+
drop_columns.append(col)
|
|
309
|
+
lookup = lookup.drop(*drop_columns)
|
|
310
|
+
|
|
311
|
+
lookup2 = coalesce(lookup, 1)
|
|
312
|
+
store.replace_lookup_table(lookup2, self.dataset_id, self._config.model.version)
|
|
313
|
+
ids = collect_list(lookup2.select("id").distinct(), "id")
|
|
314
|
+
load_df = self._load_data.filter(self._load_data.id.isin(ids))
|
|
315
|
+
ld_columns = set(load_df.columns)
|
|
316
|
+
for dim in dimensions:
|
|
317
|
+
column = dim.dimension_type.value
|
|
318
|
+
if column in ld_columns:
|
|
319
|
+
load_df = load_df.filter(load_df[column].isin(dim.record_ids))
|
|
320
|
+
|
|
321
|
+
store.replace_table(load_df, self.dataset_id, self._config.model.version)
|
|
322
|
+
logger.info("Rewrote simplified %s", self._config.model.dataset_id)
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from dsgrid.exceptions import DSGInvalidQuery
|
|
4
|
+
from dsgrid.query.models import ProjectionDatasetModel
|
|
5
|
+
from dsgrid.spark.functions import cross_join, join_multiple_columns, sql_from_df
|
|
6
|
+
from dsgrid.spark.types import DataFrame, F, IntegerType, use_duckdb
|
|
7
|
+
from dsgrid.utils.spark import get_unique_values
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def apply_exponential_growth_rate(
|
|
14
|
+
dataset: ProjectionDatasetModel,
|
|
15
|
+
initial_value_df: DataFrame,
|
|
16
|
+
growth_rate_df: DataFrame,
|
|
17
|
+
time_columns,
|
|
18
|
+
model_year_column,
|
|
19
|
+
value_columns,
|
|
20
|
+
):
|
|
21
|
+
"""Applies exponential growth rate to the initial_value dataframe as follows:
|
|
22
|
+
P(t) = P0*(1+r)^(t-t0)
|
|
23
|
+
where:
|
|
24
|
+
P(t): quantity at t
|
|
25
|
+
P0: initial quantity at t0, = P(t0)
|
|
26
|
+
r: growth rate (per time interval)
|
|
27
|
+
t-t0: number of time intervals
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
dataset : ProjectionDatasetModel
|
|
33
|
+
initial_value_df : pyspark.sql.DataFrame
|
|
34
|
+
growth_rate_df : pyspark.sql.DataFrame
|
|
35
|
+
time_columns : set[str]
|
|
36
|
+
model_year_column : str
|
|
37
|
+
value_columns : set[str]
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
pyspark.sql.DataFrame
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
initial_value_df, growth_rate_df = _process_exponential_growth_rate(
|
|
46
|
+
dataset,
|
|
47
|
+
initial_value_df,
|
|
48
|
+
growth_rate_df,
|
|
49
|
+
model_year_column,
|
|
50
|
+
value_columns,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
df = apply_annual_multiplier(
|
|
54
|
+
initial_value_df,
|
|
55
|
+
growth_rate_df,
|
|
56
|
+
time_columns,
|
|
57
|
+
value_columns,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return df
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def apply_annual_multiplier(
|
|
64
|
+
initial_value_df: DataFrame,
|
|
65
|
+
growth_rate_df: DataFrame,
|
|
66
|
+
time_columns,
|
|
67
|
+
value_columns,
|
|
68
|
+
):
|
|
69
|
+
"""Applies annual growth rate to the initial_value dataframe as follows:
|
|
70
|
+
P(t) = P0 * r(t)
|
|
71
|
+
where:
|
|
72
|
+
P(t): quantity at year t
|
|
73
|
+
P0: initial quantity
|
|
74
|
+
r(t): growth rate per year t (relative to P0)
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
dataset : ProjectionDatasetModel
|
|
79
|
+
initial_value_df : pyspark.sql.DataFrame
|
|
80
|
+
growth_rate_df : pyspark.sql.DataFrame
|
|
81
|
+
time_columns : set[str]
|
|
82
|
+
value_columns : set[str]
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
pyspark.sql.DataFrame
|
|
87
|
+
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def renamed(col):
|
|
91
|
+
return col + "_gr"
|
|
92
|
+
|
|
93
|
+
orig_columns = initial_value_df.columns
|
|
94
|
+
|
|
95
|
+
dim_columns = set(initial_value_df.columns) - value_columns - time_columns
|
|
96
|
+
df = join_multiple_columns(initial_value_df, growth_rate_df, list(dim_columns))
|
|
97
|
+
for column in df.columns:
|
|
98
|
+
if column in value_columns:
|
|
99
|
+
gr_column = renamed(column)
|
|
100
|
+
df = df.withColumn(column, df[column] * df[gr_column])
|
|
101
|
+
|
|
102
|
+
return df.select(*orig_columns)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _process_exponential_growth_rate(
|
|
106
|
+
dataset,
|
|
107
|
+
initial_value_df,
|
|
108
|
+
growth_rate_df,
|
|
109
|
+
model_year_column,
|
|
110
|
+
value_columns,
|
|
111
|
+
):
|
|
112
|
+
def renamed(col):
|
|
113
|
+
return col + "_gr"
|
|
114
|
+
|
|
115
|
+
initial_value_df, base_year = _check_model_years(
|
|
116
|
+
dataset, initial_value_df, growth_rate_df, model_year_column
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
gr_df = growth_rate_df
|
|
120
|
+
for column in value_columns:
|
|
121
|
+
gr_col = renamed(column)
|
|
122
|
+
cols = ",".join([x for x in gr_df.columns if x not in (column, gr_col)])
|
|
123
|
+
if use_duckdb():
|
|
124
|
+
query = f"""
|
|
125
|
+
SELECT
|
|
126
|
+
{cols}
|
|
127
|
+
,(1 + {column}) ** (CAST({model_year_column} AS INTEGER) - {base_year}) AS {gr_col}
|
|
128
|
+
"""
|
|
129
|
+
gr_df = sql_from_df(gr_df, query)
|
|
130
|
+
else:
|
|
131
|
+
# Spark SQL uses POW instead of **, so keep the DataFrame API method.
|
|
132
|
+
gr_df = gr_df.withColumn(
|
|
133
|
+
gr_col,
|
|
134
|
+
F.pow(
|
|
135
|
+
(1 + F.col(column)), F.col(model_year_column).cast(IntegerType()) - base_year
|
|
136
|
+
),
|
|
137
|
+
).drop(column)
|
|
138
|
+
|
|
139
|
+
return initial_value_df, gr_df
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _check_model_years(dataset, initial_value_df, growth_rate_df, model_year_column):
|
|
143
|
+
iv_years = get_unique_values(initial_value_df, model_year_column)
|
|
144
|
+
iv_years_sorted = sorted((int(x) for x in iv_years))
|
|
145
|
+
|
|
146
|
+
if dataset.base_year is None:
|
|
147
|
+
base_year = iv_years_sorted[0]
|
|
148
|
+
elif dataset.base_year in iv_years:
|
|
149
|
+
base_year = dataset.base_year
|
|
150
|
+
else:
|
|
151
|
+
msg = f"ProjectionDatasetModel base_year={dataset.base_year} is not in {iv_years_sorted}"
|
|
152
|
+
raise DSGInvalidQuery(msg)
|
|
153
|
+
|
|
154
|
+
if len(iv_years) > 1:
|
|
155
|
+
# TODO #198: needs test case
|
|
156
|
+
initial_value_df = initial_value_df.filter(f"{model_year_column} == '{base_year}'")
|
|
157
|
+
|
|
158
|
+
initial_value_df = cross_join(
|
|
159
|
+
initial_value_df.drop(model_year_column),
|
|
160
|
+
growth_rate_df.select(model_year_column).distinct(),
|
|
161
|
+
)
|
|
162
|
+
return initial_value_df, base_year
|