dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. build_backend.py +93 -0
  2. dsgrid/__init__.py +22 -0
  3. dsgrid/api/__init__.py +0 -0
  4. dsgrid/api/api_manager.py +179 -0
  5. dsgrid/api/app.py +419 -0
  6. dsgrid/api/models.py +60 -0
  7. dsgrid/api/response_models.py +116 -0
  8. dsgrid/apps/__init__.py +0 -0
  9. dsgrid/apps/project_viewer/app.py +216 -0
  10. dsgrid/apps/registration_gui.py +444 -0
  11. dsgrid/chronify.py +32 -0
  12. dsgrid/cli/__init__.py +0 -0
  13. dsgrid/cli/common.py +120 -0
  14. dsgrid/cli/config.py +176 -0
  15. dsgrid/cli/download.py +13 -0
  16. dsgrid/cli/dsgrid.py +157 -0
  17. dsgrid/cli/dsgrid_admin.py +92 -0
  18. dsgrid/cli/install_notebooks.py +62 -0
  19. dsgrid/cli/query.py +729 -0
  20. dsgrid/cli/registry.py +1862 -0
  21. dsgrid/cloud/__init__.py +0 -0
  22. dsgrid/cloud/cloud_storage_interface.py +140 -0
  23. dsgrid/cloud/factory.py +31 -0
  24. dsgrid/cloud/fake_storage_interface.py +37 -0
  25. dsgrid/cloud/s3_storage_interface.py +156 -0
  26. dsgrid/common.py +36 -0
  27. dsgrid/config/__init__.py +0 -0
  28. dsgrid/config/annual_time_dimension_config.py +194 -0
  29. dsgrid/config/common.py +142 -0
  30. dsgrid/config/config_base.py +148 -0
  31. dsgrid/config/dataset_config.py +907 -0
  32. dsgrid/config/dataset_schema_handler_factory.py +46 -0
  33. dsgrid/config/date_time_dimension_config.py +136 -0
  34. dsgrid/config/dimension_config.py +54 -0
  35. dsgrid/config/dimension_config_factory.py +65 -0
  36. dsgrid/config/dimension_mapping_base.py +350 -0
  37. dsgrid/config/dimension_mappings_config.py +48 -0
  38. dsgrid/config/dimensions.py +1025 -0
  39. dsgrid/config/dimensions_config.py +71 -0
  40. dsgrid/config/file_schema.py +190 -0
  41. dsgrid/config/index_time_dimension_config.py +80 -0
  42. dsgrid/config/input_dataset_requirements.py +31 -0
  43. dsgrid/config/mapping_tables.py +209 -0
  44. dsgrid/config/noop_time_dimension_config.py +42 -0
  45. dsgrid/config/project_config.py +1462 -0
  46. dsgrid/config/registration_models.py +188 -0
  47. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  48. dsgrid/config/simple_models.py +49 -0
  49. dsgrid/config/supplemental_dimension.py +29 -0
  50. dsgrid/config/time_dimension_base_config.py +192 -0
  51. dsgrid/data_models.py +155 -0
  52. dsgrid/dataset/__init__.py +0 -0
  53. dsgrid/dataset/dataset.py +123 -0
  54. dsgrid/dataset/dataset_expression_handler.py +86 -0
  55. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  56. dsgrid/dataset/dataset_schema_handler_base.py +945 -0
  57. dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
  58. dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
  59. dsgrid/dataset/growth_rates.py +162 -0
  60. dsgrid/dataset/models.py +51 -0
  61. dsgrid/dataset/table_format_handler_base.py +257 -0
  62. dsgrid/dataset/table_format_handler_factory.py +17 -0
  63. dsgrid/dataset/unpivoted_table.py +121 -0
  64. dsgrid/dimension/__init__.py +0 -0
  65. dsgrid/dimension/base_models.py +230 -0
  66. dsgrid/dimension/dimension_filters.py +308 -0
  67. dsgrid/dimension/standard.py +252 -0
  68. dsgrid/dimension/time.py +352 -0
  69. dsgrid/dimension/time_utils.py +103 -0
  70. dsgrid/dsgrid_rc.py +88 -0
  71. dsgrid/exceptions.py +105 -0
  72. dsgrid/filesystem/__init__.py +0 -0
  73. dsgrid/filesystem/cloud_filesystem.py +32 -0
  74. dsgrid/filesystem/factory.py +32 -0
  75. dsgrid/filesystem/filesystem_interface.py +136 -0
  76. dsgrid/filesystem/local_filesystem.py +74 -0
  77. dsgrid/filesystem/s3_filesystem.py +118 -0
  78. dsgrid/loggers.py +132 -0
  79. dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
  80. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
  81. dsgrid/notebooks/registration.ipynb +48 -0
  82. dsgrid/notebooks/start_notebook.sh +11 -0
  83. dsgrid/project.py +451 -0
  84. dsgrid/query/__init__.py +0 -0
  85. dsgrid/query/dataset_mapping_plan.py +142 -0
  86. dsgrid/query/derived_dataset.py +388 -0
  87. dsgrid/query/models.py +728 -0
  88. dsgrid/query/query_context.py +287 -0
  89. dsgrid/query/query_submitter.py +994 -0
  90. dsgrid/query/report_factory.py +19 -0
  91. dsgrid/query/report_peak_load.py +70 -0
  92. dsgrid/query/reports_base.py +20 -0
  93. dsgrid/registry/__init__.py +0 -0
  94. dsgrid/registry/bulk_register.py +165 -0
  95. dsgrid/registry/common.py +287 -0
  96. dsgrid/registry/config_update_checker_base.py +63 -0
  97. dsgrid/registry/data_store_factory.py +34 -0
  98. dsgrid/registry/data_store_interface.py +74 -0
  99. dsgrid/registry/dataset_config_generator.py +158 -0
  100. dsgrid/registry/dataset_registry_manager.py +950 -0
  101. dsgrid/registry/dataset_update_checker.py +16 -0
  102. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  103. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  104. dsgrid/registry/dimension_registry_manager.py +413 -0
  105. dsgrid/registry/dimension_update_checker.py +16 -0
  106. dsgrid/registry/duckdb_data_store.py +207 -0
  107. dsgrid/registry/filesystem_data_store.py +150 -0
  108. dsgrid/registry/filter_registry_manager.py +123 -0
  109. dsgrid/registry/project_config_generator.py +57 -0
  110. dsgrid/registry/project_registry_manager.py +1623 -0
  111. dsgrid/registry/project_update_checker.py +48 -0
  112. dsgrid/registry/registration_context.py +223 -0
  113. dsgrid/registry/registry_auto_updater.py +316 -0
  114. dsgrid/registry/registry_database.py +667 -0
  115. dsgrid/registry/registry_interface.py +446 -0
  116. dsgrid/registry/registry_manager.py +558 -0
  117. dsgrid/registry/registry_manager_base.py +367 -0
  118. dsgrid/registry/versioning.py +92 -0
  119. dsgrid/rust_ext/__init__.py +14 -0
  120. dsgrid/rust_ext/find_minimal_patterns.py +129 -0
  121. dsgrid/spark/__init__.py +0 -0
  122. dsgrid/spark/functions.py +589 -0
  123. dsgrid/spark/types.py +110 -0
  124. dsgrid/tests/__init__.py +0 -0
  125. dsgrid/tests/common.py +140 -0
  126. dsgrid/tests/make_us_data_registry.py +265 -0
  127. dsgrid/tests/register_derived_datasets.py +103 -0
  128. dsgrid/tests/utils.py +25 -0
  129. dsgrid/time/__init__.py +0 -0
  130. dsgrid/time/time_conversions.py +80 -0
  131. dsgrid/time/types.py +67 -0
  132. dsgrid/units/__init__.py +0 -0
  133. dsgrid/units/constants.py +113 -0
  134. dsgrid/units/convert.py +71 -0
  135. dsgrid/units/energy.py +145 -0
  136. dsgrid/units/power.py +87 -0
  137. dsgrid/utils/__init__.py +0 -0
  138. dsgrid/utils/dataset.py +830 -0
  139. dsgrid/utils/files.py +179 -0
  140. dsgrid/utils/filters.py +125 -0
  141. dsgrid/utils/id_remappings.py +100 -0
  142. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  143. dsgrid/utils/py_expression_eval/README.md +8 -0
  144. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  145. dsgrid/utils/py_expression_eval/tests.py +283 -0
  146. dsgrid/utils/run_command.py +70 -0
  147. dsgrid/utils/scratch_dir_context.py +65 -0
  148. dsgrid/utils/spark.py +918 -0
  149. dsgrid/utils/spark_partition.py +98 -0
  150. dsgrid/utils/timing.py +239 -0
  151. dsgrid/utils/utilities.py +221 -0
  152. dsgrid/utils/versioning.py +36 -0
  153. dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
  154. dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
  155. dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
  156. dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
  157. dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,994 @@
