dsgrid-toolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsgrid-toolkit might be problematic. Click here for more details.

Files changed (152) hide show
  1. dsgrid/__init__.py +22 -0
  2. dsgrid/api/__init__.py +0 -0
  3. dsgrid/api/api_manager.py +179 -0
  4. dsgrid/api/app.py +420 -0
  5. dsgrid/api/models.py +60 -0
  6. dsgrid/api/response_models.py +116 -0
  7. dsgrid/apps/__init__.py +0 -0
  8. dsgrid/apps/project_viewer/app.py +216 -0
  9. dsgrid/apps/registration_gui.py +444 -0
  10. dsgrid/chronify.py +22 -0
  11. dsgrid/cli/__init__.py +0 -0
  12. dsgrid/cli/common.py +120 -0
  13. dsgrid/cli/config.py +177 -0
  14. dsgrid/cli/download.py +13 -0
  15. dsgrid/cli/dsgrid.py +142 -0
  16. dsgrid/cli/dsgrid_admin.py +349 -0
  17. dsgrid/cli/install_notebooks.py +62 -0
  18. dsgrid/cli/query.py +711 -0
  19. dsgrid/cli/registry.py +1773 -0
  20. dsgrid/cloud/__init__.py +0 -0
  21. dsgrid/cloud/cloud_storage_interface.py +140 -0
  22. dsgrid/cloud/factory.py +31 -0
  23. dsgrid/cloud/fake_storage_interface.py +37 -0
  24. dsgrid/cloud/s3_storage_interface.py +156 -0
  25. dsgrid/common.py +35 -0
  26. dsgrid/config/__init__.py +0 -0
  27. dsgrid/config/annual_time_dimension_config.py +187 -0
  28. dsgrid/config/common.py +131 -0
  29. dsgrid/config/config_base.py +148 -0
  30. dsgrid/config/dataset_config.py +684 -0
  31. dsgrid/config/dataset_schema_handler_factory.py +41 -0
  32. dsgrid/config/date_time_dimension_config.py +108 -0
  33. dsgrid/config/dimension_config.py +54 -0
  34. dsgrid/config/dimension_config_factory.py +65 -0
  35. dsgrid/config/dimension_mapping_base.py +349 -0
  36. dsgrid/config/dimension_mappings_config.py +48 -0
  37. dsgrid/config/dimensions.py +775 -0
  38. dsgrid/config/dimensions_config.py +71 -0
  39. dsgrid/config/index_time_dimension_config.py +76 -0
  40. dsgrid/config/input_dataset_requirements.py +31 -0
  41. dsgrid/config/mapping_tables.py +209 -0
  42. dsgrid/config/noop_time_dimension_config.py +42 -0
  43. dsgrid/config/project_config.py +1457 -0
  44. dsgrid/config/registration_models.py +199 -0
  45. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  46. dsgrid/config/simple_models.py +49 -0
  47. dsgrid/config/supplemental_dimension.py +29 -0
  48. dsgrid/config/time_dimension_base_config.py +200 -0
  49. dsgrid/data_models.py +155 -0
  50. dsgrid/dataset/__init__.py +0 -0
  51. dsgrid/dataset/dataset.py +123 -0
  52. dsgrid/dataset/dataset_expression_handler.py +86 -0
  53. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  54. dsgrid/dataset/dataset_schema_handler_base.py +899 -0
  55. dsgrid/dataset/dataset_schema_handler_one_table.py +196 -0
  56. dsgrid/dataset/dataset_schema_handler_standard.py +303 -0
  57. dsgrid/dataset/growth_rates.py +162 -0
  58. dsgrid/dataset/models.py +44 -0
  59. dsgrid/dataset/table_format_handler_base.py +257 -0
  60. dsgrid/dataset/table_format_handler_factory.py +17 -0
  61. dsgrid/dataset/unpivoted_table.py +121 -0
  62. dsgrid/dimension/__init__.py +0 -0
  63. dsgrid/dimension/base_models.py +218 -0
  64. dsgrid/dimension/dimension_filters.py +308 -0
  65. dsgrid/dimension/standard.py +213 -0
  66. dsgrid/dimension/time.py +531 -0
  67. dsgrid/dimension/time_utils.py +88 -0
  68. dsgrid/dsgrid_rc.py +88 -0
  69. dsgrid/exceptions.py +105 -0
  70. dsgrid/filesystem/__init__.py +0 -0
  71. dsgrid/filesystem/cloud_filesystem.py +32 -0
  72. dsgrid/filesystem/factory.py +32 -0
  73. dsgrid/filesystem/filesystem_interface.py +136 -0
  74. dsgrid/filesystem/local_filesystem.py +74 -0
  75. dsgrid/filesystem/s3_filesystem.py +118 -0
  76. dsgrid/loggers.py +132 -0
  77. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +950 -0
  78. dsgrid/notebooks/registration.ipynb +48 -0
  79. dsgrid/notebooks/start_notebook.sh +11 -0
  80. dsgrid/project.py +451 -0
  81. dsgrid/query/__init__.py +0 -0
  82. dsgrid/query/dataset_mapping_plan.py +142 -0
  83. dsgrid/query/derived_dataset.py +384 -0
  84. dsgrid/query/models.py +726 -0
  85. dsgrid/query/query_context.py +287 -0
  86. dsgrid/query/query_submitter.py +847 -0
  87. dsgrid/query/report_factory.py +19 -0
  88. dsgrid/query/report_peak_load.py +70 -0
  89. dsgrid/query/reports_base.py +20 -0
  90. dsgrid/registry/__init__.py +0 -0
  91. dsgrid/registry/bulk_register.py +161 -0
  92. dsgrid/registry/common.py +287 -0
  93. dsgrid/registry/config_update_checker_base.py +63 -0
  94. dsgrid/registry/data_store_factory.py +34 -0
  95. dsgrid/registry/data_store_interface.py +69 -0
  96. dsgrid/registry/dataset_config_generator.py +156 -0
  97. dsgrid/registry/dataset_registry_manager.py +734 -0
  98. dsgrid/registry/dataset_update_checker.py +16 -0
  99. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  100. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  101. dsgrid/registry/dimension_registry_manager.py +413 -0
  102. dsgrid/registry/dimension_update_checker.py +16 -0
  103. dsgrid/registry/duckdb_data_store.py +185 -0
  104. dsgrid/registry/filesystem_data_store.py +141 -0
  105. dsgrid/registry/filter_registry_manager.py +123 -0
  106. dsgrid/registry/project_config_generator.py +57 -0
  107. dsgrid/registry/project_registry_manager.py +1616 -0
  108. dsgrid/registry/project_update_checker.py +48 -0
  109. dsgrid/registry/registration_context.py +223 -0
  110. dsgrid/registry/registry_auto_updater.py +316 -0
  111. dsgrid/registry/registry_database.py +662 -0
  112. dsgrid/registry/registry_interface.py +446 -0
  113. dsgrid/registry/registry_manager.py +544 -0
  114. dsgrid/registry/registry_manager_base.py +367 -0
  115. dsgrid/registry/versioning.py +92 -0
  116. dsgrid/spark/__init__.py +0 -0
  117. dsgrid/spark/functions.py +545 -0
  118. dsgrid/spark/types.py +50 -0
  119. dsgrid/tests/__init__.py +0 -0
  120. dsgrid/tests/common.py +139 -0
  121. dsgrid/tests/make_us_data_registry.py +204 -0
  122. dsgrid/tests/register_derived_datasets.py +103 -0
  123. dsgrid/tests/utils.py +25 -0
  124. dsgrid/time/__init__.py +0 -0
  125. dsgrid/time/time_conversions.py +80 -0
  126. dsgrid/time/types.py +67 -0
  127. dsgrid/units/__init__.py +0 -0
  128. dsgrid/units/constants.py +113 -0
  129. dsgrid/units/convert.py +71 -0
  130. dsgrid/units/energy.py +145 -0
  131. dsgrid/units/power.py +87 -0
  132. dsgrid/utils/__init__.py +0 -0
  133. dsgrid/utils/dataset.py +612 -0
  134. dsgrid/utils/files.py +179 -0
  135. dsgrid/utils/filters.py +125 -0
  136. dsgrid/utils/id_remappings.py +100 -0
  137. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  138. dsgrid/utils/py_expression_eval/README.md +8 -0
  139. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  140. dsgrid/utils/py_expression_eval/tests.py +283 -0
  141. dsgrid/utils/run_command.py +70 -0
  142. dsgrid/utils/scratch_dir_context.py +64 -0
  143. dsgrid/utils/spark.py +918 -0
  144. dsgrid/utils/spark_partition.py +98 -0
  145. dsgrid/utils/timing.py +239 -0
  146. dsgrid/utils/utilities.py +184 -0
  147. dsgrid/utils/versioning.py +36 -0
  148. dsgrid_toolkit-0.2.0.dist-info/METADATA +216 -0
  149. dsgrid_toolkit-0.2.0.dist-info/RECORD +152 -0
  150. dsgrid_toolkit-0.2.0.dist-info/WHEEL +4 -0
  151. dsgrid_toolkit-0.2.0.dist-info/entry_points.txt +4 -0
  152. dsgrid_toolkit-0.2.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,847 @@
