dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- build_backend.py +93 -0
- dsgrid/__init__.py +22 -0
- dsgrid/api/__init__.py +0 -0
- dsgrid/api/api_manager.py +179 -0
- dsgrid/api/app.py +419 -0
- dsgrid/api/models.py +60 -0
- dsgrid/api/response_models.py +116 -0
- dsgrid/apps/__init__.py +0 -0
- dsgrid/apps/project_viewer/app.py +216 -0
- dsgrid/apps/registration_gui.py +444 -0
- dsgrid/chronify.py +32 -0
- dsgrid/cli/__init__.py +0 -0
- dsgrid/cli/common.py +120 -0
- dsgrid/cli/config.py +176 -0
- dsgrid/cli/download.py +13 -0
- dsgrid/cli/dsgrid.py +157 -0
- dsgrid/cli/dsgrid_admin.py +92 -0
- dsgrid/cli/install_notebooks.py +62 -0
- dsgrid/cli/query.py +729 -0
- dsgrid/cli/registry.py +1862 -0
- dsgrid/cloud/__init__.py +0 -0
- dsgrid/cloud/cloud_storage_interface.py +140 -0
- dsgrid/cloud/factory.py +31 -0
- dsgrid/cloud/fake_storage_interface.py +37 -0
- dsgrid/cloud/s3_storage_interface.py +156 -0
- dsgrid/common.py +36 -0
- dsgrid/config/__init__.py +0 -0
- dsgrid/config/annual_time_dimension_config.py +194 -0
- dsgrid/config/common.py +142 -0
- dsgrid/config/config_base.py +148 -0
- dsgrid/config/dataset_config.py +907 -0
- dsgrid/config/dataset_schema_handler_factory.py +46 -0
- dsgrid/config/date_time_dimension_config.py +136 -0
- dsgrid/config/dimension_config.py +54 -0
- dsgrid/config/dimension_config_factory.py +65 -0
- dsgrid/config/dimension_mapping_base.py +350 -0
- dsgrid/config/dimension_mappings_config.py +48 -0
- dsgrid/config/dimensions.py +1025 -0
- dsgrid/config/dimensions_config.py +71 -0
- dsgrid/config/file_schema.py +190 -0
- dsgrid/config/index_time_dimension_config.py +80 -0
- dsgrid/config/input_dataset_requirements.py +31 -0
- dsgrid/config/mapping_tables.py +209 -0
- dsgrid/config/noop_time_dimension_config.py +42 -0
- dsgrid/config/project_config.py +1462 -0
- dsgrid/config/registration_models.py +188 -0
- dsgrid/config/representative_period_time_dimension_config.py +194 -0
- dsgrid/config/simple_models.py +49 -0
- dsgrid/config/supplemental_dimension.py +29 -0
- dsgrid/config/time_dimension_base_config.py +192 -0
- dsgrid/data_models.py +155 -0
- dsgrid/dataset/__init__.py +0 -0
- dsgrid/dataset/dataset.py +123 -0
- dsgrid/dataset/dataset_expression_handler.py +86 -0
- dsgrid/dataset/dataset_mapping_manager.py +121 -0
- dsgrid/dataset/dataset_schema_handler_base.py +945 -0
- dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
- dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
- dsgrid/dataset/growth_rates.py +162 -0
- dsgrid/dataset/models.py +51 -0
- dsgrid/dataset/table_format_handler_base.py +257 -0
- dsgrid/dataset/table_format_handler_factory.py +17 -0
- dsgrid/dataset/unpivoted_table.py +121 -0
- dsgrid/dimension/__init__.py +0 -0
- dsgrid/dimension/base_models.py +230 -0
- dsgrid/dimension/dimension_filters.py +308 -0
- dsgrid/dimension/standard.py +252 -0
- dsgrid/dimension/time.py +352 -0
- dsgrid/dimension/time_utils.py +103 -0
- dsgrid/dsgrid_rc.py +88 -0
- dsgrid/exceptions.py +105 -0
- dsgrid/filesystem/__init__.py +0 -0
- dsgrid/filesystem/cloud_filesystem.py +32 -0
- dsgrid/filesystem/factory.py +32 -0
- dsgrid/filesystem/filesystem_interface.py +136 -0
- dsgrid/filesystem/local_filesystem.py +74 -0
- dsgrid/filesystem/s3_filesystem.py +118 -0
- dsgrid/loggers.py +132 -0
- dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
- dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
- dsgrid/notebooks/registration.ipynb +48 -0
- dsgrid/notebooks/start_notebook.sh +11 -0
- dsgrid/project.py +451 -0
- dsgrid/query/__init__.py +0 -0
- dsgrid/query/dataset_mapping_plan.py +142 -0
- dsgrid/query/derived_dataset.py +388 -0
- dsgrid/query/models.py +728 -0
- dsgrid/query/query_context.py +287 -0
- dsgrid/query/query_submitter.py +994 -0
- dsgrid/query/report_factory.py +19 -0
- dsgrid/query/report_peak_load.py +70 -0
- dsgrid/query/reports_base.py +20 -0
- dsgrid/registry/__init__.py +0 -0
- dsgrid/registry/bulk_register.py +165 -0
- dsgrid/registry/common.py +287 -0
- dsgrid/registry/config_update_checker_base.py +63 -0
- dsgrid/registry/data_store_factory.py +34 -0
- dsgrid/registry/data_store_interface.py +74 -0
- dsgrid/registry/dataset_config_generator.py +158 -0
- dsgrid/registry/dataset_registry_manager.py +950 -0
- dsgrid/registry/dataset_update_checker.py +16 -0
- dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
- dsgrid/registry/dimension_mapping_update_checker.py +16 -0
- dsgrid/registry/dimension_registry_manager.py +413 -0
- dsgrid/registry/dimension_update_checker.py +16 -0
- dsgrid/registry/duckdb_data_store.py +207 -0
- dsgrid/registry/filesystem_data_store.py +150 -0
- dsgrid/registry/filter_registry_manager.py +123 -0
- dsgrid/registry/project_config_generator.py +57 -0
- dsgrid/registry/project_registry_manager.py +1623 -0
- dsgrid/registry/project_update_checker.py +48 -0
- dsgrid/registry/registration_context.py +223 -0
- dsgrid/registry/registry_auto_updater.py +316 -0
- dsgrid/registry/registry_database.py +667 -0
- dsgrid/registry/registry_interface.py +446 -0
- dsgrid/registry/registry_manager.py +558 -0
- dsgrid/registry/registry_manager_base.py +367 -0
- dsgrid/registry/versioning.py +92 -0
- dsgrid/rust_ext/__init__.py +14 -0
- dsgrid/rust_ext/find_minimal_patterns.py +129 -0
- dsgrid/spark/__init__.py +0 -0
- dsgrid/spark/functions.py +589 -0
- dsgrid/spark/types.py +110 -0
- dsgrid/tests/__init__.py +0 -0
- dsgrid/tests/common.py +140 -0
- dsgrid/tests/make_us_data_registry.py +265 -0
- dsgrid/tests/register_derived_datasets.py +103 -0
- dsgrid/tests/utils.py +25 -0
- dsgrid/time/__init__.py +0 -0
- dsgrid/time/time_conversions.py +80 -0
- dsgrid/time/types.py +67 -0
- dsgrid/units/__init__.py +0 -0
- dsgrid/units/constants.py +113 -0
- dsgrid/units/convert.py +71 -0
- dsgrid/units/energy.py +145 -0
- dsgrid/units/power.py +87 -0
- dsgrid/utils/__init__.py +0 -0
- dsgrid/utils/dataset.py +830 -0
- dsgrid/utils/files.py +179 -0
- dsgrid/utils/filters.py +125 -0
- dsgrid/utils/id_remappings.py +100 -0
- dsgrid/utils/py_expression_eval/LICENSE +19 -0
- dsgrid/utils/py_expression_eval/README.md +8 -0
- dsgrid/utils/py_expression_eval/__init__.py +847 -0
- dsgrid/utils/py_expression_eval/tests.py +283 -0
- dsgrid/utils/run_command.py +70 -0
- dsgrid/utils/scratch_dir_context.py +65 -0
- dsgrid/utils/spark.py +918 -0
- dsgrid/utils/spark_partition.py +98 -0
- dsgrid/utils/timing.py +239 -0
- dsgrid/utils/utilities.py +221 -0
- dsgrid/utils/versioning.py +36 -0
- dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
- dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
- dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
- dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
- dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,994 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import shutil
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import copy
|
|
7
|
+
from zipfile import ZipFile
|
|
8
|
+
|
|
9
|
+
from chronify.utils.path_utils import check_overwrite
|
|
10
|
+
from semver import VersionInfo
|
|
11
|
+
from sqlalchemy import Connection
|
|
12
|
+
|
|
13
|
+
import dsgrid
|
|
14
|
+
from dsgrid.common import VALUE_COLUMN, BackendEngine
|
|
15
|
+
from dsgrid.config.dataset_config import DatasetConfig
|
|
16
|
+
from dsgrid.config.dimension_config import DimensionBaseConfig
|
|
17
|
+
from dsgrid.config.project_config import DatasetBaseDimensionNamesModel
|
|
18
|
+
from dsgrid.config.dimension_mapping_base import DimensionMappingReferenceModel
|
|
19
|
+
from dsgrid.config.time_dimension_base_config import TimeDimensionBaseConfig
|
|
20
|
+
from dsgrid.dataset.dataset_expression_handler import (
|
|
21
|
+
DatasetExpressionHandler,
|
|
22
|
+
evaluate_expression,
|
|
23
|
+
)
|
|
24
|
+
from dsgrid.utils.scratch_dir_context import ScratchDirContext
|
|
25
|
+
from dsgrid.dataset.models import ValueFormat, PivotedTableFormatModel
|
|
26
|
+
from dsgrid.dataset.dataset_schema_handler_base import DatasetSchemaHandlerBase
|
|
27
|
+
from dsgrid.dataset.table_format_handler_factory import make_table_format_handler
|
|
28
|
+
from dsgrid.dimension.base_models import DimensionCategory, DimensionType
|
|
29
|
+
from dsgrid.dimension.dimension_filters import SubsetDimensionFilterModel
|
|
30
|
+
from dsgrid.exceptions import DSGInvalidDataset, DSGInvalidParameter, DSGInvalidQuery
|
|
31
|
+
from dsgrid.dsgrid_rc import DsgridRuntimeConfig
|
|
32
|
+
from dsgrid.query.dataset_mapping_plan import MapOperationCheckpoint
|
|
33
|
+
from dsgrid.query.query_context import QueryContext
|
|
34
|
+
from dsgrid.query.report_factory import make_report
|
|
35
|
+
from dsgrid.registry.registry_manager import RegistryManager
|
|
36
|
+
from dsgrid.spark.functions import pivot
|
|
37
|
+
from dsgrid.spark.types import DataFrame
|
|
38
|
+
from dsgrid.project import Project
|
|
39
|
+
from dsgrid.utils.spark import (
|
|
40
|
+
custom_time_zone,
|
|
41
|
+
read_dataframe,
|
|
42
|
+
try_read_dataframe,
|
|
43
|
+
write_dataframe,
|
|
44
|
+
write_dataframe_and_auto_partition,
|
|
45
|
+
persist_table,
|
|
46
|
+
)
|
|
47
|
+
from dsgrid.utils.timing import timer_stats_collector, track_timing
|
|
48
|
+
from dsgrid.utils.files import delete_if_exists, compute_hash, load_data
|
|
49
|
+
from dsgrid.query.models import (
|
|
50
|
+
DatasetQueryModel,
|
|
51
|
+
ProjectQueryModel,
|
|
52
|
+
ColumnType,
|
|
53
|
+
CreateCompositeDatasetQueryModel,
|
|
54
|
+
CompositeDatasetQueryModel,
|
|
55
|
+
DatasetMetadataModel,
|
|
56
|
+
ProjectionDatasetModel,
|
|
57
|
+
StandaloneDatasetModel,
|
|
58
|
+
)
|
|
59
|
+
from dsgrid.utils.dataset import (
|
|
60
|
+
add_time_zone,
|
|
61
|
+
convert_time_zone_with_chronify_spark_hive,
|
|
62
|
+
convert_time_zone_with_chronify_spark_path,
|
|
63
|
+
convert_time_zone_with_chronify_duckdb,
|
|
64
|
+
convert_time_zone_by_column_with_chronify_spark_hive,
|
|
65
|
+
convert_time_zone_by_column_with_chronify_spark_path,
|
|
66
|
+
convert_time_zone_by_column_with_chronify_duckdb,
|
|
67
|
+
)
|
|
68
|
+
from dsgrid.config.dataset_schema_handler_factory import make_dataset_schema_handler
|
|
69
|
+
from dsgrid.config.date_time_dimension_config import DateTimeDimensionConfig
|
|
70
|
+
from dsgrid.exceptions import DSGInvalidOperation
|
|
71
|
+
|
|
72
|
+
logger = logging.getLogger(__name__)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class QuerySubmitterBase:
|
|
76
|
+
"""Handles query submission"""
|
|
77
|
+
|
|
78
|
+
def __init__(self, output_dir: Path):
|
|
79
|
+
self._output_dir = output_dir
|
|
80
|
+
self._cached_tables_dir().mkdir(exist_ok=True, parents=True)
|
|
81
|
+
self._composite_datasets_dir().mkdir(exist_ok=True, parents=True)
|
|
82
|
+
|
|
83
|
+
# TODO #186: This location will need more consideration.
|
|
84
|
+
# We might want to store cached datasets in the spark-warehouse and let Spark manage it
|
|
85
|
+
# for us. However, would we share them on the HPC? What happens on HPC walltime timeouts
|
|
86
|
+
# where the tables are left in intermediate states?
|
|
87
|
+
# This is even more of a problem on AWS.
|
|
88
|
+
self._cached_project_mapped_datasets_dir().mkdir(exist_ok=True, parents=True)
|
|
89
|
+
|
|
90
|
+
@abc.abstractmethod
|
|
91
|
+
def submit(self, *args, **kwargs) -> DataFrame:
|
|
92
|
+
"""Submit a query for execution"""
|
|
93
|
+
|
|
94
|
+
def _composite_datasets_dir(self):
|
|
95
|
+
return self._output_dir / "composite_datasets"
|
|
96
|
+
|
|
97
|
+
def _cached_tables_dir(self):
|
|
98
|
+
"""Directory for intermediate tables made up of multiple project-mapped datasets."""
|
|
99
|
+
return self._output_dir / "cached_tables"
|
|
100
|
+
|
|
101
|
+
def _cached_project_mapped_datasets_dir(self):
|
|
102
|
+
"""Directory for intermediate project-mapped datasets.
|
|
103
|
+
Data could be filtered.
|
|
104
|
+
"""
|
|
105
|
+
return self._output_dir / "cached_project_mapped_datasets"
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def metadata_filename(path: Path):
|
|
109
|
+
return path / "metadata.json"
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def query_filename(path: Path):
|
|
113
|
+
return path / "query.json5"
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def table_filename(path: Path):
|
|
117
|
+
return path / "table.parquet"
|
|
118
|
+
|
|
119
|
+
@staticmethod
|
|
120
|
+
def _cached_table_filename(path: Path):
|
|
121
|
+
return path / "table.parquet"
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class ProjectBasedQuerySubmitter(QuerySubmitterBase):
|
|
125
|
+
def __init__(self, project: Project, *args, **kwargs):
|
|
126
|
+
super().__init__(*args, **kwargs)
|
|
127
|
+
self._project = project
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def project(self):
|
|
131
|
+
return self._project
|
|
132
|
+
|
|
133
|
+
def _create_table_hash(self, context: QueryContext) -> tuple[str, str]:
|
|
134
|
+
"""Create a hash that can be used to identify whether the following sequence
|
|
135
|
+
can be skipped based on a previous query:
|
|
136
|
+
- Apply expression across all datasets in the query.
|
|
137
|
+
- Apply filters.
|
|
138
|
+
- Apply aggregations.
|
|
139
|
+
|
|
140
|
+
Examples of changes that will invalidate the query:
|
|
141
|
+
- Change to the project section of the query
|
|
142
|
+
- Bump to project major version number
|
|
143
|
+
- Change to a dataset version
|
|
144
|
+
- Change to a project's dimension requirements for a dataset
|
|
145
|
+
- Change to a dataset dimension mapping
|
|
146
|
+
"""
|
|
147
|
+
assert isinstance(context.model, ProjectQueryModel) or isinstance(
|
|
148
|
+
context.model, CreateCompositeDatasetQueryModel
|
|
149
|
+
)
|
|
150
|
+
data = {
|
|
151
|
+
"project_major_version": VersionInfo.parse(self._project.config.model.version).major,
|
|
152
|
+
"project_query": context.model.serialize_cached_content(),
|
|
153
|
+
"datasets": [
|
|
154
|
+
self._project.config.get_dataset(x.dataset_id).model_dump(mode="json")
|
|
155
|
+
for x in context.model.project.dataset.source_datasets
|
|
156
|
+
],
|
|
157
|
+
}
|
|
158
|
+
text = json.dumps(data, indent=2)
|
|
159
|
+
hash_value = compute_hash(text.encode())
|
|
160
|
+
return text, hash_value
|
|
161
|
+
|
|
162
|
+
def _try_read_cache(self, context: QueryContext):
|
|
163
|
+
_, hash_value = self._create_table_hash(context)
|
|
164
|
+
cached_dir = self._cached_tables_dir() / hash_value
|
|
165
|
+
filename = self._cached_table_filename(cached_dir)
|
|
166
|
+
df = try_read_dataframe(filename)
|
|
167
|
+
if df is not None:
|
|
168
|
+
logger.info("Load intermediate table from cache: %s", filename)
|
|
169
|
+
metadata_file = self.metadata_filename(cached_dir)
|
|
170
|
+
return df, DatasetMetadataModel.from_file(metadata_file)
|
|
171
|
+
return None, None
|
|
172
|
+
|
|
173
|
+
def _run_checks(self, model: ProjectQueryModel) -> DatasetBaseDimensionNamesModel:
|
|
174
|
+
subsets = set(self.project.config.list_dimension_names(DimensionCategory.SUBSET))
|
|
175
|
+
for agg in model.result.aggregations:
|
|
176
|
+
for _, column in agg.iter_dimensions_to_keep():
|
|
177
|
+
dimension_name = column.dimension_name
|
|
178
|
+
if dimension_name in subsets:
|
|
179
|
+
subset_dim = self._project.config.get_dimension(dimension_name)
|
|
180
|
+
dim_type = subset_dim.model.dimension_type
|
|
181
|
+
supp_names = " ".join(
|
|
182
|
+
self._project.config.get_supplemental_dimension_to_name_mapping()[dim_type]
|
|
183
|
+
)
|
|
184
|
+
base_names = [
|
|
185
|
+
x.model.name
|
|
186
|
+
for x in self._project.config.list_base_dimensions(dimension_type=dim_type)
|
|
187
|
+
]
|
|
188
|
+
msg = (
|
|
189
|
+
f"Subset dimensions cannot be used in aggregations: "
|
|
190
|
+
f"{dimension_name=}. Only base and supplemental dimensions are "
|
|
191
|
+
f"allowed. base={base_names} supplemental={supp_names}"
|
|
192
|
+
)
|
|
193
|
+
raise DSGInvalidQuery(msg)
|
|
194
|
+
|
|
195
|
+
for report_inputs in model.result.reports:
|
|
196
|
+
report = make_report(report_inputs.report_type)
|
|
197
|
+
report.check_query(model)
|
|
198
|
+
|
|
199
|
+
with self._project.dimension_mapping_manager.db.engine.connect() as conn:
|
|
200
|
+
return self._check_datasets(model, conn)
|
|
201
|
+
|
|
202
|
+
def _check_datasets(
|
|
203
|
+
self, query_model: ProjectQueryModel, conn: Connection
|
|
204
|
+
) -> DatasetBaseDimensionNamesModel:
|
|
205
|
+
base_dimension_names: DatasetBaseDimensionNamesModel | None = None
|
|
206
|
+
dataset_ids: list[str] = []
|
|
207
|
+
query_names: list[DatasetBaseDimensionNamesModel] = []
|
|
208
|
+
for dataset in query_model.project.dataset.source_datasets:
|
|
209
|
+
src_dataset_ids = dataset.list_source_dataset_ids()
|
|
210
|
+
dataset_ids += src_dataset_ids
|
|
211
|
+
if isinstance(dataset, StandaloneDatasetModel):
|
|
212
|
+
query_names.append(
|
|
213
|
+
self._project.config.get_dataset_base_dimension_names(dataset.dataset_id)
|
|
214
|
+
)
|
|
215
|
+
elif isinstance(dataset, ProjectionDatasetModel):
|
|
216
|
+
query_names += [
|
|
217
|
+
self._project.config.get_dataset_base_dimension_names(
|
|
218
|
+
dataset.initial_value_dataset_id
|
|
219
|
+
),
|
|
220
|
+
self._project.config.get_dataset_base_dimension_names(
|
|
221
|
+
dataset.growth_rate_dataset_id
|
|
222
|
+
),
|
|
223
|
+
]
|
|
224
|
+
else:
|
|
225
|
+
msg = f"Unhandled dataset type: {dataset=}"
|
|
226
|
+
raise NotImplementedError(msg)
|
|
227
|
+
|
|
228
|
+
for dataset_id in src_dataset_ids:
|
|
229
|
+
dataset = self._project.load_dataset(dataset_id, conn=conn)
|
|
230
|
+
plan = query_model.project.get_dataset_mapping_plan(dataset_id)
|
|
231
|
+
if plan is None:
|
|
232
|
+
plan = dataset.handler.build_default_dataset_mapping_plan()
|
|
233
|
+
query_model.project.set_dataset_mapper(plan)
|
|
234
|
+
else:
|
|
235
|
+
dataset.handler.check_dataset_mapping_plan(plan, self._project.config)
|
|
236
|
+
|
|
237
|
+
for dataset_id, names in zip(dataset_ids, query_names):
|
|
238
|
+
self._fix_legacy_base_dimension_names(names, dataset_id)
|
|
239
|
+
if base_dimension_names is None:
|
|
240
|
+
base_dimension_names = names
|
|
241
|
+
elif base_dimension_names != names:
|
|
242
|
+
msg = (
|
|
243
|
+
"Datasets in a query must have the same base dimension query names: "
|
|
244
|
+
f"{dataset=} {base_dimension_names} {names}"
|
|
245
|
+
)
|
|
246
|
+
raise DSGInvalidQuery(msg)
|
|
247
|
+
|
|
248
|
+
assert base_dimension_names is not None
|
|
249
|
+
return base_dimension_names
|
|
250
|
+
|
|
251
|
+
def _fix_legacy_base_dimension_names(
|
|
252
|
+
self, names: DatasetBaseDimensionNamesModel, dataset_id: str
|
|
253
|
+
) -> None:
|
|
254
|
+
for dim_type in DimensionType:
|
|
255
|
+
val = getattr(names, dim_type.value)
|
|
256
|
+
if val is None:
|
|
257
|
+
# This is a workaround for dsgrid projects created before the field
|
|
258
|
+
# base_dimension_names was added to InputDatasetModel.
|
|
259
|
+
dims = self._project.config.list_base_dimensions(dimension_type=dim_type)
|
|
260
|
+
if len(dims) > 1:
|
|
261
|
+
msg = (
|
|
262
|
+
"The dataset's base_dimension_names value is not set and "
|
|
263
|
+
f"there are multiple base dimensions of type {dim_type} in the project. "
|
|
264
|
+
f"Please re-register the dataset with {dataset_id=}."
|
|
265
|
+
)
|
|
266
|
+
raise DSGInvalidDataset(msg)
|
|
267
|
+
setattr(names, dim_type.value, dims[0].model.name)
|
|
268
|
+
|
|
269
|
+
def _run_query(
|
|
270
|
+
self,
|
|
271
|
+
scratch_dir_context: ScratchDirContext,
|
|
272
|
+
model: ProjectQueryModel,
|
|
273
|
+
load_cached_table: bool,
|
|
274
|
+
checkpoint_file: Path | None,
|
|
275
|
+
persist_intermediate_table: bool,
|
|
276
|
+
zip_file: bool = False,
|
|
277
|
+
overwrite: bool = False,
|
|
278
|
+
):
|
|
279
|
+
base_dimension_names = self._run_checks(model)
|
|
280
|
+
checkpoint = self._check_checkpoint_file(checkpoint_file, model)
|
|
281
|
+
context = QueryContext(
|
|
282
|
+
model,
|
|
283
|
+
base_dimension_names,
|
|
284
|
+
scratch_dir_context=scratch_dir_context,
|
|
285
|
+
checkpoint=checkpoint,
|
|
286
|
+
)
|
|
287
|
+
assert isinstance(context.model, ProjectQueryModel) or isinstance(
|
|
288
|
+
context.model, CreateCompositeDatasetQueryModel
|
|
289
|
+
)
|
|
290
|
+
context.model.project.version = str(self._project.version)
|
|
291
|
+
output_dir = self._output_dir / context.model.name
|
|
292
|
+
if output_dir.exists() and not overwrite:
|
|
293
|
+
msg = (
|
|
294
|
+
f"output directory {self._output_dir} and query name={context.model.name} will "
|
|
295
|
+
"overwrite an existing query results directory. "
|
|
296
|
+
"Choose a different path or pass force=True."
|
|
297
|
+
)
|
|
298
|
+
raise DSGInvalidParameter(msg)
|
|
299
|
+
|
|
300
|
+
df = None
|
|
301
|
+
if load_cached_table:
|
|
302
|
+
df, metadata = self._try_read_cache(context)
|
|
303
|
+
if df is None:
|
|
304
|
+
df_filenames = self._project.process_query(
|
|
305
|
+
context, self._cached_project_mapped_datasets_dir()
|
|
306
|
+
)
|
|
307
|
+
df = self._postprocess_datasets(context, scratch_dir_context, df_filenames)
|
|
308
|
+
is_cached = False
|
|
309
|
+
else:
|
|
310
|
+
context.metadata = metadata
|
|
311
|
+
is_cached = True
|
|
312
|
+
|
|
313
|
+
if context.model.result.aggregate_each_dataset:
|
|
314
|
+
# This wouldn't save any time.
|
|
315
|
+
persist_intermediate_table = False
|
|
316
|
+
|
|
317
|
+
if persist_intermediate_table and not is_cached:
|
|
318
|
+
df = self._persist_intermediate_result(context, df)
|
|
319
|
+
|
|
320
|
+
if not context.model.result.aggregate_each_dataset:
|
|
321
|
+
if context.model.result.dimension_filters:
|
|
322
|
+
df = self._apply_filters(df, context)
|
|
323
|
+
df = self._process_aggregations(df, context)
|
|
324
|
+
|
|
325
|
+
repartition = not persist_intermediate_table
|
|
326
|
+
table_filename = self._save_query_results(context, df, repartition, zip_file=zip_file)
|
|
327
|
+
|
|
328
|
+
for report_inputs in context.model.result.reports:
|
|
329
|
+
report = make_report(report_inputs.report_type)
|
|
330
|
+
output_dir = self._output_dir / context.model.name
|
|
331
|
+
report.generate(table_filename, output_dir, context, report_inputs.inputs)
|
|
332
|
+
|
|
333
|
+
return df, context
|
|
334
|
+
|
|
335
|
+
def _convert_time_zone(
|
|
336
|
+
self,
|
|
337
|
+
scratch_dir_context: ScratchDirContext,
|
|
338
|
+
model: ProjectQueryModel,
|
|
339
|
+
df,
|
|
340
|
+
context,
|
|
341
|
+
persist_intermediate_table: bool,
|
|
342
|
+
zip_file: bool = False,
|
|
343
|
+
):
|
|
344
|
+
time_dim = copy.deepcopy(self._project.config.get_base_time_dimension())
|
|
345
|
+
if not isinstance(time_dim, DateTimeDimensionConfig):
|
|
346
|
+
msg = f"Only DateTimeDimensionConfig allowed for time zone conversion. {time_dim.__class__.__name__}"
|
|
347
|
+
raise DSGInvalidOperation(msg)
|
|
348
|
+
time_cols = list(context.get_dimension_column_names(DimensionType.TIME))
|
|
349
|
+
assert len(time_cols) == 1
|
|
350
|
+
time_col = next(iter(time_cols))
|
|
351
|
+
time_dim.model.time_column = time_col
|
|
352
|
+
|
|
353
|
+
config = dsgrid.runtime_config
|
|
354
|
+
if isinstance(model.result.time_zone, str) and model.result.time_zone != "geography":
|
|
355
|
+
if time_dim.supports_chronify():
|
|
356
|
+
match (config.backend_engine, config.use_hive_metastore):
|
|
357
|
+
case (BackendEngine.SPARK, True):
|
|
358
|
+
df = convert_time_zone_with_chronify_spark_hive(
|
|
359
|
+
df=df,
|
|
360
|
+
value_column=VALUE_COLUMN,
|
|
361
|
+
from_time_dim=time_dim,
|
|
362
|
+
time_zone=model.result.time_zone,
|
|
363
|
+
scratch_dir_context=scratch_dir_context,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
case (BackendEngine.SPARK, False):
|
|
367
|
+
filename = persist_table(
|
|
368
|
+
df,
|
|
369
|
+
scratch_dir_context,
|
|
370
|
+
tag="project query before time mapping",
|
|
371
|
+
)
|
|
372
|
+
df = convert_time_zone_with_chronify_spark_path(
|
|
373
|
+
df=df,
|
|
374
|
+
filename=filename,
|
|
375
|
+
value_column=VALUE_COLUMN,
|
|
376
|
+
from_time_dim=time_dim,
|
|
377
|
+
time_zone=model.result.time_zone,
|
|
378
|
+
scratch_dir_context=scratch_dir_context,
|
|
379
|
+
)
|
|
380
|
+
case (BackendEngine.DUCKDB, _):
|
|
381
|
+
df = convert_time_zone_with_chronify_duckdb(
|
|
382
|
+
df=df,
|
|
383
|
+
value_column=VALUE_COLUMN,
|
|
384
|
+
from_time_dim=time_dim,
|
|
385
|
+
time_zone=model.result.time_zone,
|
|
386
|
+
scratch_dir_context=scratch_dir_context,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
else:
|
|
390
|
+
msg = "time_dim must support Chronify"
|
|
391
|
+
raise DSGInvalidParameter(msg)
|
|
392
|
+
|
|
393
|
+
elif model.result.time_zone == "geography":
|
|
394
|
+
if "time_zone" not in df.columns:
|
|
395
|
+
geo_cols = list(context.get_dimension_column_names(DimensionType.GEOGRAPHY))
|
|
396
|
+
assert len(geo_cols) == 1
|
|
397
|
+
geo_col = next(iter(geo_cols))
|
|
398
|
+
geo_dim = self._project.config.get_base_dimension(DimensionType.GEOGRAPHY)
|
|
399
|
+
if model.result.replace_ids_with_names:
|
|
400
|
+
dim_key = "name"
|
|
401
|
+
else:
|
|
402
|
+
dim_key = "id"
|
|
403
|
+
df = add_time_zone(df, geo_dim, df_key=geo_col, dim_key=dim_key)
|
|
404
|
+
|
|
405
|
+
if time_dim.supports_chronify():
|
|
406
|
+
# use chronify
|
|
407
|
+
match (config.backend_engine, config.use_hive_metastore):
|
|
408
|
+
case (BackendEngine.SPARK, True):
|
|
409
|
+
df = convert_time_zone_by_column_with_chronify_spark_hive(
|
|
410
|
+
df=df,
|
|
411
|
+
value_column=VALUE_COLUMN,
|
|
412
|
+
from_time_dim=time_dim,
|
|
413
|
+
time_zone_column="time_zone",
|
|
414
|
+
scratch_dir_context=scratch_dir_context,
|
|
415
|
+
wrap_time_allowed=False,
|
|
416
|
+
)
|
|
417
|
+
case (BackendEngine.SPARK, False):
|
|
418
|
+
filename = persist_table(
|
|
419
|
+
df,
|
|
420
|
+
scratch_dir_context,
|
|
421
|
+
tag="project query before time mapping",
|
|
422
|
+
)
|
|
423
|
+
df = convert_time_zone_by_column_with_chronify_spark_path(
|
|
424
|
+
df=df,
|
|
425
|
+
filename=filename,
|
|
426
|
+
value_column=VALUE_COLUMN,
|
|
427
|
+
from_time_dim=time_dim,
|
|
428
|
+
time_zone_column="time_zone",
|
|
429
|
+
scratch_dir_context=scratch_dir_context,
|
|
430
|
+
wrap_time_allowed=False,
|
|
431
|
+
)
|
|
432
|
+
case (BackendEngine.DUCKDB, _):
|
|
433
|
+
df = convert_time_zone_by_column_with_chronify_duckdb(
|
|
434
|
+
df=df,
|
|
435
|
+
value_column=VALUE_COLUMN,
|
|
436
|
+
from_time_dim=time_dim,
|
|
437
|
+
time_zone_column="time_zone",
|
|
438
|
+
scratch_dir_context=scratch_dir_context,
|
|
439
|
+
wrap_time_allowed=False,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
else:
|
|
443
|
+
msg = "time_dim must support Chronify"
|
|
444
|
+
raise DSGInvalidParameter(msg)
|
|
445
|
+
else:
|
|
446
|
+
msg = f"Unknown input {model.result.time_zone=}"
|
|
447
|
+
raise DSGInvalidParameter(msg)
|
|
448
|
+
|
|
449
|
+
repartition = not persist_intermediate_table
|
|
450
|
+
table_filename = self._save_query_results(context, df, repartition, zip_file=zip_file)
|
|
451
|
+
|
|
452
|
+
for report_inputs in context.model.result.reports:
|
|
453
|
+
report = make_report(report_inputs.report_type)
|
|
454
|
+
output_dir = self._output_dir / context.model.name
|
|
455
|
+
report.generate(table_filename, output_dir, context, report_inputs.inputs)
|
|
456
|
+
|
|
457
|
+
return df, context
|
|
458
|
+
|
|
459
|
+
def _check_checkpoint_file(
|
|
460
|
+
self, checkpoint_file: Path | None, model: ProjectQueryModel
|
|
461
|
+
) -> MapOperationCheckpoint | None:
|
|
462
|
+
if checkpoint_file is None:
|
|
463
|
+
return None
|
|
464
|
+
|
|
465
|
+
checkpoint = MapOperationCheckpoint.from_file(checkpoint_file)
|
|
466
|
+
confirmed_checkpoint = False
|
|
467
|
+
for dataset in model.project.dataset.source_datasets:
|
|
468
|
+
if dataset.get_dataset_id() == checkpoint.dataset_id:
|
|
469
|
+
for plan in model.project.mapping_plans:
|
|
470
|
+
if plan.dataset_id == checkpoint.dataset_id:
|
|
471
|
+
if plan.compute_hash() == checkpoint.mapping_plan_hash:
|
|
472
|
+
confirmed_checkpoint = True
|
|
473
|
+
else:
|
|
474
|
+
msg = (
|
|
475
|
+
f"The hash of the mapping plan for dataset {checkpoint.dataset_id} "
|
|
476
|
+
"does not match the checkpoint file. Cannot use the checkpoint."
|
|
477
|
+
)
|
|
478
|
+
raise DSGInvalidParameter(msg)
|
|
479
|
+
if not confirmed_checkpoint:
|
|
480
|
+
msg = f"Checkpoint {checkpoint_file} does not match any dataset in the query."
|
|
481
|
+
raise DSGInvalidParameter(msg)
|
|
482
|
+
|
|
483
|
+
return checkpoint
|
|
484
|
+
|
|
485
|
+
@track_timing(timer_stats_collector)
|
|
486
|
+
def _persist_intermediate_result(self, context: QueryContext, df):
|
|
487
|
+
text, hash_value = self._create_table_hash(context)
|
|
488
|
+
cached_dir = self._cached_tables_dir() / hash_value
|
|
489
|
+
if cached_dir.exists():
|
|
490
|
+
shutil.rmtree(cached_dir)
|
|
491
|
+
cached_dir.mkdir()
|
|
492
|
+
filename = self._cached_table_filename(cached_dir)
|
|
493
|
+
df = write_dataframe_and_auto_partition(df, filename)
|
|
494
|
+
|
|
495
|
+
self.metadata_filename(cached_dir).write_text(
|
|
496
|
+
context.metadata.model_dump_json(indent=2), encoding="utf-8"
|
|
497
|
+
)
|
|
498
|
+
self.query_filename(cached_dir).write_text(text, encoding="utf-8")
|
|
499
|
+
logger.debug("Persisted intermediate table to %s", filename)
|
|
500
|
+
return df
|
|
501
|
+
|
|
502
|
+
def _postprocess_datasets(
|
|
503
|
+
self,
|
|
504
|
+
context: QueryContext,
|
|
505
|
+
scratch_dir_context: ScratchDirContext,
|
|
506
|
+
df_filenames: dict[str, Path],
|
|
507
|
+
) -> DataFrame:
|
|
508
|
+
if context.model.result.aggregate_each_dataset:
|
|
509
|
+
for dataset_id, path in df_filenames.items():
|
|
510
|
+
df = read_dataframe(path)
|
|
511
|
+
if context.model.result.dimension_filters:
|
|
512
|
+
df = self._apply_filters(df, context)
|
|
513
|
+
df = self._process_aggregations(df, context, dataset_id=dataset_id)
|
|
514
|
+
path = scratch_dir_context.get_temp_filename(suffix=".parquet")
|
|
515
|
+
write_dataframe(df, path)
|
|
516
|
+
df_filenames[dataset_id] = path
|
|
517
|
+
|
|
518
|
+
# All dataset columns need to be in the same order.
|
|
519
|
+
context.consolidate_dataset_metadata()
|
|
520
|
+
datasets = self._convert_datasets(context, df_filenames)
|
|
521
|
+
assert isinstance(context.model, ProjectQueryModel) or isinstance(
|
|
522
|
+
context.model, CreateCompositeDatasetQueryModel
|
|
523
|
+
)
|
|
524
|
+
assert context.model.project.dataset.expression is not None
|
|
525
|
+
return evaluate_expression(context.model.project.dataset.expression, datasets).df
|
|
526
|
+
|
|
527
|
+
def _convert_datasets(self, context: QueryContext, filenames: dict[str, Path]):
|
|
528
|
+
dim_columns, time_columns = self._get_dimension_columns(context)
|
|
529
|
+
expected_columns = time_columns + dim_columns
|
|
530
|
+
expected_columns.append(VALUE_COLUMN)
|
|
531
|
+
|
|
532
|
+
datasets = {}
|
|
533
|
+
for dataset_id, path in filenames.items():
|
|
534
|
+
df = read_dataframe(path)
|
|
535
|
+
unexpected = sorted(set(df.columns).difference(expected_columns))
|
|
536
|
+
if unexpected:
|
|
537
|
+
msg = f"Unexpected columns are present in {dataset_id=} {unexpected=}"
|
|
538
|
+
raise Exception(msg)
|
|
539
|
+
datasets[dataset_id] = DatasetExpressionHandler(
|
|
540
|
+
df.select(*expected_columns), time_columns + dim_columns, [VALUE_COLUMN]
|
|
541
|
+
)
|
|
542
|
+
return datasets
|
|
543
|
+
|
|
544
|
+
def _get_dimension_columns(self, context: QueryContext) -> tuple[list[str], list[str]]:
|
|
545
|
+
match context.model.result.column_type:
|
|
546
|
+
case ColumnType.DIMENSION_NAMES:
|
|
547
|
+
dim_columns = context.get_all_dimension_column_names(exclude={DimensionType.TIME})
|
|
548
|
+
time_columns = context.get_dimension_column_names(DimensionType.TIME)
|
|
549
|
+
case ColumnType.DIMENSION_TYPES:
|
|
550
|
+
dim_columns = {x.value for x in DimensionType if x != DimensionType.TIME}
|
|
551
|
+
time_columns = context.get_dimension_column_names(DimensionType.TIME)
|
|
552
|
+
case _:
|
|
553
|
+
msg = f"BUG: unhandled {context.model.result.column_type=}"
|
|
554
|
+
raise NotImplementedError(msg)
|
|
555
|
+
|
|
556
|
+
return sorted(dim_columns), sorted(time_columns)
|
|
557
|
+
|
|
558
|
+
def _process_aggregations(
|
|
559
|
+
self, df: DataFrame, context: QueryContext, dataset_id: str | None = None
|
|
560
|
+
) -> DataFrame:
|
|
561
|
+
handler = make_table_format_handler(
|
|
562
|
+
ValueFormat.STACKED, self._project.config, dataset_id=dataset_id
|
|
563
|
+
)
|
|
564
|
+
df = handler.process_aggregations(df, context.model.result.aggregations, context)
|
|
565
|
+
|
|
566
|
+
if context.model.result.replace_ids_with_names:
|
|
567
|
+
df = handler.replace_ids_with_names(df)
|
|
568
|
+
|
|
569
|
+
if context.model.result.sort_columns:
|
|
570
|
+
df = df.sort(*context.model.result.sort_columns)
|
|
571
|
+
|
|
572
|
+
if isinstance(context.model.result.table_format, PivotedTableFormatModel):
|
|
573
|
+
df = _pivot_table(df, context)
|
|
574
|
+
|
|
575
|
+
return df
|
|
576
|
+
|
|
577
|
+
def _process_aggregations_and_save(
|
|
578
|
+
self,
|
|
579
|
+
df: DataFrame,
|
|
580
|
+
context: QueryContext,
|
|
581
|
+
repartition: bool,
|
|
582
|
+
zip_file: bool = False,
|
|
583
|
+
) -> DataFrame:
|
|
584
|
+
df = self._process_aggregations(df, context)
|
|
585
|
+
|
|
586
|
+
self._save_query_results(context, df, repartition, zip_file=zip_file)
|
|
587
|
+
return df
|
|
588
|
+
|
|
589
|
+
def _apply_filters(self, df, context: QueryContext):
|
|
590
|
+
for dim_filter in context.model.result.dimension_filters:
|
|
591
|
+
column_names = context.get_dimension_column_names(dim_filter.dimension_type)
|
|
592
|
+
if len(column_names) > 1:
|
|
593
|
+
msg = f"Cannot filter {dim_filter} when there are multiple {column_names=}"
|
|
594
|
+
raise NotImplementedError(msg)
|
|
595
|
+
if isinstance(dim_filter, SubsetDimensionFilterModel):
|
|
596
|
+
records = dim_filter.get_filtered_records_dataframe(
|
|
597
|
+
self.project.config.get_dimension
|
|
598
|
+
)
|
|
599
|
+
column = next(iter(column_names))
|
|
600
|
+
df = df.join(
|
|
601
|
+
records.select("id"),
|
|
602
|
+
on=getattr(df, column) == getattr(records, "id"),
|
|
603
|
+
).drop("id")
|
|
604
|
+
else:
|
|
605
|
+
query_name = dim_filter.dimension_name
|
|
606
|
+
if query_name not in df.columns:
|
|
607
|
+
# Consider catching this exception and still write to a file.
|
|
608
|
+
# It could mean writing a lot of data the user doesn't want.
|
|
609
|
+
msg = f"filter column {query_name} is not in the dataframe: {df.columns}"
|
|
610
|
+
raise DSGInvalidParameter(msg)
|
|
611
|
+
df = dim_filter.apply_filter(df, column=query_name)
|
|
612
|
+
return df
|
|
613
|
+
|
|
614
|
+
@track_timing(timer_stats_collector)
|
|
615
|
+
def _save_query_results(
|
|
616
|
+
self,
|
|
617
|
+
context: QueryContext,
|
|
618
|
+
df,
|
|
619
|
+
repartition,
|
|
620
|
+
aggregation_name=None,
|
|
621
|
+
zip_file=False,
|
|
622
|
+
):
|
|
623
|
+
output_dir = self._output_dir / context.model.name
|
|
624
|
+
output_dir.mkdir(exist_ok=True)
|
|
625
|
+
if aggregation_name is not None:
|
|
626
|
+
output_dir /= aggregation_name
|
|
627
|
+
output_dir.mkdir(exist_ok=True)
|
|
628
|
+
filename = output_dir / f"table.{context.model.result.output_format}"
|
|
629
|
+
self._save_result(context, df, filename, repartition)
|
|
630
|
+
if zip_file:
|
|
631
|
+
zip_name = Path(str(output_dir) + ".zip")
|
|
632
|
+
with ZipFile(zip_name, "w") as zipf:
|
|
633
|
+
for path in output_dir.rglob("*"):
|
|
634
|
+
zipf.write(path)
|
|
635
|
+
return filename
|
|
636
|
+
|
|
637
|
+
def _save_result(self, context: QueryContext, df, filename, repartition):
|
|
638
|
+
output_dir = filename.parent
|
|
639
|
+
suffix = filename.suffix
|
|
640
|
+
if suffix == ".csv":
|
|
641
|
+
df.toPandas().to_csv(filename, header=True, index=False)
|
|
642
|
+
elif suffix == ".parquet":
|
|
643
|
+
if repartition:
|
|
644
|
+
df = write_dataframe_and_auto_partition(df, filename)
|
|
645
|
+
else:
|
|
646
|
+
delete_if_exists(filename)
|
|
647
|
+
write_dataframe(df, filename, overwrite=True)
|
|
648
|
+
else:
|
|
649
|
+
msg = f"Unsupported output_format={suffix}"
|
|
650
|
+
raise NotImplementedError(msg)
|
|
651
|
+
self.query_filename(output_dir).write_text(context.model.serialize_with_hash()[1])
|
|
652
|
+
self.metadata_filename(output_dir).write_text(context.metadata.model_dump_json(indent=2))
|
|
653
|
+
logger.info("Wrote query=%s output table to %s", context.model.name, filename)
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
class ProjectQuerySubmitter(ProjectBasedQuerySubmitter):
|
|
657
|
+
"""Submits queries for a project."""
|
|
658
|
+
|
|
659
|
+
@track_timing(timer_stats_collector)
|
|
660
|
+
def submit(
|
|
661
|
+
self,
|
|
662
|
+
model: ProjectQueryModel,
|
|
663
|
+
scratch_dir: Path | None = None,
|
|
664
|
+
checkpoint_file: Path | None = None,
|
|
665
|
+
persist_intermediate_table: bool = True,
|
|
666
|
+
load_cached_table: bool = True,
|
|
667
|
+
zip_file: bool = False,
|
|
668
|
+
overwrite: bool = False,
|
|
669
|
+
) -> DataFrame:
|
|
670
|
+
"""Submits a project query to consolidate datasets and produce result tables.
|
|
671
|
+
|
|
672
|
+
Parameters
|
|
673
|
+
----------
|
|
674
|
+
model : ProjectQueryResultModel
|
|
675
|
+
checkpoint_file : bool, optional
|
|
676
|
+
Optional checkpoint file from which to resume the operation.
|
|
677
|
+
persist_intermediate_table : bool, optional
|
|
678
|
+
Persist the intermediate consolidated table.
|
|
679
|
+
load_cached_table : bool, optional
|
|
680
|
+
Load a cached consolidated table if the query matches an existing query.
|
|
681
|
+
zip_file : bool, optional
|
|
682
|
+
Create a zip file with all output files.
|
|
683
|
+
overwrite : bool
|
|
684
|
+
If True, overwrite any existing output directory.
|
|
685
|
+
|
|
686
|
+
Returns
|
|
687
|
+
-------
|
|
688
|
+
pyspark.sql.DataFrame
|
|
689
|
+
|
|
690
|
+
Raises
|
|
691
|
+
------
|
|
692
|
+
DSGInvalidParameter
|
|
693
|
+
Raised if the model defines a project version
|
|
694
|
+
DSGInvalidQuery
|
|
695
|
+
Raised if the query is invalid
|
|
696
|
+
"""
|
|
697
|
+
tz = self._project.config.get_base_time_dimension().get_time_zone()
|
|
698
|
+
assert tz is not None, "Project base time dimension must have a time zone"
|
|
699
|
+
|
|
700
|
+
scratch_dir = scratch_dir or DsgridRuntimeConfig.load().get_scratch_dir()
|
|
701
|
+
with ScratchDirContext(scratch_dir) as scratch_dir_context:
|
|
702
|
+
# Ensure that queries that aggregate time reflect the project's time zone instead
|
|
703
|
+
# of the local computer.
|
|
704
|
+
# If any other settings get customized here, handle them in restart_spark()
|
|
705
|
+
# as well. This change won't persist Spark session restarts.
|
|
706
|
+
with custom_time_zone(tz):
|
|
707
|
+
df, context = self._run_query(
|
|
708
|
+
scratch_dir_context,
|
|
709
|
+
model,
|
|
710
|
+
load_cached_table,
|
|
711
|
+
checkpoint_file=checkpoint_file,
|
|
712
|
+
persist_intermediate_table=persist_intermediate_table,
|
|
713
|
+
zip_file=zip_file,
|
|
714
|
+
overwrite=overwrite,
|
|
715
|
+
)
|
|
716
|
+
if model.result.time_zone:
|
|
717
|
+
df, context = self._convert_time_zone(
|
|
718
|
+
scratch_dir_context,
|
|
719
|
+
model,
|
|
720
|
+
df,
|
|
721
|
+
context,
|
|
722
|
+
persist_intermediate_table=persist_intermediate_table,
|
|
723
|
+
zip_file=zip_file,
|
|
724
|
+
)
|
|
725
|
+
context.finalize()
|
|
726
|
+
|
|
727
|
+
return df
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
class CompositeDatasetQuerySubmitter(ProjectBasedQuerySubmitter):
|
|
731
|
+
"""Submits queries for a composite dataset."""
|
|
732
|
+
|
|
733
|
+
@track_timing(timer_stats_collector)
|
|
734
|
+
def create_dataset(
|
|
735
|
+
self,
|
|
736
|
+
model: CreateCompositeDatasetQueryModel,
|
|
737
|
+
scratch_dir: Path | None = None,
|
|
738
|
+
persist_intermediate_table=False,
|
|
739
|
+
load_cached_table=True,
|
|
740
|
+
force=False,
|
|
741
|
+
):
|
|
742
|
+
"""Create a composite dataset from a project.
|
|
743
|
+
|
|
744
|
+
Parameters
|
|
745
|
+
----------
|
|
746
|
+
model : CreateCompositeDatasetQueryModel
|
|
747
|
+
persist_intermediate_table : bool, optional
|
|
748
|
+
Persist the intermediate consolidated table.
|
|
749
|
+
load_cached_table : bool, optional
|
|
750
|
+
Load a cached consolidated table if the query matches an existing query.
|
|
751
|
+
force : bool
|
|
752
|
+
If True, overwrite any existing output directory.
|
|
753
|
+
|
|
754
|
+
"""
|
|
755
|
+
tz = self._project.config.get_base_time_dimension().get_time_zone()
|
|
756
|
+
scratch_dir = scratch_dir or DsgridRuntimeConfig.load().get_scratch_dir()
|
|
757
|
+
with ScratchDirContext(scratch_dir) as scratch_dir_context:
|
|
758
|
+
# Ensure that queries that aggregate time reflect the project's time zone instead
|
|
759
|
+
# of the local computer.
|
|
760
|
+
# If any other settings get customized here, handle them in restart_spark()
|
|
761
|
+
# as well. This change won't persist Spark session restarts.
|
|
762
|
+
with custom_time_zone(tz): # type: ignore
|
|
763
|
+
df, context = self._run_query(
|
|
764
|
+
scratch_dir_context,
|
|
765
|
+
model,
|
|
766
|
+
load_cached_table,
|
|
767
|
+
None,
|
|
768
|
+
persist_intermediate_table,
|
|
769
|
+
overwrite=force,
|
|
770
|
+
)
|
|
771
|
+
self._save_composite_dataset(context, df, not persist_intermediate_table)
|
|
772
|
+
context.finalize()
|
|
773
|
+
|
|
774
|
+
@track_timing(timer_stats_collector)
|
|
775
|
+
def submit(
|
|
776
|
+
self,
|
|
777
|
+
query: CompositeDatasetQueryModel,
|
|
778
|
+
scratch_dir: Path | None = None,
|
|
779
|
+
) -> DataFrame:
|
|
780
|
+
"""Submit a query to an composite dataset and produce result tables.
|
|
781
|
+
|
|
782
|
+
Parameters
|
|
783
|
+
----------
|
|
784
|
+
query : CompositeDatasetQueryModel
|
|
785
|
+
scratch_dir : Path | None
|
|
786
|
+
"""
|
|
787
|
+
tz = self._project.config.get_base_time_dimension().get_time_zone()
|
|
788
|
+
assert tz is not None
|
|
789
|
+
scratch_dir = DsgridRuntimeConfig.load().get_scratch_dir()
|
|
790
|
+
# orig_query = self._load_composite_dataset_query(query.dataset_id)
|
|
791
|
+
with ScratchDirContext(scratch_dir) as scratch_dir_context:
|
|
792
|
+
df, metadata = self._read_dataset(query.dataset_id)
|
|
793
|
+
base_dimension_names = DatasetBaseDimensionNamesModel()
|
|
794
|
+
for dim_type in DimensionType:
|
|
795
|
+
field = dim_type.value
|
|
796
|
+
query_names = getattr(metadata.dimensions, field)
|
|
797
|
+
if len(query_names) > 1:
|
|
798
|
+
msg = (
|
|
799
|
+
"Composite datasets must have a single query name for each dimension: "
|
|
800
|
+
f"{dim_type} {query_names}"
|
|
801
|
+
)
|
|
802
|
+
raise DSGInvalidQuery(msg)
|
|
803
|
+
setattr(base_dimension_names, field, query_names[0].dimension_name)
|
|
804
|
+
context = QueryContext(query, base_dimension_names, scratch_dir_context)
|
|
805
|
+
context.metadata = metadata
|
|
806
|
+
# Refer to the comment in ProjectQuerySubmitter.submit for an explanation or if
|
|
807
|
+
# you add a new customization.
|
|
808
|
+
with custom_time_zone(tz): # type: ignore
|
|
809
|
+
df = self._process_aggregations_and_save(df, context, repartition=False)
|
|
810
|
+
context.finalize()
|
|
811
|
+
return df
|
|
812
|
+
|
|
813
|
+
def _load_composite_dataset_query(self, dataset_id):
|
|
814
|
+
filename = self._composite_datasets_dir() / dataset_id / "query.json5"
|
|
815
|
+
return CreateCompositeDatasetQueryModel.from_file(filename)
|
|
816
|
+
|
|
817
|
+
def _read_dataset(self, dataset_id) -> tuple[DataFrame, DatasetMetadataModel]:
|
|
818
|
+
filename = self._composite_datasets_dir() / dataset_id / "table.parquet"
|
|
819
|
+
if not filename.exists():
|
|
820
|
+
msg = f"There is no composite dataset with dataset_id={dataset_id}"
|
|
821
|
+
raise DSGInvalidParameter(msg)
|
|
822
|
+
metadata_file = self.metadata_filename(self._composite_datasets_dir() / dataset_id)
|
|
823
|
+
return (
|
|
824
|
+
read_dataframe(filename),
|
|
825
|
+
DatasetMetadataModel(**load_data(metadata_file)),
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
@track_timing(timer_stats_collector)
|
|
829
|
+
def _save_composite_dataset(self, context: QueryContext, df, repartition):
|
|
830
|
+
output_dir = self._composite_datasets_dir() / context.model.dataset_id
|
|
831
|
+
output_dir.mkdir(exist_ok=True)
|
|
832
|
+
filename = output_dir / "table.parquet"
|
|
833
|
+
self._save_result(context, df, filename, repartition)
|
|
834
|
+
self.metadata_filename(output_dir).write_text(context.metadata.model_dump_json(indent=2))
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
class DatasetQuerySubmitter(QuerySubmitterBase):
|
|
838
|
+
"""Submits queries for a project."""
|
|
839
|
+
|
|
840
|
+
@track_timing(timer_stats_collector)
|
|
841
|
+
def submit(
|
|
842
|
+
self,
|
|
843
|
+
query: DatasetQueryModel,
|
|
844
|
+
mgr: RegistryManager,
|
|
845
|
+
scratch_dir: Path | None = None,
|
|
846
|
+
checkpoint_file: Path | None = None,
|
|
847
|
+
overwrite: bool = False,
|
|
848
|
+
) -> DataFrame:
|
|
849
|
+
"""Submits a dataset query to produce a result table."""
|
|
850
|
+
if not query.to_dimension_references:
|
|
851
|
+
msg = "A dataset query must specify at least one dimension to map."
|
|
852
|
+
raise DSGInvalidQuery(msg)
|
|
853
|
+
|
|
854
|
+
dataset_config = mgr.dataset_manager.get_by_id(query.dataset_id)
|
|
855
|
+
to_dimension_mapping_refs, dims = self._build_mappings(query, dataset_config, mgr)
|
|
856
|
+
handler = make_dataset_schema_handler(
|
|
857
|
+
conn=None,
|
|
858
|
+
config=dataset_config,
|
|
859
|
+
dimension_mgr=mgr.dimension_manager,
|
|
860
|
+
dimension_mapping_mgr=mgr.dimension_mapping_manager,
|
|
861
|
+
store=mgr.dataset_manager.store,
|
|
862
|
+
mapping_references=to_dimension_mapping_refs,
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
base_dim_names = DatasetBaseDimensionNamesModel()
|
|
866
|
+
scratch_dir = scratch_dir or DsgridRuntimeConfig.load().get_scratch_dir()
|
|
867
|
+
time_dim = dims.get(DimensionType.TIME) or dataset_config.get_time_dimension()
|
|
868
|
+
time_zone = None if time_dim is None else time_dim.get_time_zone()
|
|
869
|
+
checkpoint = self._check_checkpoint_file(checkpoint_file, query)
|
|
870
|
+
with ScratchDirContext(scratch_dir) as scratch_dir_context:
|
|
871
|
+
context = QueryContext(
|
|
872
|
+
query, base_dim_names, scratch_dir_context, checkpoint=checkpoint
|
|
873
|
+
)
|
|
874
|
+
output_dir = self._query_output_dir(context)
|
|
875
|
+
check_overwrite(output_dir, overwrite)
|
|
876
|
+
args = (context, handler)
|
|
877
|
+
kwargs = {"time_dimension": dims.get(DimensionType.TIME)}
|
|
878
|
+
if time_dim is not None and time_zone is not None:
|
|
879
|
+
with custom_time_zone(time_zone):
|
|
880
|
+
df = self._run_query(*args, **kwargs)
|
|
881
|
+
else:
|
|
882
|
+
df = self._run_query(*args, **kwargs)
|
|
883
|
+
return df
|
|
884
|
+
|
|
885
|
+
def _build_mappings(
|
|
886
|
+
self, query: DatasetQueryModel, config: DatasetConfig, mgr: RegistryManager
|
|
887
|
+
) -> tuple[list[DimensionMappingReferenceModel], dict[DimensionType, DimensionBaseConfig]]:
|
|
888
|
+
config = mgr.dataset_manager.get_by_id(query.dataset_id)
|
|
889
|
+
to_dimension_mapping_refs: list[DimensionMappingReferenceModel] = []
|
|
890
|
+
mapped_dimension_types: set[DimensionType] = set()
|
|
891
|
+
dimensions: dict[DimensionType, DimensionBaseConfig] = {}
|
|
892
|
+
with mgr.dimension_mapping_manager.db.engine.connect() as conn:
|
|
893
|
+
graph = mgr.dimension_mapping_manager.build_graph(conn=conn)
|
|
894
|
+
for to_dim_ref in query.to_dimension_references:
|
|
895
|
+
to_dim = mgr.dimension_manager.get_by_id(
|
|
896
|
+
to_dim_ref.dimension_id, version=to_dim_ref.version
|
|
897
|
+
)
|
|
898
|
+
if to_dim.model.dimension_type in mapped_dimension_types:
|
|
899
|
+
msg = f"A dataset query cannot map multiple dimensions of the same type: {to_dim.model.dimension_type}"
|
|
900
|
+
raise DSGInvalidQuery(msg)
|
|
901
|
+
dataset_dim = config.get_dimension(to_dim.model.dimension_type)
|
|
902
|
+
assert dataset_dim is not None
|
|
903
|
+
if to_dim.model.dimension_id == dataset_dim.model.dimension_id:
|
|
904
|
+
if to_dim.model.version != dataset_dim.model.version:
|
|
905
|
+
msg = (
|
|
906
|
+
f"A to_dimension_reference cannot point to a different version of a "
|
|
907
|
+
f"dataset's dimension dimension: dataset version = {dataset_dim.model.version}, "
|
|
908
|
+
f"dimension version = {to_dim.model.version}"
|
|
909
|
+
)
|
|
910
|
+
raise DSGInvalidQuery(msg)
|
|
911
|
+
# No mapping is required.
|
|
912
|
+
continue
|
|
913
|
+
if to_dim.model.dimension_type != DimensionType.TIME:
|
|
914
|
+
refs = mgr.dimension_mapping_manager.list_mappings_between_dimensions(
|
|
915
|
+
graph,
|
|
916
|
+
dataset_dim.model.dimension_id,
|
|
917
|
+
to_dim.model.dimension_id,
|
|
918
|
+
)
|
|
919
|
+
to_dimension_mapping_refs += refs
|
|
920
|
+
mapped_dimension_types.add(to_dim.model.dimension_type)
|
|
921
|
+
dimensions[to_dim.model.dimension_type] = to_dim
|
|
922
|
+
return to_dimension_mapping_refs, dimensions
|
|
923
|
+
|
|
924
|
+
def _check_checkpoint_file(
|
|
925
|
+
self, checkpoint_file: Path | None, query: DatasetQueryModel
|
|
926
|
+
) -> MapOperationCheckpoint | None:
|
|
927
|
+
if checkpoint_file is None:
|
|
928
|
+
return None
|
|
929
|
+
|
|
930
|
+
if query.mapping_plan is None:
|
|
931
|
+
msg = f"Query {query.name} does not have a mapping plan. A checkpoint file cannot be used."
|
|
932
|
+
raise DSGInvalidQuery(msg)
|
|
933
|
+
|
|
934
|
+
checkpoint = MapOperationCheckpoint.from_file(checkpoint_file)
|
|
935
|
+
if query.dataset_id != checkpoint.dataset_id:
|
|
936
|
+
msg = (
|
|
937
|
+
f"The dataset_id in the checkpoint file {checkpoint.dataset_id} does not match "
|
|
938
|
+
f"the query dataset_id {query.dataset_id}."
|
|
939
|
+
)
|
|
940
|
+
raise DSGInvalidQuery(msg)
|
|
941
|
+
|
|
942
|
+
if query.mapping_plan.compute_hash() != checkpoint.mapping_plan_hash:
|
|
943
|
+
msg = (
|
|
944
|
+
f"The hash of the mapping plan for dataset {checkpoint.dataset_id} "
|
|
945
|
+
"does not match the checkpoint file. Cannot use the checkpoint."
|
|
946
|
+
)
|
|
947
|
+
raise DSGInvalidParameter(msg)
|
|
948
|
+
|
|
949
|
+
return checkpoint
|
|
950
|
+
|
|
951
|
+
def _run_query(
|
|
952
|
+
self,
|
|
953
|
+
context: QueryContext,
|
|
954
|
+
handler: DatasetSchemaHandlerBase,
|
|
955
|
+
time_dimension: TimeDimensionBaseConfig | None,
|
|
956
|
+
) -> DataFrame:
|
|
957
|
+
df = handler.make_mapped_dataframe(context, time_dimension=time_dimension)
|
|
958
|
+
df = self._postprocess(context, df)
|
|
959
|
+
self._save_results(context, df)
|
|
960
|
+
return df
|
|
961
|
+
|
|
962
|
+
def _postprocess(self, context: QueryContext, df: DataFrame) -> DataFrame:
|
|
963
|
+
if context.model.result.sort_columns:
|
|
964
|
+
df = df.sort(*context.model.result.sort_columns)
|
|
965
|
+
|
|
966
|
+
if isinstance(context.model.result.table_format, PivotedTableFormatModel):
|
|
967
|
+
df = _pivot_table(df, context)
|
|
968
|
+
|
|
969
|
+
return df
|
|
970
|
+
|
|
971
|
+
def _query_output_dir(self, context: QueryContext) -> Path:
|
|
972
|
+
return self._output_dir / context.model.name
|
|
973
|
+
|
|
974
|
+
@track_timing(timer_stats_collector)
|
|
975
|
+
def _save_results(self, context: QueryContext, df) -> Path:
|
|
976
|
+
output_dir = self._query_output_dir(context)
|
|
977
|
+
output_dir.mkdir(exist_ok=True)
|
|
978
|
+
filename = output_dir / f"table.{context.model.result.output_format}"
|
|
979
|
+
suffix = filename.suffix
|
|
980
|
+
if suffix == ".csv":
|
|
981
|
+
df.toPandas().to_csv(filename, header=True, index=False)
|
|
982
|
+
elif suffix == ".parquet":
|
|
983
|
+
df = write_dataframe_and_auto_partition(df, filename)
|
|
984
|
+
else:
|
|
985
|
+
msg = f"Unsupported output_format={suffix}"
|
|
986
|
+
raise NotImplementedError(msg)
|
|
987
|
+
|
|
988
|
+
logger.info("Wrote query=%s output table to %s", context.model.name, filename)
|
|
989
|
+
return filename
|
|
990
|
+
|
|
991
|
+
|
|
992
|
+
def _pivot_table(df: DataFrame, context: QueryContext):
|
|
993
|
+
pivoted_column = context.convert_to_pivoted()
|
|
994
|
+
return pivot(df, pivoted_column, VALUE_COLUMN)
|