1
+ import abc
2
+ import json
3
+ import logging
4
+ import shutil
5
+ from pathlib import Path
6
+ import copy
7
+ from zipfile import ZipFile
8
+
9
+ from chronify.utils.path_utils import check_overwrite
10
+ from semver import VersionInfo
11
+ from sqlalchemy import Connection
12
+
13
+ import dsgrid
14
+ from dsgrid.common import VALUE_COLUMN, BackendEngine
15
+ from dsgrid.config.dataset_config import DatasetConfig
16
+ from dsgrid.config.dimension_config import DimensionBaseConfig
17
+ from dsgrid.config.project_config import DatasetBaseDimensionNamesModel
18
+ from dsgrid.config.dimension_mapping_base import DimensionMappingReferenceModel
19
+ from dsgrid.config.time_dimension_base_config import TimeDimensionBaseConfig
20
+ from dsgrid.dataset.dataset_expression_handler import (
21
+ DatasetExpressionHandler,
22
+ evaluate_expression,
23
+ )
24
+ from dsgrid.utils.scratch_dir_context import ScratchDirContext
25
+ from dsgrid.dataset.models import ValueFormat, PivotedTableFormatModel
26
+ from dsgrid.dataset.dataset_schema_handler_base import DatasetSchemaHandlerBase
27
+ from dsgrid.dataset.table_format_handler_factory import make_table_format_handler
28
+ from dsgrid.dimension.base_models import DimensionCategory, DimensionType
29
+ from dsgrid.dimension.dimension_filters import SubsetDimensionFilterModel
30
+ from dsgrid.exceptions import DSGInvalidDataset, DSGInvalidParameter, DSGInvalidQuery
31
+ from dsgrid.dsgrid_rc import DsgridRuntimeConfig
32
+ from dsgrid.query.dataset_mapping_plan import MapOperationCheckpoint
33
+ from dsgrid.query.query_context import QueryContext
34
+ from dsgrid.query.report_factory import make_report
35
+ from dsgrid.registry.registry_manager import RegistryManager
36
+ from dsgrid.spark.functions import pivot
37
+ from dsgrid.spark.types import DataFrame
38
+ from dsgrid.project import Project
39
+ from dsgrid.utils.spark import (
40
+ custom_time_zone,
41
+ read_dataframe,
42
+ try_read_dataframe,
43
+ write_dataframe,
44
+ write_dataframe_and_auto_partition,
45
+ persist_table,
46
+ )
47
+ from dsgrid.utils.timing import timer_stats_collector, track_timing
48
+ from dsgrid.utils.files import delete_if_exists, compute_hash, load_data
49
+ from dsgrid.query.models import (
50
+ DatasetQueryModel,
51
+ ProjectQueryModel,
52
+ ColumnType,
53
+ CreateCompositeDatasetQueryModel,
54
+ CompositeDatasetQueryModel,
55
+ DatasetMetadataModel,
56
+ ProjectionDatasetModel,
57
+ StandaloneDatasetModel,
58
+ )
59
+ from dsgrid.utils.dataset import (
60
+ add_time_zone,
61
+ convert_time_zone_with_chronify_spark_hive,
62
+ convert_time_zone_with_chronify_spark_path,
63
+ convert_time_zone_with_chronify_duckdb,
64
+ convert_time_zone_by_column_with_chronify_spark_hive,
65
+ convert_time_zone_by_column_with_chronify_spark_path,
66
+ convert_time_zone_by_column_with_chronify_duckdb,
67
+ )
68
+ from dsgrid.config.dataset_schema_handler_factory import make_dataset_schema_handler
69
+ from dsgrid.config.date_time_dimension_config import DateTimeDimensionConfig
70
+ from dsgrid.exceptions import DSGInvalidOperation
71
+
72
+ logger = logging.getLogger(__name__)
73
+
74
+
75
+ class QuerySubmitterBase:
76
+ """Handles query submission"""
77
+
78
+ def __init__(self, output_dir: Path):
79
+ self._output_dir = output_dir
80
+ self._cached_tables_dir().mkdir(exist_ok=True, parents=True)
81
+ self._composite_datasets_dir().mkdir(exist_ok=True, parents=True)
82
+
83
+ # TODO #186: This location will need more consideration.
84
+ # We might want to store cached datasets in the spark-warehouse and let Spark manage it
85
+ # for us. However, would we share them on the HPC? What happens on HPC walltime timeouts
86
+ # where the tables are left in intermediate states?
87
+ # This is even more of a problem on AWS.
88
+ self._cached_project_mapped_datasets_dir().mkdir(exist_ok=True, parents=True)
89
+
90
+ @abc.abstractmethod
91
+ def submit(self, *args, **kwargs) -> DataFrame:
92
+ """Submit a query for execution"""
93
+
94
+ def _composite_datasets_dir(self):
95
+ return self._output_dir / "composite_datasets"
96
+
97
+ def _cached_tables_dir(self):
98
+ """Directory for intermediate tables made up of multiple project-mapped datasets."""
99
+ return self._output_dir / "cached_tables"
100
+
101
+ def _cached_project_mapped_datasets_dir(self):
102
+ """Directory for intermediate project-mapped datasets.
103
+ Data could be filtered.
104
+ """
105
+ return self._output_dir / "cached_project_mapped_datasets"
106
+
107
+ @staticmethod
108
+ def metadata_filename(path: Path):
109
+ return path / "metadata.json"
110
+
111
+ @staticmethod
112
+ def query_filename(path: Path):
113
+ return path / "query.json5"
114
+
115
+ @staticmethod
116
+ def table_filename(path: Path):
117
+ return path / "table.parquet"
118
+
119
+ @staticmethod
120
+ def _cached_table_filename(path: Path):
121
+ return path / "table.parquet"
122
+
123
+
124
+ class ProjectBasedQuerySubmitter(QuerySubmitterBase):
125
+ def __init__(self, project: Project, *args, **kwargs):
126
+ super().__init__(*args, **kwargs)
127
+ self._project = project
128
+
129
+ @property
130
+ def project(self):
131
+ return self._project
132
+
133
+ def _create_table_hash(self, context: QueryContext) -> tuple[str, str]:
134
+ """Create a hash that can be used to identify whether the following sequence
135
+ can be skipped based on a previous query:
136
+ - Apply expression across all datasets in the query.
137
+ - Apply filters.
138
+ - Apply aggregations.
139
+
140
+ Examples of changes that will invalidate the query:
141
+ - Change to the project section of the query
142
+ - Bump to project major version number
143
+ - Change to a dataset version
144
+ - Change to a project's dimension requirements for a dataset
145
+ - Change to a dataset dimension mapping
146
+ """
147
+ assert isinstance(context.model, ProjectQueryModel) or isinstance(
148
+ context.model, CreateCompositeDatasetQueryModel
149
+ )
150
+ data = {
151
+ "project_major_version": VersionInfo.parse(self._project.config.model.version).major,
152
+ "project_query": context.model.serialize_cached_content(),
153
+ "datasets": [
154
+ self._project.config.get_dataset(x.dataset_id).model_dump(mode="json")
155
+ for x in context.model.project.dataset.source_datasets
156
+ ],
157
+ }
158
+ text = json.dumps(data, indent=2)
159
+ hash_value = compute_hash(text.encode())
160
+ return text, hash_value
161
+
162
+ def _try_read_cache(self, context: QueryContext):
163
+ _, hash_value = self._create_table_hash(context)
164
+ cached_dir = self._cached_tables_dir() / hash_value
165
+ filename = self._cached_table_filename(cached_dir)
166
+ df = try_read_dataframe(filename)
167
+ if df is not None:
168
+ logger.info("Load intermediate table from cache: %s", filename)
169
+ metadata_file = self.metadata_filename(cached_dir)
170
+ return df, DatasetMetadataModel.from_file(metadata_file)
171
+ return None, None
172
+
173
+ def _run_checks(self, model: ProjectQueryModel) -> DatasetBaseDimensionNamesModel:
174
+ subsets = set(self.project.config.list_dimension_names(DimensionCategory.SUBSET))
175
+ for agg in model.result.aggregations:
176
+ for _, column in agg.iter_dimensions_to_keep():
177
+ dimension_name = column.dimension_name
178
+ if dimension_name in subsets:
179
+ subset_dim = self._project.config.get_dimension(dimension_name)
180
+ dim_type = subset_dim.model.dimension_type
181
+ supp_names = " ".join(
182
+ self._project.config.get_supplemental_dimension_to_name_mapping()[dim_type]
183
+ )
184
+ base_names = [
185
+ x.model.name
186
+ for x in self._project.config.list_base_dimensions(dimension_type=dim_type)
187
+ ]
188
+ msg = (
189
+ f"Subset dimensions cannot be used in aggregations: "
190
+ f"{dimension_name=}. Only base and supplemental dimensions are "
191
+ f"allowed. base={base_names} supplemental={supp_names}"
192
+ )
193
+ raise DSGInvalidQuery(msg)
194
+
195
+ for report_inputs in model.result.reports:
196
+ report = make_report(report_inputs.report_type)
197
+ report.check_query(model)
198
+
199
+ with self._project.dimension_mapping_manager.db.engine.connect() as conn:
200
+ return self._check_datasets(model, conn)
201
+
202
+ def _check_datasets(
203
+ self, query_model: ProjectQueryModel, conn: Connection
204
+ ) -> DatasetBaseDimensionNamesModel:
205
+ base_dimension_names: DatasetBaseDimensionNamesModel | None = None
206
+ dataset_ids: list[str] = []
207
+ query_names: list[DatasetBaseDimensionNamesModel] = []
208
+ for dataset in query_model.project.dataset.source_datasets:
209
+ src_dataset_ids = dataset.list_source_dataset_ids()
210
+ dataset_ids += src_dataset_ids
211
+ if isinstance(dataset, StandaloneDatasetModel):
212
+ query_names.append(
213
+ self._project.config.get_dataset_base_dimension_names(dataset.dataset_id)
214
+ )
215
+ elif isinstance(dataset, ProjectionDatasetModel):
216
+ query_names += [
217
+ self._project.config.get_dataset_base_dimension_names(
218
+ dataset.initial_value_dataset_id
219
+ ),
220
+ self._project.config.get_dataset_base_dimension_names(
221
+ dataset.growth_rate_dataset_id
222
+ ),
223
+ ]
224
+ else:
225
+ msg = f"Unhandled dataset type: {dataset=}"
226
+ raise NotImplementedError(msg)
227
+
228
+ for dataset_id in src_dataset_ids:
229
+ dataset = self._project.load_dataset(dataset_id, conn=conn)
230
+ plan = query_model.project.get_dataset_mapping_plan(dataset_id)
231
+ if plan is None:
232
+ plan = dataset.handler.build_default_dataset_mapping_plan()
233
+ query_model.project.set_dataset_mapper(plan)
234
+ else:
235
+ dataset.handler.check_dataset_mapping_plan(plan, self._project.config)
236
+
237
+ for dataset_id, names in zip(dataset_ids, query_names):
238
+ self._fix_legacy_base_dimension_names(names, dataset_id)
239
+ if base_dimension_names is None:
240
+ base_dimension_names = names
241
+ elif base_dimension_names != names:
242
+ msg = (
243
+ "Datasets in a query must have the same base dimension query names: "
244
+ f"{dataset=} {base_dimension_names} {names}"
245
+ )
246
+ raise DSGInvalidQuery(msg)
247
+
248
+ assert base_dimension_names is not None
249
+ return base_dimension_names
250
+
251
+ def _fix_legacy_base_dimension_names(
252
+ self, names: DatasetBaseDimensionNamesModel, dataset_id: str
253
+ ) -> None:
254
+ for dim_type in DimensionType:
255
+ val = getattr(names, dim_type.value)
256
+ if val is None:
257
+ # This is a workaround for dsgrid projects created before the field
258
+ # base_dimension_names was added to InputDatasetModel.
259
+ dims = self._project.config.list_base_dimensions(dimension_type=dim_type)
260
+ if len(dims) > 1:
261
+ msg = (
262
+ "The dataset's base_dimension_names value is not set and "
263
+ f"there are multiple base dimensions of type {dim_type} in the project. "
264
+ f"Please re-register the dataset with {dataset_id=}."
265
+ )
266
+ raise DSGInvalidDataset(msg)
267
+ setattr(names, dim_type.value, dims[0].model.name)
268
+
269
+ def _run_query(
270
+ self,
271
+ scratch_dir_context: ScratchDirContext,
272
+ model: ProjectQueryModel,
273
+ load_cached_table: bool,
274
+ checkpoint_file: Path | None,
275
+ persist_intermediate_table: bool,
276
+ zip_file: bool = False,
277
+ overwrite: bool = False,
278
+ ):
279
+ base_dimension_names = self._run_checks(model)
280
+ checkpoint = self._check_checkpoint_file(checkpoint_file, model)
281
+ context = QueryContext(
282
+ model,
283
+ base_dimension_names,
284
+ scratch_dir_context=scratch_dir_context,
285
+ checkpoint=checkpoint,
286
+ )
287
+ assert isinstance(context.model, ProjectQueryModel) or isinstance(
288
+ context.model, CreateCompositeDatasetQueryModel
289
+ )
290
+ context.model.project.version = str(self._project.version)
291
+ output_dir = self._output_dir / context.model.name
292
+ if output_dir.exists() and not overwrite:
293
+ msg = (
294
+ f"output directory {self._output_dir} and query name={context.model.name} will "
295
+ "overwrite an existing query results directory. "
296
+ "Choose a different path or pass force=True."
297
+ )
298
+ raise DSGInvalidParameter(msg)
299
+
300
+ df = None
301
+ if load_cached_table:
302
+ df, metadata = self._try_read_cache(context)
303
+ if df is None:
304
+ df_filenames = self._project.process_query(
305
+ context, self._cached_project_mapped_datasets_dir()
306
+ )
307
+ df = self._postprocess_datasets(context, scratch_dir_context, df_filenames)
308
+ is_cached = False
309
+ else:
310
+ context.metadata = metadata
311
+ is_cached = True
312
+
313
+ if context.model.result.aggregate_each_dataset:
314
+ # This wouldn't save any time.
315
+ persist_intermediate_table = False
316
+
317
+ if persist_intermediate_table and not is_cached:
318
+ df = self._persist_intermediate_result(context, df)
319
+
320
+ if not context.model.result.aggregate_each_dataset:
321
+ if context.model.result.dimension_filters:
322
+ df = self._apply_filters(df, context)
323
+ df = self._process_aggregations(df, context)
324
+
325
+ repartition = not persist_intermediate_table
326
+ table_filename = self._save_query_results(context, df, repartition, zip_file=zip_file)
327
+
328
+ for report_inputs in context.model.result.reports:
329
+ report = make_report(report_inputs.report_type)
330
+ output_dir = self._output_dir / context.model.name
331
+ report.generate(table_filename, output_dir, context, report_inputs.inputs)
332
+
333
+ return df, context
334
+
335
+ def _convert_time_zone(
336
+ self,
337
+ scratch_dir_context: ScratchDirContext,
338
+ model: ProjectQueryModel,
339
+ df,
340
+ context,
341
+ persist_intermediate_table: bool,
342
+ zip_file: bool = False,
343
+ ):
344
+ time_dim = copy.deepcopy(self._project.config.get_base_time_dimension())
345
+ if not isinstance(time_dim, DateTimeDimensionConfig):
346
+ msg = f"Only DateTimeDimensionConfig allowed for time zone conversion. {time_dim.__class__.__name__}"
347
+ raise DSGInvalidOperation(msg)
348
+ time_cols = list(context.get_dimension_column_names(DimensionType.TIME))
349
+ assert len(time_cols) == 1
350
+ time_col = next(iter(time_cols))
351
+ time_dim.model.time_column = time_col
352
+
353
+ config = dsgrid.runtime_config
354
+ if isinstance(model.result.time_zone, str) and model.result.time_zone != "geography":
355
+ if time_dim.supports_chronify():
356
+ match (config.backend_engine, config.use_hive_metastore):
357
+ case (BackendEngine.SPARK, True):
358
+ df = convert_time_zone_with_chronify_spark_hive(
359
+ df=df,
360
+ value_column=VALUE_COLUMN,
361
+ from_time_dim=time_dim,
362
+ time_zone=model.result.time_zone,
363
+ scratch_dir_context=scratch_dir_context,
364
+ )
365
+
366
+ case (BackendEngine.SPARK, False):
367
+ filename = persist_table(
368
+ df,
369
+ scratch_dir_context,
370
+ tag="project query before time mapping",
371
+ )
372
+ df = convert_time_zone_with_chronify_spark_path(
373
+ df=df,
374
+ filename=filename,
375
+ value_column=VALUE_COLUMN,
376
+ from_time_dim=time_dim,
377
+ time_zone=model.result.time_zone,
378
+ scratch_dir_context=scratch_dir_context,
379
+ )
380
+ case (BackendEngine.DUCKDB, _):
381
+ df = convert_time_zone_with_chronify_duckdb(
382
+ df=df,
383
+ value_column=VALUE_COLUMN,
384
+ from_time_dim=time_dim,
385
+ time_zone=model.result.time_zone,
386
+ scratch_dir_context=scratch_dir_context,
387
+ )
388
+
389
+ else:
390
+ msg = "time_dim must support Chronify"
391
+ raise DSGInvalidParameter(msg)
392
+
393
+ elif model.result.time_zone == "geography":
394
+ if "time_zone" not in df.columns:
395
+ geo_cols = list(context.get_dimension_column_names(DimensionType.GEOGRAPHY))
396
+ assert len(geo_cols) == 1
397
+ geo_col = next(iter(geo_cols))
398
+ geo_dim = self._project.config.get_base_dimension(DimensionType.GEOGRAPHY)
399
+ if model.result.replace_ids_with_names:
400
+ dim_key = "name"
401
+ else:
402
+ dim_key = "id"
403
+ df = add_time_zone(df, geo_dim, df_key=geo_col, dim_key=dim_key)
404
+
405
+ if time_dim.supports_chronify():
406
+ # use chronify
407
+ match (config.backend_engine, config.use_hive_metastore):
408
+ case (BackendEngine.SPARK, True):
409
+ df = convert_time_zone_by_column_with_chronify_spark_hive(
410
+ df=df,
411
+ value_column=VALUE_COLUMN,
412
+ from_time_dim=time_dim,
413
+ time_zone_column="time_zone",
414
+ scratch_dir_context=scratch_dir_context,
415
+ wrap_time_allowed=False,
416
+ )
417
+ case (BackendEngine.SPARK, False):
418
+ filename = persist_table(
419
+ df,
420
+ scratch_dir_context,
421
+ tag="project query before time mapping",
422
+ )
423
+ df = convert_time_zone_by_column_with_chronify_spark_path(
424
+ df=df,
425
+ filename=filename,
426
+ value_column=VALUE_COLUMN,
427
+ from_time_dim=time_dim,
428
+ time_zone_column="time_zone",
429
+ scratch_dir_context=scratch_dir_context,
430
+ wrap_time_allowed=False,
431
+ )
432
+ case (BackendEngine.DUCKDB, _):
433
+ df = convert_time_zone_by_column_with_chronify_duckdb(
434
+ df=df,
435
+ value_column=VALUE_COLUMN,
436
+ from_time_dim=time_dim,
437
+ time_zone_column="time_zone",
438
+ scratch_dir_context=scratch_dir_context,
439
+ wrap_time_allowed=False,
440
+ )
441
+
442
+ else:
443
+ msg = "time_dim must support Chronify"
444
+ raise DSGInvalidParameter(msg)
445
+ else:
446
+ msg = f"Unknown input {model.result.time_zone=}"
447
+ raise DSGInvalidParameter(msg)
448
+
449
+ repartition = not persist_intermediate_table
450
+ table_filename = self._save_query_results(context, df, repartition, zip_file=zip_file)
451
+
452
+ for report_inputs in context.model.result.reports:
453
+ report = make_report(report_inputs.report_type)
454
+ output_dir = self._output_dir / context.model.name
455
+ report.generate(table_filename, output_dir, context, report_inputs.inputs)
456
+
457
+ return df, context
458
+
459
+ def _check_checkpoint_file(
460
+ self, checkpoint_file: Path | None, model: ProjectQueryModel
461
+ ) -> MapOperationCheckpoint | None:
462
+ if checkpoint_file is None:
463
+ return None
464
+
465
+ checkpoint = MapOperationCheckpoint.from_file(checkpoint_file)
466
+ confirmed_checkpoint = False
467
+ for dataset in model.project.dataset.source_datasets:
468
+ if dataset.get_dataset_id() == checkpoint.dataset_id:
469
+ for plan in model.project.mapping_plans:
470
+ if plan.dataset_id == checkpoint.dataset_id:
471
+ if plan.compute_hash() == checkpoint.mapping_plan_hash:
472
+ confirmed_checkpoint = True
473
+ else:
474
+ msg = (
475
+ f"The hash of the mapping plan for dataset {checkpoint.dataset_id} "
476
+ "does not match the checkpoint file. Cannot use the checkpoint."
477
+ )
478
+ raise DSGInvalidParameter(msg)
479
+ if not confirmed_checkpoint:
480
+ msg = f"Checkpoint {checkpoint_file} does not match any dataset in the query."
481
+ raise DSGInvalidParameter(msg)
482
+
483
+ return checkpoint
484
+
485
+ @track_timing(timer_stats_collector)
486
+ def _persist_intermediate_result(self, context: QueryContext, df):
487
+ text, hash_value = self._create_table_hash(context)
488
+ cached_dir = self._cached_tables_dir() / hash_value
489
+ if cached_dir.exists():
490
+ shutil.rmtree(cached_dir)
491
+ cached_dir.mkdir()
492
+ filename = self._cached_table_filename(cached_dir)
493
+ df = write_dataframe_and_auto_partition(df, filename)
494
+
495
+ self.metadata_filename(cached_dir).write_text(
496
+ context.metadata.model_dump_json(indent=2), encoding="utf-8"
497
+ )
498
+ self.query_filename(cached_dir).write_text(text, encoding="utf-8")
499
+ logger.debug("Persisted intermediate table to %s", filename)
500
+ return df
501
+
502
+ def _postprocess_datasets(
503
+ self,
504
+ context: QueryContext,
505
+ scratch_dir_context: ScratchDirContext,
506
+ df_filenames: dict[str, Path],
507
+ ) -> DataFrame:
508
+ if context.model.result.aggregate_each_dataset:
509
+ for dataset_id, path in df_filenames.items():
510
+ df = read_dataframe(path)
511
+ if context.model.result.dimension_filters:
512
+ df = self._apply_filters(df, context)
513
+ df = self._process_aggregations(df, context, dataset_id=dataset_id)
514
+ path = scratch_dir_context.get_temp_filename(suffix=".parquet")
515
+ write_dataframe(df, path)
516
+ df_filenames[dataset_id] = path
517
+
518
+ # All dataset columns need to be in the same order.
519
+ context.consolidate_dataset_metadata()
520
+ datasets = self._convert_datasets(context, df_filenames)
521
+ assert isinstance(context.model, ProjectQueryModel) or isinstance(
522
+ context.model, CreateCompositeDatasetQueryModel
523
+ )
524
+ assert context.model.project.dataset.expression is not None
525
+ return evaluate_expression(context.model.project.dataset.expression, datasets).df
526
+
527
+ def _convert_datasets(self, context: QueryContext, filenames: dict[str, Path]):
528
+ dim_columns, time_columns = self._get_dimension_columns(context)
529
+ expected_columns = time_columns + dim_columns
530
+ expected_columns.append(VALUE_COLUMN)
531
+
532
+ datasets = {}
533
+ for dataset_id, path in filenames.items():
534
+ df = read_dataframe(path)
535
+ unexpected = sorted(set(df.columns).difference(expected_columns))
536
+ if unexpected:
537
+ msg = f"Unexpected columns are present in {dataset_id=} {unexpected=}"
538
+ raise Exception(msg)
539
+ datasets[dataset_id] = DatasetExpressionHandler(
540
+ df.select(*expected_columns), time_columns + dim_columns, [VALUE_COLUMN]
541
+ )
542
+ return datasets
543
+
544
+ def _get_dimension_columns(self, context: QueryContext) -> tuple[list[str], list[str]]:
545
+ match context.model.result.column_type:
546
+ case ColumnType.DIMENSION_NAMES:
547
+ dim_columns = context.get_all_dimension_column_names(exclude={DimensionType.TIME})
548
+ time_columns = context.get_dimension_column_names(DimensionType.TIME)
549
+ case ColumnType.DIMENSION_TYPES:
550
+ dim_columns = {x.value for x in DimensionType if x != DimensionType.TIME}
551
+ time_columns = context.get_dimension_column_names(DimensionType.TIME)
552
+ case _:
553
+ msg = f"BUG: unhandled {context.model.result.column_type=}"
554
+ raise NotImplementedError(msg)
555
+
556
+ return sorted(dim_columns), sorted(time_columns)
557
+
558
+ def _process_aggregations(
559
+ self, df: DataFrame, context: QueryContext, dataset_id: str | None = None
560
+ ) -> DataFrame:
561
+ handler = make_table_format_handler(
562
+ ValueFormat.STACKED, self._project.config, dataset_id=dataset_id
563
+ )
564
+ df = handler.process_aggregations(df, context.model.result.aggregations, context)
565
+
566
+ if context.model.result.replace_ids_with_names:
567
+ df = handler.replace_ids_with_names(df)
568
+
569
+ if context.model.result.sort_columns:
570
+ df = df.sort(*context.model.result.sort_columns)
571
+
572
+ if isinstance(context.model.result.table_format, PivotedTableFormatModel):
573
+ df = _pivot_table(df, context)
574
+
575
+ return df
576
+
577
+ def _process_aggregations_and_save(
578
+ self,
579
+ df: DataFrame,
580
+ context: QueryContext,
581
+ repartition: bool,
582
+ zip_file: bool = False,
583
+ ) -> DataFrame:
584
+ df = self._process_aggregations(df, context)
585
+
586
+ self._save_query_results(context, df, repartition, zip_file=zip_file)
587
+ return df
588
+
589
+ def _apply_filters(self, df, context: QueryContext):
590
+ for dim_filter in context.model.result.dimension_filters:
591
+ column_names = context.get_dimension_column_names(dim_filter.dimension_type)
592
+ if len(column_names) > 1:
593
+ msg = f"Cannot filter {dim_filter} when there are multiple {column_names=}"
594
+ raise NotImplementedError(msg)
595
+ if isinstance(dim_filter, SubsetDimensionFilterModel):
596
+ records = dim_filter.get_filtered_records_dataframe(
597
+ self.project.config.get_dimension
598
+ )
599
+ column = next(iter(column_names))
600
+ df = df.join(
601
+ records.select("id"),
602
+ on=getattr(df, column) == getattr(records, "id"),
603
+ ).drop("id")
604
+ else:
605
+ query_name = dim_filter.dimension_name
606
+ if query_name not in df.columns:
607
+ # Consider catching this exception and still write to a file.
608
+ # It could mean writing a lot of data the user doesn't want.
609
+ msg = f"filter column {query_name} is not in the dataframe: {df.columns}"
610
+ raise DSGInvalidParameter(msg)
611
+ df = dim_filter.apply_filter(df, column=query_name)
612
+ return df
613
+
614
+ @track_timing(timer_stats_collector)
615
+ def _save_query_results(
616
+ self,
617
+ context: QueryContext,
618
+ df,
619
+ repartition,
620
+ aggregation_name=None,
621
+ zip_file=False,
622
+ ):
623
+ output_dir = self._output_dir / context.model.name
624
+ output_dir.mkdir(exist_ok=True)
625
+ if aggregation_name is not None:
626
+ output_dir /= aggregation_name
627
+ output_dir.mkdir(exist_ok=True)
628
+ filename = output_dir / f"table.{context.model.result.output_format}"
629
+ self._save_result(context, df, filename, repartition)
630
+ if zip_file:
631
+ zip_name = Path(str(output_dir) + ".zip")
632
+ with ZipFile(zip_name, "w") as zipf:
633
+ for path in output_dir.rglob("*"):
634
+ zipf.write(path)
635
+ return filename
636
+
637
+ def _save_result(self, context: QueryContext, df, filename, repartition):
638
+ output_dir = filename.parent
639
+ suffix = filename.suffix
640
+ if suffix == ".csv":
641
+ df.toPandas().to_csv(filename, header=True, index=False)
642
+ elif suffix == ".parquet":
643
+ if repartition:
644
+ df = write_dataframe_and_auto_partition(df, filename)
645
+ else:
646
+ delete_if_exists(filename)
647
+ write_dataframe(df, filename, overwrite=True)
648
+ else:
649
+ msg = f"Unsupported output_format={suffix}"
650
+ raise NotImplementedError(msg)
651
+ self.query_filename(output_dir).write_text(context.model.serialize_with_hash()[1])
652
+ self.metadata_filename(output_dir).write_text(context.metadata.model_dump_json(indent=2))
653
+ logger.info("Wrote query=%s output table to %s", context.model.name, filename)
654
+
655
+
656
+ class ProjectQuerySubmitter(ProjectBasedQuerySubmitter):
657
+ """Submits queries for a project."""
658
+
659
+ @track_timing(timer_stats_collector)
660
+ def submit(
661
+ self,
662
+ model: ProjectQueryModel,
663
+ scratch_dir: Path | None = None,
664
+ checkpoint_file: Path | None = None,
665
+ persist_intermediate_table: bool = True,
666
+ load_cached_table: bool = True,
667
+ zip_file: bool = False,
668
+ overwrite: bool = False,
669
+ ) -> DataFrame:
670
+ """Submits a project query to consolidate datasets and produce result tables.
671
+
672
+ Parameters
673
+ ----------
674
+ model : ProjectQueryResultModel
675
+ checkpoint_file : bool, optional
676
+ Optional checkpoint file from which to resume the operation.
677
+ persist_intermediate_table : bool, optional
678
+ Persist the intermediate consolidated table.
679
+ load_cached_table : bool, optional
680
+ Load a cached consolidated table if the query matches an existing query.
681
+ zip_file : bool, optional
682
+ Create a zip file with all output files.
683
+ overwrite : bool
684
+ If True, overwrite any existing output directory.
685
+
686
+ Returns
687
+ -------
688
+ pyspark.sql.DataFrame
689
+
690
+ Raises
691
+ ------
692
+ DSGInvalidParameter
693
+ Raised if the model defines a project version
694
+ DSGInvalidQuery
695
+ Raised if the query is invalid
696
+ """
697
+ tz = self._project.config.get_base_time_dimension().get_time_zone()
698
+ assert tz is not None, "Project base time dimension must have a time zone"
699
+
700
+ scratch_dir = scratch_dir or DsgridRuntimeConfig.load().get_scratch_dir()
701
+ with ScratchDirContext(scratch_dir) as scratch_dir_context:
702
+ # Ensure that queries that aggregate time reflect the project's time zone instead
703
+ # of the local computer.
704
+ # If any other settings get customized here, handle them in restart_spark()
705
+ # as well. This change won't persist Spark session restarts.
706
+ with custom_time_zone(tz):
707
+ df, context = self._run_query(
708
+ scratch_dir_context,
709
+ model,
710
+ load_cached_table,
711
+ checkpoint_file=checkpoint_file,
712
+ persist_intermediate_table=persist_intermediate_table,
713
+ zip_file=zip_file,
714
+ overwrite=overwrite,
715
+ )
716
+ if model.result.time_zone:
717
+ df, context = self._convert_time_zone(
718
+ scratch_dir_context,
719
+ model,
720
+ df,
721
+ context,
722
+ persist_intermediate_table=persist_intermediate_table,
723
+ zip_file=zip_file,
724
+ )
725
+ context.finalize()
726
+
727
+ return df
728
+
729
+
730
+ class CompositeDatasetQuerySubmitter(ProjectBasedQuerySubmitter):
731
+ """Submits queries for a composite dataset."""
732
+
733
+ @track_timing(timer_stats_collector)
734
+ def create_dataset(
735
+ self,
736
+ model: CreateCompositeDatasetQueryModel,
737
+ scratch_dir: Path | None = None,
738
+ persist_intermediate_table=False,
739
+ load_cached_table=True,
740
+ force=False,
741
+ ):
742
+ """Create a composite dataset from a project.
743
+
744
+ Parameters
745
+ ----------
746
+ model : CreateCompositeDatasetQueryModel
747
+ persist_intermediate_table : bool, optional
748
+ Persist the intermediate consolidated table.
749
+ load_cached_table : bool, optional
750
+ Load a cached consolidated table if the query matches an existing query.
751
+ force : bool
752
+ If True, overwrite any existing output directory.
753
+
754
+ """
755
+ tz = self._project.config.get_base_time_dimension().get_time_zone()
756
+ scratch_dir = scratch_dir or DsgridRuntimeConfig.load().get_scratch_dir()
757
+ with ScratchDirContext(scratch_dir) as scratch_dir_context:
758
+ # Ensure that queries that aggregate time reflect the project's time zone instead
759
+ # of the local computer.
760
+ # If any other settings get customized here, handle them in restart_spark()
761
+ # as well. This change won't persist Spark session restarts.
762
+ with custom_time_zone(tz): # type: ignore
763
+ df, context = self._run_query(
764
+ scratch_dir_context,
765
+ model,
766
+ load_cached_table,
767
+ None,
768
+ persist_intermediate_table,
769
+ overwrite=force,
770
+ )
771
+ self._save_composite_dataset(context, df, not persist_intermediate_table)
772
+ context.finalize()
773
+
774
+ @track_timing(timer_stats_collector)
775
+ def submit(
776
+ self,
777
+ query: CompositeDatasetQueryModel,
778
+ scratch_dir: Path | None = None,
779
+ ) -> DataFrame:
780
+ """Submit a query to an composite dataset and produce result tables.
781
+
782
+ Parameters
783
+ ----------
784
+ query : CompositeDatasetQueryModel
785
+ scratch_dir : Path | None
786
+ """
787
+ tz = self._project.config.get_base_time_dimension().get_time_zone()
788
+ assert tz is not None
789
+ scratch_dir = DsgridRuntimeConfig.load().get_scratch_dir()
790
+ # orig_query = self._load_composite_dataset_query(query.dataset_id)
791
+ with ScratchDirContext(scratch_dir) as scratch_dir_context:
792
+ df, metadata = self._read_dataset(query.dataset_id)
793
+ base_dimension_names = DatasetBaseDimensionNamesModel()
794
+ for dim_type in DimensionType:
795
+ field = dim_type.value
796
+ query_names = getattr(metadata.dimensions, field)
797
+ if len(query_names) > 1:
798
+ msg = (
799
+ "Composite datasets must have a single query name for each dimension: "
800
+ f"{dim_type} {query_names}"
801
+ )
802
+ raise DSGInvalidQuery(msg)
803
+ setattr(base_dimension_names, field, query_names[0].dimension_name)
804
+ context = QueryContext(query, base_dimension_names, scratch_dir_context)
805
+ context.metadata = metadata
806
+ # Refer to the comment in ProjectQuerySubmitter.submit for an explanation or if
807
+ # you add a new customization.
808
+ with custom_time_zone(tz): # type: ignore
809
+ df = self._process_aggregations_and_save(df, context, repartition=False)
810
+ context.finalize()
811
+ return df
812
+
813
+ def _load_composite_dataset_query(self, dataset_id):
814
+ filename = self._composite_datasets_dir() / dataset_id / "query.json5"
815
+ return CreateCompositeDatasetQueryModel.from_file(filename)
816
+
817
+ def _read_dataset(self, dataset_id) -> tuple[DataFrame, DatasetMetadataModel]:
818
+ filename = self._composite_datasets_dir() / dataset_id / "table.parquet"
819
+ if not filename.exists():
820
+ msg = f"There is no composite dataset with dataset_id={dataset_id}"
821
+ raise DSGInvalidParameter(msg)
822
+ metadata_file = self.metadata_filename(self._composite_datasets_dir() / dataset_id)
823
+ return (
824
+ read_dataframe(filename),
825
+ DatasetMetadataModel(**load_data(metadata_file)),
826
+ )
827
+
828
+ @track_timing(timer_stats_collector)
829
+ def _save_composite_dataset(self, context: QueryContext, df, repartition):
830
+ output_dir = self._composite_datasets_dir() / context.model.dataset_id
831
+ output_dir.mkdir(exist_ok=True)
832
+ filename = output_dir / "table.parquet"
833
+ self._save_result(context, df, filename, repartition)
834
+ self.metadata_filename(output_dir).write_text(context.metadata.model_dump_json(indent=2))
835
+
836
+
837
+ class DatasetQuerySubmitter(QuerySubmitterBase):
838
+ """Submits queries for a project."""
839
+
840
+ @track_timing(timer_stats_collector)
841
+ def submit(
842
+ self,
843
+ query: DatasetQueryModel,
844
+ mgr: RegistryManager,
845
+ scratch_dir: Path | None = None,
846
+ checkpoint_file: Path | None = None,
847
+ overwrite: bool = False,
848
+ ) -> DataFrame:
849
+ """Submits a dataset query to produce a result table."""
850
+ if not query.to_dimension_references:
851
+ msg = "A dataset query must specify at least one dimension to map."
852
+ raise DSGInvalidQuery(msg)
853
+
854
+ dataset_config = mgr.dataset_manager.get_by_id(query.dataset_id)
855
+ to_dimension_mapping_refs, dims = self._build_mappings(query, dataset_config, mgr)
856
+ handler = make_dataset_schema_handler(
857
+ conn=None,
858
+ config=dataset_config,
859
+ dimension_mgr=mgr.dimension_manager,
860
+ dimension_mapping_mgr=mgr.dimension_mapping_manager,
861
+ store=mgr.dataset_manager.store,
862
+ mapping_references=to_dimension_mapping_refs,
863
+ )
864
+
865
+ base_dim_names = DatasetBaseDimensionNamesModel()
866
+ scratch_dir = scratch_dir or DsgridRuntimeConfig.load().get_scratch_dir()
867
+ time_dim = dims.get(DimensionType.TIME) or dataset_config.get_time_dimension()
868
+ time_zone = None if time_dim is None else time_dim.get_time_zone()
869
+ checkpoint = self._check_checkpoint_file(checkpoint_file, query)
870
+ with ScratchDirContext(scratch_dir) as scratch_dir_context:
871
+ context = QueryContext(
872
+ query, base_dim_names, scratch_dir_context, checkpoint=checkpoint
873
+ )
874
+ output_dir = self._query_output_dir(context)
875
+ check_overwrite(output_dir, overwrite)
876
+ args = (context, handler)
877
+ kwargs = {"time_dimension": dims.get(DimensionType.TIME)}
878
+ if time_dim is not None and time_zone is not None:
879
+ with custom_time_zone(time_zone):
880
+ df = self._run_query(*args, **kwargs)
881
+ else:
882
+ df = self._run_query(*args, **kwargs)
883
+ return df
884
+
885
+ def _build_mappings(
886
+ self, query: DatasetQueryModel, config: DatasetConfig, mgr: RegistryManager
887
+ ) -> tuple[list[DimensionMappingReferenceModel], dict[DimensionType, DimensionBaseConfig]]:
888
+ config = mgr.dataset_manager.get_by_id(query.dataset_id)
889
+ to_dimension_mapping_refs: list[DimensionMappingReferenceModel] = []
890
+ mapped_dimension_types: set[DimensionType] = set()
891
+ dimensions: dict[DimensionType, DimensionBaseConfig] = {}
892
+ with mgr.dimension_mapping_manager.db.engine.connect() as conn:
893
+ graph = mgr.dimension_mapping_manager.build_graph(conn=conn)
894
+ for to_dim_ref in query.to_dimension_references:
895
+ to_dim = mgr.dimension_manager.get_by_id(
896
+ to_dim_ref.dimension_id, version=to_dim_ref.version
897
+ )
898
+ if to_dim.model.dimension_type in mapped_dimension_types:
899
+ msg = f"A dataset query cannot map multiple dimensions of the same type: {to_dim.model.dimension_type}"
900
+ raise DSGInvalidQuery(msg)
901
+ dataset_dim = config.get_dimension(to_dim.model.dimension_type)
902
+ assert dataset_dim is not None
903
+ if to_dim.model.dimension_id == dataset_dim.model.dimension_id:
904
+ if to_dim.model.version != dataset_dim.model.version:
905
+ msg = (
906
+ f"A to_dimension_reference cannot point to a different version of a "
907
+ f"dataset's dimension dimension: dataset version = {dataset_dim.model.version}, "
908
+ f"dimension version = {to_dim.model.version}"
909
+ )
910
+ raise DSGInvalidQuery(msg)
911
+ # No mapping is required.
912
+ continue
913
+ if to_dim.model.dimension_type != DimensionType.TIME:
914
+ refs = mgr.dimension_mapping_manager.list_mappings_between_dimensions(
915
+ graph,
916
+ dataset_dim.model.dimension_id,
917
+ to_dim.model.dimension_id,
918
+ )
919
+ to_dimension_mapping_refs += refs
920
+ mapped_dimension_types.add(to_dim.model.dimension_type)
921
+ dimensions[to_dim.model.dimension_type] = to_dim
922
+ return to_dimension_mapping_refs, dimensions
923
+
924
+ def _check_checkpoint_file(
925
+ self, checkpoint_file: Path | None, query: DatasetQueryModel
926
+ ) -> MapOperationCheckpoint | None:
927
+ if checkpoint_file is None:
928
+ return None
929
+
930
+ if query.mapping_plan is None:
931
+ msg = f"Query {query.name} does not have a mapping plan. A checkpoint file cannot be used."
932
+ raise DSGInvalidQuery(msg)
933
+
934
+ checkpoint = MapOperationCheckpoint.from_file(checkpoint_file)
935
+ if query.dataset_id != checkpoint.dataset_id:
936
+ msg = (
937
+ f"The dataset_id in the checkpoint file {checkpoint.dataset_id} does not match "
938
+ f"the query dataset_id {query.dataset_id}."
939
+ )
940
+ raise DSGInvalidQuery(msg)
941
+
942
+ if query.mapping_plan.compute_hash() != checkpoint.mapping_plan_hash:
943
+ msg = (
944
+ f"The hash of the mapping plan for dataset {checkpoint.dataset_id} "
945
+ "does not match the checkpoint file. Cannot use the checkpoint."
946
+ )
947
+ raise DSGInvalidParameter(msg)
948
+
949
+ return checkpoint
950
+
951
+ def _run_query(
952
+ self,
953
+ context: QueryContext,
954
+ handler: DatasetSchemaHandlerBase,
955
+ time_dimension: TimeDimensionBaseConfig | None,
956
+ ) -> DataFrame:
957
+ df = handler.make_mapped_dataframe(context, time_dimension=time_dimension)
958
+ df = self._postprocess(context, df)
959
+ self._save_results(context, df)
960
+ return df
961
+
962
+ def _postprocess(self, context: QueryContext, df: DataFrame) -> DataFrame:
963
+ if context.model.result.sort_columns:
964
+ df = df.sort(*context.model.result.sort_columns)
965
+
966
+ if isinstance(context.model.result.table_format, PivotedTableFormatModel):
967
+ df = _pivot_table(df, context)
968
+
969
+ return df
970
+
971
+ def _query_output_dir(self, context: QueryContext) -> Path:
972
+ return self._output_dir / context.model.name
973
+
974
+ @track_timing(timer_stats_collector)
975
+ def _save_results(self, context: QueryContext, df) -> Path:
976
+ output_dir = self._query_output_dir(context)
977
+ output_dir.mkdir(exist_ok=True)
978
+ filename = output_dir / f"table.{context.model.result.output_format}"
979
+ suffix = filename.suffix
980
+ if suffix == ".csv":
981
+ df.toPandas().to_csv(filename, header=True, index=False)
982
+ elif suffix == ".parquet":
983
+ df = write_dataframe_and_auto_partition(df, filename)
984
+ else:
985
+ msg = f"Unsupported output_format={suffix}"
986
+ raise NotImplementedError(msg)
987
+
988
+ logger.info("Wrote query=%s output table to %s", context.model.name, filename)
989
+ return filename
990
+
991
+
992
+ def _pivot_table(df: DataFrame, context: QueryContext):
993
+ pivoted_column = context.convert_to_pivoted()
994
+ return pivot(df, pivoted_column, VALUE_COLUMN)