1
+ import abc
2
+ import json
3
+ import logging
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+ from zipfile import ZipFile
8
+
9
+ from chronify.utils.path_utils import check_overwrite
10
+ from semver import VersionInfo
11
+ from sqlalchemy import Connection
12
+
13
+ from dsgrid.common import VALUE_COLUMN
14
+ from dsgrid.config.dataset_config import DatasetConfig
15
+ from dsgrid.config.dimension_config import DimensionBaseConfig
16
+ from dsgrid.config.project_config import DatasetBaseDimensionNamesModel
17
+ from dsgrid.config.dimension_mapping_base import DimensionMappingReferenceModel
18
+ from dsgrid.config.time_dimension_base_config import TimeDimensionBaseConfig
19
+ from dsgrid.dataset.dataset_expression_handler import (
20
+ DatasetExpressionHandler,
21
+ evaluate_expression,
22
+ )
23
+ from dsgrid.utils.scratch_dir_context import ScratchDirContext
24
+ from dsgrid.dataset.models import TableFormatType, PivotedTableFormatModel
25
+ from dsgrid.dataset.dataset_schema_handler_base import DatasetSchemaHandlerBase
26
+ from dsgrid.dataset.table_format_handler_factory import make_table_format_handler
27
+ from dsgrid.dimension.base_models import DimensionCategory, DimensionType
28
+ from dsgrid.dimension.dimension_filters import SubsetDimensionFilterModel
29
+ from dsgrid.exceptions import DSGInvalidDataset, DSGInvalidParameter, DSGInvalidQuery
30
+ from dsgrid.dsgrid_rc import DsgridRuntimeConfig
31
+ from dsgrid.query.dataset_mapping_plan import MapOperationCheckpoint
32
+ from dsgrid.query.query_context import QueryContext
33
+ from dsgrid.query.report_factory import make_report
34
+ from dsgrid.registry.registry_manager import RegistryManager
35
+ from dsgrid.spark.functions import pivot
36
+ from dsgrid.spark.types import DataFrame
37
+ from dsgrid.project import Project
38
+ from dsgrid.utils.spark import (
39
+ custom_time_zone,
40
+ read_dataframe,
41
+ try_read_dataframe,
42
+ write_dataframe,
43
+ write_dataframe_and_auto_partition,
44
+ )
45
+ from dsgrid.utils.timing import timer_stats_collector, track_timing
46
+ from dsgrid.utils.files import delete_if_exists, compute_hash, load_data
47
+ from dsgrid.query.models import (
48
+ DatasetQueryModel,
49
+ ProjectQueryModel,
50
+ ColumnType,
51
+ CreateCompositeDatasetQueryModel,
52
+ CompositeDatasetQueryModel,
53
+ DatasetMetadataModel,
54
+ ProjectionDatasetModel,
55
+ StandaloneDatasetModel,
56
+ )
57
+ from dsgrid.config.dataset_schema_handler_factory import make_dataset_schema_handler
58
+
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ class QuerySubmitterBase:
64
+ """Handles query submission"""
65
+
66
+ def __init__(self, output_dir: Path):
67
+ self._output_dir = output_dir
68
+ self._cached_tables_dir().mkdir(exist_ok=True, parents=True)
69
+ self._composite_datasets_dir().mkdir(exist_ok=True, parents=True)
70
+
71
+ # TODO #186: This location will need more consideration.
72
+ # We might want to store cached datasets in the spark-warehouse and let Spark manage it
73
+ # for us. However, would we share them on the HPC? What happens on HPC walltime timeouts
74
+ # where the tables are left in intermediate states?
75
+ # This is even more of a problem on AWS.
76
+ self._cached_project_mapped_datasets_dir().mkdir(exist_ok=True, parents=True)
77
+
78
+ @abc.abstractmethod
79
+ def submit(self, *args, **kwargs) -> DataFrame:
80
+ """Submit a query for execution"""
81
+
82
+ def _composite_datasets_dir(self):
83
+ return self._output_dir / "composite_datasets"
84
+
85
+ def _cached_tables_dir(self):
86
+ """Directory for intermediate tables made up of multiple project-mapped datasets."""
87
+ return self._output_dir / "cached_tables"
88
+
89
+ def _cached_project_mapped_datasets_dir(self):
90
+ """Directory for intermediate project-mapped datasets.
91
+ Data could be filtered.
92
+ """
93
+ return self._output_dir / "cached_project_mapped_datasets"
94
+
95
+ @staticmethod
96
+ def metadata_filename(path: Path):
97
+ return path / "metadata.json"
98
+
99
+ @staticmethod
100
+ def query_filename(path: Path):
101
+ return path / "query.json5"
102
+
103
+ @staticmethod
104
+ def table_filename(path: Path):
105
+ return path / "table.parquet"
106
+
107
+ @staticmethod
108
+ def _cached_table_filename(path: Path):
109
+ return path / "table.parquet"
110
+
111
+
112
+ class ProjectBasedQuerySubmitter(QuerySubmitterBase):
113
+ def __init__(self, project: Project, *args, **kwargs):
114
+ super().__init__(*args, **kwargs)
115
+ self._project = project
116
+
117
+ @property
118
+ def project(self):
119
+ return self._project
120
+
121
+ def _create_table_hash(self, context: QueryContext) -> tuple[str, str]:
122
+ """Create a hash that can be used to identify whether the following sequence
123
+ can be skipped based on a previous query:
124
+ - Apply expression across all datasets in the query.
125
+ - Apply filters.
126
+ - Apply aggregations.
127
+
128
+ Examples of changes that will invalidate the query:
129
+ - Change to the project section of the query
130
+ - Bump to project major version number
131
+ - Change to a dataset version
132
+ - Change to a project's dimension requirements for a dataset
133
+ - Change to a dataset dimension mapping
134
+ """
135
+ assert isinstance(context.model, ProjectQueryModel) or isinstance(
136
+ context.model, CreateCompositeDatasetQueryModel
137
+ )
138
+ data = {
139
+ "project_major_version": VersionInfo.parse(self._project.config.model.version).major,
140
+ "project_query": context.model.serialize_cached_content(),
141
+ "datasets": [
142
+ self._project.config.get_dataset(x.dataset_id).model_dump(mode="json")
143
+ for x in context.model.project.dataset.source_datasets
144
+ ],
145
+ }
146
+ text = json.dumps(data, indent=2)
147
+ hash_value = compute_hash(text.encode())
148
+ return text, hash_value
149
+
150
+ def _try_read_cache(self, context: QueryContext):
151
+ _, hash_value = self._create_table_hash(context)
152
+ cached_dir = self._cached_tables_dir() / hash_value
153
+ filename = self._cached_table_filename(cached_dir)
154
+ df = try_read_dataframe(filename)
155
+ if df is not None:
156
+ logger.info("Load intermediate table from cache: %s", filename)
157
+ metadata_file = self.metadata_filename(cached_dir)
158
+ return df, DatasetMetadataModel.from_file(metadata_file)
159
+ return None, None
160
+
161
+ def _run_checks(self, model: ProjectQueryModel) -> DatasetBaseDimensionNamesModel:
162
+ subsets = set(self.project.config.list_dimension_names(DimensionCategory.SUBSET))
163
+ for agg in model.result.aggregations:
164
+ for _, column in agg.iter_dimensions_to_keep():
165
+ dimension_name = column.dimension_name
166
+ if dimension_name in subsets:
167
+ subset_dim = self._project.config.get_dimension(dimension_name)
168
+ dim_type = subset_dim.model.dimension_type
169
+ supp_names = " ".join(
170
+ self._project.config.get_supplemental_dimension_to_name_mapping()[dim_type]
171
+ )
172
+ base_names = [
173
+ x.model.name
174
+ for x in self._project.config.list_base_dimensions(dimension_type=dim_type)
175
+ ]
176
+ msg = (
177
+ f"Subset dimensions cannot be used in aggregations: "
178
+ f"{dimension_name=}. Only base and supplemental dimensions are "
179
+ f"allowed. base={base_names} supplemental={supp_names}"
180
+ )
181
+ raise DSGInvalidQuery(msg)
182
+
183
+ for report_inputs in model.result.reports:
184
+ report = make_report(report_inputs.report_type)
185
+ report.check_query(model)
186
+
187
+ with self._project.dimension_mapping_manager.db.engine.connect() as conn:
188
+ return self._check_datasets(model, conn)
189
+
190
+ def _check_datasets(
191
+ self, query_model: ProjectQueryModel, conn: Connection
192
+ ) -> DatasetBaseDimensionNamesModel:
193
+ base_dimension_names: DatasetBaseDimensionNamesModel | None = None
194
+ dataset_ids: list[str] = []
195
+ query_names: list[DatasetBaseDimensionNamesModel] = []
196
+ for dataset in query_model.project.dataset.source_datasets:
197
+ src_dataset_ids = dataset.list_source_dataset_ids()
198
+ dataset_ids += src_dataset_ids
199
+ if isinstance(dataset, StandaloneDatasetModel):
200
+ query_names.append(
201
+ self._project.config.get_dataset_base_dimension_names(dataset.dataset_id)
202
+ )
203
+ elif isinstance(dataset, ProjectionDatasetModel):
204
+ query_names += [
205
+ self._project.config.get_dataset_base_dimension_names(
206
+ dataset.initial_value_dataset_id
207
+ ),
208
+ self._project.config.get_dataset_base_dimension_names(
209
+ dataset.growth_rate_dataset_id
210
+ ),
211
+ ]
212
+ else:
213
+ msg = f"Unhandled dataset type: {dataset=}"
214
+ raise NotImplementedError(msg)
215
+
216
+ for dataset_id in src_dataset_ids:
217
+ dataset = self._project.load_dataset(dataset_id, conn=conn)
218
+ plan = query_model.project.get_dataset_mapping_plan(dataset_id)
219
+ if plan is None:
220
+ plan = dataset.handler.build_default_dataset_mapping_plan()
221
+ query_model.project.set_dataset_mapper(plan)
222
+ else:
223
+ dataset.handler.check_dataset_mapping_plan(plan, self._project.config)
224
+
225
+ for dataset_id, names in zip(dataset_ids, query_names):
226
+ self._fix_legacy_base_dimension_names(names, dataset_id)
227
+ if base_dimension_names is None:
228
+ base_dimension_names = names
229
+ elif base_dimension_names != names:
230
+ msg = (
231
+ "Datasets in a query must have the same base dimension query names: "
232
+ f"{dataset=} {base_dimension_names} {names}"
233
+ )
234
+ raise DSGInvalidQuery(msg)
235
+
236
+ assert base_dimension_names is not None
237
+ return base_dimension_names
238
+
239
+ def _fix_legacy_base_dimension_names(
240
+ self, names: DatasetBaseDimensionNamesModel, dataset_id: str
241
+ ) -> None:
242
+ for dim_type in DimensionType:
243
+ val = getattr(names, dim_type.value)
244
+ if val is None:
245
+ # This is a workaround for dsgrid projects created before the field
246
+ # base_dimension_names was added to InputDatasetModel.
247
+ dims = self._project.config.list_base_dimensions(dimension_type=dim_type)
248
+ if len(dims) > 1:
249
+ msg = (
250
+ "The dataset's base_dimension_names value is not set and "
251
+ f"there are multiple base dimensions of type {dim_type} in the project. "
252
+ f"Please re-register the dataset with {dataset_id=}."
253
+ )
254
+ raise DSGInvalidDataset(msg)
255
+ setattr(names, dim_type.value, dims[0].model.name)
256
+
257
+ def _run_query(
258
+ self,
259
+ scratch_dir_context: ScratchDirContext,
260
+ model: ProjectQueryModel,
261
+ load_cached_table: bool,
262
+ checkpoint_file: Path | None,
263
+ persist_intermediate_table: bool,
264
+ zip_file: bool = False,
265
+ overwrite: bool = False,
266
+ ):
267
+ base_dimension_names = self._run_checks(model)
268
+ checkpoint = self._check_checkpoint_file(checkpoint_file, model)
269
+ context = QueryContext(
270
+ model,
271
+ base_dimension_names,
272
+ scratch_dir_context=scratch_dir_context,
273
+ checkpoint=checkpoint,
274
+ )
275
+ assert isinstance(context.model, ProjectQueryModel) or isinstance(
276
+ context.model, CreateCompositeDatasetQueryModel
277
+ )
278
+ context.model.project.version = str(self._project.version)
279
+ output_dir = self._output_dir / context.model.name
280
+ if output_dir.exists() and not overwrite:
281
+ msg = (
282
+ f"output directory {self._output_dir} and query name={context.model.name} will "
283
+ "overwrite an existing query results directory. "
284
+ "Choose a different path or pass force=True."
285
+ )
286
+ raise DSGInvalidParameter(msg)
287
+
288
+ df = None
289
+ if load_cached_table:
290
+ df, metadata = self._try_read_cache(context)
291
+ if df is None:
292
+ df_filenames = self._project.process_query(
293
+ context, self._cached_project_mapped_datasets_dir()
294
+ )
295
+ df = self._postprocess_datasets(context, scratch_dir_context, df_filenames)
296
+ is_cached = False
297
+ else:
298
+ context.metadata = metadata
299
+ is_cached = True
300
+
301
+ if context.model.result.aggregate_each_dataset:
302
+ # This wouldn't save any time.
303
+ persist_intermediate_table = False
304
+
305
+ if persist_intermediate_table and not is_cached:
306
+ df = self._persist_intermediate_result(context, df)
307
+
308
+ if not context.model.result.aggregate_each_dataset:
309
+ if context.model.result.dimension_filters:
310
+ df = self._apply_filters(df, context)
311
+ df = self._process_aggregations(df, context)
312
+
313
+ repartition = not persist_intermediate_table
314
+ table_filename = self._save_query_results(context, df, repartition, zip_file=zip_file)
315
+
316
+ for report_inputs in context.model.result.reports:
317
+ report = make_report(report_inputs.report_type)
318
+ output_dir = self._output_dir / context.model.name
319
+ report.generate(table_filename, output_dir, context, report_inputs.inputs)
320
+
321
+ return df, context
322
+
323
+ def _check_checkpoint_file(
324
+ self, checkpoint_file: Path | None, model: ProjectQueryModel
325
+ ) -> MapOperationCheckpoint | None:
326
+ if checkpoint_file is None:
327
+ return None
328
+
329
+ checkpoint = MapOperationCheckpoint.from_file(checkpoint_file)
330
+ confirmed_checkpoint = False
331
+ for dataset in model.project.dataset.source_datasets:
332
+ if dataset.get_dataset_id() == checkpoint.dataset_id:
333
+ for plan in model.project.mapping_plans:
334
+ if plan.dataset_id == checkpoint.dataset_id:
335
+ if plan.compute_hash() == checkpoint.mapping_plan_hash:
336
+ confirmed_checkpoint = True
337
+ else:
338
+ msg = (
339
+ f"The hash of the mapping plan for dataset {checkpoint.dataset_id} "
340
+ "does not match the checkpoint file. Cannot use the checkpoint."
341
+ )
342
+ raise DSGInvalidParameter(msg)
343
+ if not confirmed_checkpoint:
344
+ msg = f"Checkpoint {checkpoint_file} does not match any dataset in the query."
345
+ raise DSGInvalidParameter(msg)
346
+
347
+ return checkpoint
348
+
349
+ @track_timing(timer_stats_collector)
350
+ def _persist_intermediate_result(self, context: QueryContext, df):
351
+ text, hash_value = self._create_table_hash(context)
352
+ cached_dir = self._cached_tables_dir() / hash_value
353
+ if cached_dir.exists():
354
+ shutil.rmtree(cached_dir)
355
+ cached_dir.mkdir()
356
+ filename = self._cached_table_filename(cached_dir)
357
+ df = write_dataframe_and_auto_partition(df, filename)
358
+
359
+ self.metadata_filename(cached_dir).write_text(
360
+ context.metadata.model_dump_json(indent=2), encoding="utf-8"
361
+ )
362
+ self.query_filename(cached_dir).write_text(text, encoding="utf-8")
363
+ logger.debug("Persisted intermediate table to %s", filename)
364
+ return df
365
+
366
+ def _postprocess_datasets(
367
+ self,
368
+ context: QueryContext,
369
+ scratch_dir_context: ScratchDirContext,
370
+ df_filenames: dict[str, Path],
371
+ ) -> DataFrame:
372
+ if context.model.result.aggregate_each_dataset:
373
+ for dataset_id, path in df_filenames.items():
374
+ df = read_dataframe(path)
375
+ if context.model.result.dimension_filters:
376
+ df = self._apply_filters(df, context)
377
+ df = self._process_aggregations(df, context, dataset_id=dataset_id)
378
+ path = scratch_dir_context.get_temp_filename(suffix=".parquet")
379
+ write_dataframe(df, path)
380
+ df_filenames[dataset_id] = path
381
+
382
+ # All dataset columns need to be in the same order.
383
+ context.consolidate_dataset_metadata()
384
+ datasets = self._convert_datasets(context, df_filenames)
385
+ assert isinstance(context.model, ProjectQueryModel) or isinstance(
386
+ context.model, CreateCompositeDatasetQueryModel
387
+ )
388
+ assert context.model.project.dataset.expression is not None
389
+ return evaluate_expression(context.model.project.dataset.expression, datasets).df
390
+
391
+ def _convert_datasets(self, context: QueryContext, filenames: dict[str, Path]):
392
+ dim_columns, time_columns = self._get_dimension_columns(context)
393
+ expected_columns = time_columns + dim_columns
394
+ expected_columns.append(VALUE_COLUMN)
395
+
396
+ datasets = {}
397
+ for dataset_id, path in filenames.items():
398
+ df = read_dataframe(path)
399
+ unexpected = sorted(set(df.columns).difference(expected_columns))
400
+ if unexpected:
401
+ msg = f"Unexpected columns are present in {dataset_id=} {unexpected=}"
402
+ raise Exception(msg)
403
+ datasets[dataset_id] = DatasetExpressionHandler(
404
+ df.select(*expected_columns), time_columns + dim_columns, [VALUE_COLUMN]
405
+ )
406
+ return datasets
407
+
408
+ def _get_dimension_columns(self, context: QueryContext) -> tuple[list[str], list[str]]:
409
+ match context.model.result.column_type:
410
+ case ColumnType.DIMENSION_NAMES:
411
+ dim_columns = context.get_all_dimension_column_names(exclude={DimensionType.TIME})
412
+ time_columns = context.get_dimension_column_names(DimensionType.TIME)
413
+ case ColumnType.DIMENSION_TYPES:
414
+ dim_columns = {x.value for x in DimensionType if x != DimensionType.TIME}
415
+ time_columns = context.get_dimension_column_names(DimensionType.TIME)
416
+ case _:
417
+ msg = f"BUG: unhandled {context.model.result.column_type=}"
418
+ raise NotImplementedError(msg)
419
+
420
+ return sorted(dim_columns), sorted(time_columns)
421
+
422
+ def _process_aggregations(
423
+ self, df: DataFrame, context: QueryContext, dataset_id: str | None = None
424
+ ) -> DataFrame:
425
+ handler = make_table_format_handler(
426
+ TableFormatType.UNPIVOTED, self._project.config, dataset_id=dataset_id
427
+ )
428
+ df = handler.process_aggregations(df, context.model.result.aggregations, context)
429
+
430
+ if context.model.result.replace_ids_with_names:
431
+ df = handler.replace_ids_with_names(df)
432
+
433
+ if context.model.result.sort_columns:
434
+ df = df.sort(*context.model.result.sort_columns)
435
+
436
+ if isinstance(context.model.result.table_format, PivotedTableFormatModel):
437
+ df = _pivot_table(df, context)
438
+
439
+ return df
440
+
441
+ def _process_aggregations_and_save(
442
+ self,
443
+ df: DataFrame,
444
+ context: QueryContext,
445
+ repartition: bool,
446
+ zip_file: bool = False,
447
+ ) -> DataFrame:
448
+ df = self._process_aggregations(df, context)
449
+
450
+ self._save_query_results(context, df, repartition, zip_file=zip_file)
451
+ return df
452
+
453
+ def _apply_filters(self, df, context: QueryContext):
454
+ for dim_filter in context.model.result.dimension_filters:
455
+ column_names = context.get_dimension_column_names(dim_filter.dimension_type)
456
+ if len(column_names) > 1:
457
+ msg = f"Cannot filter {dim_filter} when there are multiple {column_names=}"
458
+ raise NotImplementedError(msg)
459
+ if isinstance(dim_filter, SubsetDimensionFilterModel):
460
+ records = dim_filter.get_filtered_records_dataframe(
461
+ self.project.config.get_dimension
462
+ )
463
+ column = next(iter(column_names))
464
+ df = df.join(
465
+ records.select("id"),
466
+ on=getattr(df, column) == getattr(records, "id"),
467
+ ).drop("id")
468
+ else:
469
+ query_name = dim_filter.dimension_name
470
+ if query_name not in df.columns:
471
+ # Consider catching this exception and still write to a file.
472
+ # It could mean writing a lot of data the user doesn't want.
473
+ msg = f"filter column {query_name} is not in the dataframe: {df.columns}"
474
+ raise DSGInvalidParameter(msg)
475
+ df = dim_filter.apply_filter(df, column=query_name)
476
+ return df
477
+
478
+ @track_timing(timer_stats_collector)
479
+ def _save_query_results(
480
+ self,
481
+ context: QueryContext,
482
+ df,
483
+ repartition,
484
+ aggregation_name=None,
485
+ zip_file=False,
486
+ ):
487
+ output_dir = self._output_dir / context.model.name
488
+ output_dir.mkdir(exist_ok=True)
489
+ if aggregation_name is not None:
490
+ output_dir /= aggregation_name
491
+ output_dir.mkdir(exist_ok=True)
492
+ filename = output_dir / f"table.{context.model.result.output_format}"
493
+ self._save_result(context, df, filename, repartition)
494
+ if zip_file:
495
+ zip_name = Path(str(output_dir) + ".zip")
496
+ with ZipFile(zip_name, "w") as zipf:
497
+ for path in output_dir.rglob("*"):
498
+ zipf.write(path)
499
+ return filename
500
+
501
+ def _save_result(self, context: QueryContext, df, filename, repartition):
502
+ output_dir = filename.parent
503
+ suffix = filename.suffix
504
+ if suffix == ".csv":
505
+ df.toPandas().to_csv(filename, header=True, index=False)
506
+ elif suffix == ".parquet":
507
+ if repartition:
508
+ df = write_dataframe_and_auto_partition(df, filename)
509
+ else:
510
+ delete_if_exists(filename)
511
+ write_dataframe(df, filename, overwrite=True)
512
+ else:
513
+ msg = f"Unsupported output_format={suffix}"
514
+ raise NotImplementedError(msg)
515
+ self.query_filename(output_dir).write_text(context.model.serialize_with_hash()[1])
516
+ self.metadata_filename(output_dir).write_text(context.metadata.model_dump_json(indent=2))
517
+ logger.info("Wrote query=%s output table to %s", context.model.name, filename)
518
+
519
+
520
+ class ProjectQuerySubmitter(ProjectBasedQuerySubmitter):
521
+ """Submits queries for a project."""
522
+
523
+ @track_timing(timer_stats_collector)
524
+ def submit(
525
+ self,
526
+ model: ProjectQueryModel,
527
+ scratch_dir: Path | None = None,
528
+ checkpoint_file: Path | None = None,
529
+ persist_intermediate_table: bool = True,
530
+ load_cached_table: bool = True,
531
+ zip_file: bool = False,
532
+ overwrite: bool = False,
533
+ ) -> DataFrame:
534
+ """Submits a project query to consolidate datasets and produce result tables.
535
+
536
+ Parameters
537
+ ----------
538
+ model : ProjectQueryResultModel
539
+ checkpoint_file : bool, optional
540
+ Optional checkpoint file from which to resume the operation.
541
+ persist_intermediate_table : bool, optional
542
+ Persist the intermediate consolidated table.
543
+ load_cached_table : bool, optional
544
+ Load a cached consolidated table if the query matches an existing query.
545
+ zip_file : bool, optional
546
+ Create a zip file with all output files.
547
+ overwrite : bool
548
+ If True, overwrite any existing output directory.
549
+
550
+ Returns
551
+ -------
552
+ pyspark.sql.DataFrame
553
+
554
+ Raises
555
+ ------
556
+ DSGInvalidParameter
557
+ Raised if the model defines a project version
558
+ DSGInvalidQuery
559
+ Raised if the query is invalid
560
+ """
561
+ tz = self._project.config.get_base_time_dimension().get_time_zone()
562
+ assert tz is not None, "Project base time dimension must have a time zone"
563
+ scratch_dir = scratch_dir or DsgridRuntimeConfig.load().get_scratch_dir()
564
+ with ScratchDirContext(scratch_dir) as scratch_dir_context:
565
+ # Ensure that queries that aggregate time reflect the project's time zone instead
566
+ # of the local computer.
567
+ # If any other settings get customized here, handle them in restart_spark()
568
+ # as well. This change won't persist Spark session restarts.
569
+ with custom_time_zone(tz.tz_name):
570
+ df, context = self._run_query(
571
+ scratch_dir_context,
572
+ model,
573
+ load_cached_table,
574
+ checkpoint_file=checkpoint_file,
575
+ persist_intermediate_table=persist_intermediate_table,
576
+ zip_file=zip_file,
577
+ overwrite=overwrite,
578
+ )
579
+ context.finalize()
580
+ return df
581
+
582
+
583
+ class CompositeDatasetQuerySubmitter(ProjectBasedQuerySubmitter):
584
+ """Submits queries for a composite dataset."""
585
+
586
+ @track_timing(timer_stats_collector)
587
+ def create_dataset(
588
+ self,
589
+ model: CreateCompositeDatasetQueryModel,
590
+ scratch_dir: Path | None = None,
591
+ persist_intermediate_table=False,
592
+ load_cached_table=True,
593
+ force=False,
594
+ ):
595
+ """Create a composite dataset from a project.
596
+
597
+ Parameters
598
+ ----------
599
+ model : CreateCompositeDatasetQueryModel
600
+ persist_intermediate_table : bool, optional
601
+ Persist the intermediate consolidated table.
602
+ load_cached_table : bool, optional
603
+ Load a cached consolidated table if the query matches an existing query.
604
+ force : bool
605
+ If True, overwrite any existing output directory.
606
+
607
+ """
608
+ tz = self._project.config.get_base_time_dimension().get_time_zone()
609
+ scratch_dir = scratch_dir or DsgridRuntimeConfig.load().get_scratch_dir()
610
+ with ScratchDirContext(scratch_dir) as scratch_dir_context:
611
+ # Ensure that queries that aggregate time reflect the project's time zone instead
612
+ # of the local computer.
613
+ # If any other settings get customized here, handle them in restart_spark()
614
+ # as well. This change won't persist Spark session restarts.
615
+ with custom_time_zone(tz.tz_name): # type: ignore
616
+ df, context = self._run_query(
617
+ scratch_dir_context,
618
+ model,
619
+ load_cached_table,
620
+ None,
621
+ persist_intermediate_table,
622
+ overwrite=force,
623
+ )
624
+ self._save_composite_dataset(context, df, not persist_intermediate_table)
625
+ context.finalize()
626
+
627
+ @track_timing(timer_stats_collector)
628
+ def submit(
629
+ self,
630
+ query: CompositeDatasetQueryModel,
631
+ scratch_dir: Path | None = None,
632
+ ) -> DataFrame:
633
+ """Submit a query to an composite dataset and produce result tables.
634
+
635
+ Parameters
636
+ ----------
637
+ query : CompositeDatasetQueryModel
638
+ scratch_dir : Path | None
639
+ """
640
+ tz = self._project.config.get_base_time_dimension().get_time_zone()
641
+ assert tz is not None
642
+ scratch_dir = DsgridRuntimeConfig.load().get_scratch_dir()
643
+ # orig_query = self._load_composite_dataset_query(query.dataset_id)
644
+ with ScratchDirContext(scratch_dir) as scratch_dir_context:
645
+ df, metadata = self._read_dataset(query.dataset_id)
646
+ base_dimension_names = DatasetBaseDimensionNamesModel()
647
+ for dim_type in DimensionType:
648
+ field = dim_type.value
649
+ query_names = getattr(metadata.dimensions, field)
650
+ if len(query_names) > 1:
651
+ msg = (
652
+ "Composite datasets must have a single query name for each dimension: "
653
+ f"{dim_type} {query_names}"
654
+ )
655
+ raise DSGInvalidQuery(msg)
656
+ setattr(base_dimension_names, field, query_names[0].dimension_name)
657
+ context = QueryContext(query, base_dimension_names, scratch_dir_context)
658
+ context.metadata = metadata
659
+ # Refer to the comment in ProjectQuerySubmitter.submit for an explanation or if
660
+ # you add a new customization.
661
+ with custom_time_zone(tz.tz_name): # type: ignore
662
+ df = self._process_aggregations_and_save(df, context, repartition=False)
663
+ context.finalize()
664
+ return df
665
+
666
+ def _load_composite_dataset_query(self, dataset_id):
667
+ filename = self._composite_datasets_dir() / dataset_id / "query.json5"
668
+ return CreateCompositeDatasetQueryModel.from_file(filename)
669
+
670
+ def _read_dataset(self, dataset_id) -> tuple[DataFrame, DatasetMetadataModel]:
671
+ filename = self._composite_datasets_dir() / dataset_id / "table.parquet"
672
+ if not filename.exists():
673
+ msg = f"There is no composite dataset with dataset_id={dataset_id}"
674
+ raise DSGInvalidParameter(msg)
675
+ metadata_file = self.metadata_filename(self._composite_datasets_dir() / dataset_id)
676
+ return (
677
+ read_dataframe(filename),
678
+ DatasetMetadataModel(**load_data(metadata_file)),
679
+ )
680
+
681
+ @track_timing(timer_stats_collector)
682
+ def _save_composite_dataset(self, context: QueryContext, df, repartition):
683
+ output_dir = self._composite_datasets_dir() / context.model.dataset_id
684
+ output_dir.mkdir(exist_ok=True)
685
+ filename = output_dir / "table.parquet"
686
+ self._save_result(context, df, filename, repartition)
687
+ self.metadata_filename(output_dir).write_text(context.metadata.model_dump_json(indent=2))
688
+
689
+
690
+ class DatasetQuerySubmitter(QuerySubmitterBase):
691
+ """Submits queries for a project."""
692
+
693
+ @track_timing(timer_stats_collector)
694
+ def submit(
695
+ self,
696
+ query: DatasetQueryModel,
697
+ mgr: RegistryManager,
698
+ scratch_dir: Path | None = None,
699
+ checkpoint_file: Path | None = None,
700
+ overwrite: bool = False,
701
+ ) -> DataFrame:
702
+ """Submits a dataset query to produce a result table."""
703
+ if not query.to_dimension_references:
704
+ msg = "A dataset query must specify at least one dimension to map."
705
+ raise DSGInvalidQuery(msg)
706
+
707
+ dataset_config = mgr.dataset_manager.get_by_id(query.dataset_id)
708
+ to_dimension_mapping_refs, dims = self._build_mappings(query, dataset_config, mgr)
709
+ handler = make_dataset_schema_handler(
710
+ conn=None,
711
+ config=dataset_config,
712
+ dimension_mgr=mgr.dimension_manager,
713
+ dimension_mapping_mgr=mgr.dimension_mapping_manager,
714
+ store=mgr.dataset_manager.store,
715
+ mapping_references=to_dimension_mapping_refs,
716
+ )
717
+
718
+ base_dim_names = DatasetBaseDimensionNamesModel()
719
+ scratch_dir = scratch_dir or DsgridRuntimeConfig.load().get_scratch_dir()
720
+ time_dim = dims.get(DimensionType.TIME) or dataset_config.get_time_dimension()
721
+ time_zone = None if time_dim is None else time_dim.get_time_zone()
722
+ checkpoint = self._check_checkpoint_file(checkpoint_file, query)
723
+ with ScratchDirContext(scratch_dir) as scratch_dir_context:
724
+ context = QueryContext(
725
+ query, base_dim_names, scratch_dir_context, checkpoint=checkpoint
726
+ )
727
+ output_dir = self._query_output_dir(context)
728
+ check_overwrite(output_dir, overwrite)
729
+ args = (context, handler)
730
+ kwargs = {"time_dimension": dims.get(DimensionType.TIME)}
731
+ if time_dim is not None and time_zone is not None:
732
+ with custom_time_zone(time_zone.tz_name):
733
+ df = self._run_query(*args, **kwargs)
734
+ else:
735
+ df = self._run_query(*args, **kwargs)
736
+ return df
737
+
738
+ def _build_mappings(
739
+ self, query: DatasetQueryModel, config: DatasetConfig, mgr: RegistryManager
740
+ ) -> tuple[list[DimensionMappingReferenceModel], dict[DimensionType, DimensionBaseConfig]]:
741
+ config = mgr.dataset_manager.get_by_id(query.dataset_id)
742
+ to_dimension_mapping_refs: list[DimensionMappingReferenceModel] = []
743
+ mapped_dimension_types: set[DimensionType] = set()
744
+ dimensions: dict[DimensionType, DimensionBaseConfig] = {}
745
+ with mgr.dimension_mapping_manager.db.engine.connect() as conn:
746
+ graph = mgr.dimension_mapping_manager.build_graph(conn=conn)
747
+ for to_dim_ref in query.to_dimension_references:
748
+ to_dim = mgr.dimension_manager.get_by_id(
749
+ to_dim_ref.dimension_id, version=to_dim_ref.version
750
+ )
751
+ if to_dim.model.dimension_type in mapped_dimension_types:
752
+ msg = f"A dataset query cannot map multiple dimensions of the same type: {to_dim.model.dimension_type}"
753
+ raise DSGInvalidQuery(msg)
754
+ dataset_dim = config.get_dimension(to_dim.model.dimension_type)
755
+ assert dataset_dim is not None
756
+ if to_dim.model.dimension_id == dataset_dim.model.dimension_id:
757
+ if to_dim.model.version != dataset_dim.model.version:
758
+ msg = (
759
+ f"A to_dimension_reference cannot point to a different version of a "
760
+ f"dataset's dimension dimension: dataset version = {dataset_dim.model.version}, "
761
+ f"dimension version = {to_dim.model.version}"
762
+ )
763
+ raise DSGInvalidQuery(msg)
764
+ # No mapping is required.
765
+ continue
766
+ if to_dim.model.dimension_type != DimensionType.TIME:
767
+ refs = mgr.dimension_mapping_manager.list_mappings_between_dimensions(
768
+ graph,
769
+ dataset_dim.model.dimension_id,
770
+ to_dim.model.dimension_id,
771
+ )
772
+ to_dimension_mapping_refs += refs
773
+ mapped_dimension_types.add(to_dim.model.dimension_type)
774
+ dimensions[to_dim.model.dimension_type] = to_dim
775
+ return to_dimension_mapping_refs, dimensions
776
+
777
+ def _check_checkpoint_file(
778
+ self, checkpoint_file: Path | None, query: DatasetQueryModel
779
+ ) -> MapOperationCheckpoint | None:
780
+ if checkpoint_file is None:
781
+ return None
782
+
783
+ if query.mapping_plan is None:
784
+ msg = f"Query {query.name} does not have a mapping plan. A checkpoint file cannot be used."
785
+ raise DSGInvalidQuery(msg)
786
+
787
+ checkpoint = MapOperationCheckpoint.from_file(checkpoint_file)
788
+ if query.dataset_id != checkpoint.dataset_id:
789
+ msg = (
790
+ f"The dataset_id in the checkpoint file {checkpoint.dataset_id} does not match "
791
+ f"the query dataset_id {query.dataset_id}."
792
+ )
793
+ raise DSGInvalidQuery(msg)
794
+
795
+ if query.mapping_plan.compute_hash() != checkpoint.mapping_plan_hash:
796
+ msg = (
797
+ f"The hash of the mapping plan for dataset {checkpoint.dataset_id} "
798
+ "does not match the checkpoint file. Cannot use the checkpoint."
799
+ )
800
+ raise DSGInvalidParameter(msg)
801
+
802
+ return checkpoint
803
+
804
+ def _run_query(
805
+ self,
806
+ context: QueryContext,
807
+ handler: DatasetSchemaHandlerBase,
808
+ time_dimension: TimeDimensionBaseConfig | None,
809
+ ) -> DataFrame:
810
+ df = handler.make_mapped_dataframe(context, time_dimension=time_dimension)
811
+ df = self._postprocess(context, df)
812
+ self._save_results(context, df)
813
+ return df
814
+
815
+ def _postprocess(self, context: QueryContext, df: DataFrame) -> DataFrame:
816
+ if context.model.result.sort_columns:
817
+ df = df.sort(*context.model.result.sort_columns)
818
+
819
+ if isinstance(context.model.result.table_format, PivotedTableFormatModel):
820
+ df = _pivot_table(df, context)
821
+
822
+ return df
823
+
824
+ def _query_output_dir(self, context: QueryContext) -> Path:
825
+ return self._output_dir / context.model.name
826
+
827
+ @track_timing(timer_stats_collector)
828
+ def _save_results(self, context: QueryContext, df) -> Path:
829
+ output_dir = self._query_output_dir(context)
830
+ output_dir.mkdir(exist_ok=True)
831
+ filename = output_dir / f"table.{context.model.result.output_format}"
832
+ suffix = filename.suffix
833
+ if suffix == ".csv":
834
+ df.toPandas().to_csv(filename, header=True, index=False)
835
+ elif suffix == ".parquet":
836
+ df = write_dataframe_and_auto_partition(df, filename)
837
+ else:
838
+ msg = f"Unsupported output_format={suffix}"
839
+ raise NotImplementedError(msg)
840
+
841
+ logger.info("Wrote query=%s output table to %s", context.model.name, filename)
842
+ return filename
843
+
844
+
845
+ def _pivot_table(df: DataFrame, context: QueryContext):
846
+ pivoted_column = context.convert_to_pivoted()
847
+ return pivot(df, pivoted_column, VALUE_COLUMN)