dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. build_backend.py +93 -0
  2. dsgrid/__init__.py +22 -0
  3. dsgrid/api/__init__.py +0 -0
  4. dsgrid/api/api_manager.py +179 -0
  5. dsgrid/api/app.py +419 -0
  6. dsgrid/api/models.py +60 -0
  7. dsgrid/api/response_models.py +116 -0
  8. dsgrid/apps/__init__.py +0 -0
  9. dsgrid/apps/project_viewer/app.py +216 -0
  10. dsgrid/apps/registration_gui.py +444 -0
  11. dsgrid/chronify.py +32 -0
  12. dsgrid/cli/__init__.py +0 -0
  13. dsgrid/cli/common.py +120 -0
  14. dsgrid/cli/config.py +176 -0
  15. dsgrid/cli/download.py +13 -0
  16. dsgrid/cli/dsgrid.py +157 -0
  17. dsgrid/cli/dsgrid_admin.py +92 -0
  18. dsgrid/cli/install_notebooks.py +62 -0
  19. dsgrid/cli/query.py +729 -0
  20. dsgrid/cli/registry.py +1862 -0
  21. dsgrid/cloud/__init__.py +0 -0
  22. dsgrid/cloud/cloud_storage_interface.py +140 -0
  23. dsgrid/cloud/factory.py +31 -0
  24. dsgrid/cloud/fake_storage_interface.py +37 -0
  25. dsgrid/cloud/s3_storage_interface.py +156 -0
  26. dsgrid/common.py +36 -0
  27. dsgrid/config/__init__.py +0 -0
  28. dsgrid/config/annual_time_dimension_config.py +194 -0
  29. dsgrid/config/common.py +142 -0
  30. dsgrid/config/config_base.py +148 -0
  31. dsgrid/config/dataset_config.py +907 -0
  32. dsgrid/config/dataset_schema_handler_factory.py +46 -0
  33. dsgrid/config/date_time_dimension_config.py +136 -0
  34. dsgrid/config/dimension_config.py +54 -0
  35. dsgrid/config/dimension_config_factory.py +65 -0
  36. dsgrid/config/dimension_mapping_base.py +350 -0
  37. dsgrid/config/dimension_mappings_config.py +48 -0
  38. dsgrid/config/dimensions.py +1025 -0
  39. dsgrid/config/dimensions_config.py +71 -0
  40. dsgrid/config/file_schema.py +190 -0
  41. dsgrid/config/index_time_dimension_config.py +80 -0
  42. dsgrid/config/input_dataset_requirements.py +31 -0
  43. dsgrid/config/mapping_tables.py +209 -0
  44. dsgrid/config/noop_time_dimension_config.py +42 -0
  45. dsgrid/config/project_config.py +1462 -0
  46. dsgrid/config/registration_models.py +188 -0
  47. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  48. dsgrid/config/simple_models.py +49 -0
  49. dsgrid/config/supplemental_dimension.py +29 -0
  50. dsgrid/config/time_dimension_base_config.py +192 -0
  51. dsgrid/data_models.py +155 -0
  52. dsgrid/dataset/__init__.py +0 -0
  53. dsgrid/dataset/dataset.py +123 -0
  54. dsgrid/dataset/dataset_expression_handler.py +86 -0
  55. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  56. dsgrid/dataset/dataset_schema_handler_base.py +945 -0
  57. dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
  58. dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
  59. dsgrid/dataset/growth_rates.py +162 -0
  60. dsgrid/dataset/models.py +51 -0
  61. dsgrid/dataset/table_format_handler_base.py +257 -0
  62. dsgrid/dataset/table_format_handler_factory.py +17 -0
  63. dsgrid/dataset/unpivoted_table.py +121 -0
  64. dsgrid/dimension/__init__.py +0 -0
  65. dsgrid/dimension/base_models.py +230 -0
  66. dsgrid/dimension/dimension_filters.py +308 -0
  67. dsgrid/dimension/standard.py +252 -0
  68. dsgrid/dimension/time.py +352 -0
  69. dsgrid/dimension/time_utils.py +103 -0
  70. dsgrid/dsgrid_rc.py +88 -0
  71. dsgrid/exceptions.py +105 -0
  72. dsgrid/filesystem/__init__.py +0 -0
  73. dsgrid/filesystem/cloud_filesystem.py +32 -0
  74. dsgrid/filesystem/factory.py +32 -0
  75. dsgrid/filesystem/filesystem_interface.py +136 -0
  76. dsgrid/filesystem/local_filesystem.py +74 -0
  77. dsgrid/filesystem/s3_filesystem.py +118 -0
  78. dsgrid/loggers.py +132 -0
  79. dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
  80. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
  81. dsgrid/notebooks/registration.ipynb +48 -0
  82. dsgrid/notebooks/start_notebook.sh +11 -0
  83. dsgrid/project.py +451 -0
  84. dsgrid/query/__init__.py +0 -0
  85. dsgrid/query/dataset_mapping_plan.py +142 -0
  86. dsgrid/query/derived_dataset.py +388 -0
  87. dsgrid/query/models.py +728 -0
  88. dsgrid/query/query_context.py +287 -0
  89. dsgrid/query/query_submitter.py +994 -0
  90. dsgrid/query/report_factory.py +19 -0
  91. dsgrid/query/report_peak_load.py +70 -0
  92. dsgrid/query/reports_base.py +20 -0
  93. dsgrid/registry/__init__.py +0 -0
  94. dsgrid/registry/bulk_register.py +165 -0
  95. dsgrid/registry/common.py +287 -0
  96. dsgrid/registry/config_update_checker_base.py +63 -0
  97. dsgrid/registry/data_store_factory.py +34 -0
  98. dsgrid/registry/data_store_interface.py +74 -0
  99. dsgrid/registry/dataset_config_generator.py +158 -0
  100. dsgrid/registry/dataset_registry_manager.py +950 -0
  101. dsgrid/registry/dataset_update_checker.py +16 -0
  102. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  103. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  104. dsgrid/registry/dimension_registry_manager.py +413 -0
  105. dsgrid/registry/dimension_update_checker.py +16 -0
  106. dsgrid/registry/duckdb_data_store.py +207 -0
  107. dsgrid/registry/filesystem_data_store.py +150 -0
  108. dsgrid/registry/filter_registry_manager.py +123 -0
  109. dsgrid/registry/project_config_generator.py +57 -0
  110. dsgrid/registry/project_registry_manager.py +1623 -0
  111. dsgrid/registry/project_update_checker.py +48 -0
  112. dsgrid/registry/registration_context.py +223 -0
  113. dsgrid/registry/registry_auto_updater.py +316 -0
  114. dsgrid/registry/registry_database.py +667 -0
  115. dsgrid/registry/registry_interface.py +446 -0
  116. dsgrid/registry/registry_manager.py +558 -0
  117. dsgrid/registry/registry_manager_base.py +367 -0
  118. dsgrid/registry/versioning.py +92 -0
  119. dsgrid/rust_ext/__init__.py +14 -0
  120. dsgrid/rust_ext/find_minimal_patterns.py +129 -0
  121. dsgrid/spark/__init__.py +0 -0
  122. dsgrid/spark/functions.py +589 -0
  123. dsgrid/spark/types.py +110 -0
  124. dsgrid/tests/__init__.py +0 -0
  125. dsgrid/tests/common.py +140 -0
  126. dsgrid/tests/make_us_data_registry.py +265 -0
  127. dsgrid/tests/register_derived_datasets.py +103 -0
  128. dsgrid/tests/utils.py +25 -0
  129. dsgrid/time/__init__.py +0 -0
  130. dsgrid/time/time_conversions.py +80 -0
  131. dsgrid/time/types.py +67 -0
  132. dsgrid/units/__init__.py +0 -0
  133. dsgrid/units/constants.py +113 -0
  134. dsgrid/units/convert.py +71 -0
  135. dsgrid/units/energy.py +145 -0
  136. dsgrid/units/power.py +87 -0
  137. dsgrid/utils/__init__.py +0 -0
  138. dsgrid/utils/dataset.py +830 -0
  139. dsgrid/utils/files.py +179 -0
  140. dsgrid/utils/filters.py +125 -0
  141. dsgrid/utils/id_remappings.py +100 -0
  142. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  143. dsgrid/utils/py_expression_eval/README.md +8 -0
  144. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  145. dsgrid/utils/py_expression_eval/tests.py +283 -0
  146. dsgrid/utils/run_command.py +70 -0
  147. dsgrid/utils/scratch_dir_context.py +65 -0
  148. dsgrid/utils/spark.py +918 -0
  149. dsgrid/utils/spark_partition.py +98 -0
  150. dsgrid/utils/timing.py +239 -0
  151. dsgrid/utils/utilities.py +221 -0
  152. dsgrid/utils/versioning.py +36 -0
  153. dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
  154. dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
  155. dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
  156. dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
  157. dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
dsgrid/data_models.py ADDED
@@ -0,0 +1,155 @@
1
+ """Base functionality for all Pydantic data models used in dsgrid"""
2
+
3
+ import logging
4
+ from enum import Enum
5
+ from pathlib import Path
6
+ from typing import Any, Self
7
+
8
+ from pydantic import ConfigDict, BaseModel, Field, ValidationError
9
+
10
+ from dsgrid.exceptions import DSGInvalidParameter
11
+ from dsgrid.utils.files import in_other_dir, load_data, dump_data
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def make_model_config(**kwargs) -> ConfigDict:
18
+ """Return a Pydantic config"""
19
+ return ConfigDict(
20
+ str_strip_whitespace=True,
21
+ validate_assignment=True,
22
+ validate_default=True,
23
+ extra="forbid",
24
+ use_enum_values=False,
25
+ arbitrary_types_allowed=True,
26
+ populate_by_name=True,
27
+ **kwargs,
28
+ )
29
+
30
+
31
+ class DSGBaseModel(BaseModel):
32
+ """Base data model for all dsgrid data models"""
33
+
34
+ model_config = make_model_config()
35
+
36
+ @classmethod
37
+ def load(cls, filename):
38
+ """Load a data model from a file.
39
+ Temporarily changes to the file's parent directory so that Pydantic
40
+ validators can load relative file paths within the file.
41
+
42
+ Parameters
43
+ ----------
44
+ filename : str
45
+
46
+ """
47
+ filename = Path(filename)
48
+ if not filename.is_file():
49
+ msg = f"{filename} is not a file"
50
+ raise DSGInvalidParameter(msg)
51
+
52
+ with in_other_dir(filename.parent):
53
+ try:
54
+ return cls(**load_data(filename.name))
55
+ except ValidationError:
56
+ logger.exception("Failed to validate %s", filename)
57
+ raise
58
+
59
+ @classmethod
60
+ def get_fields_with_extra_attribute(cls, attribute):
61
+ fields = set()
62
+ for f, attrs in cls.model_fields.items():
63
+ if attrs.json_schema_extra.get(attribute):
64
+ fields.add(f)
65
+ return fields
66
+
67
+ def model_dump(self, *args, by_alias=True, **kwargs):
68
+ return super().model_dump(*args, by_alias=by_alias, **self._handle_kwargs(**kwargs))
69
+
70
+ def model_dump_json(self, *args, by_alias=True, **kwargs):
71
+ return super().model_dump_json(*args, by_alias=by_alias, **self._handle_kwargs(**kwargs))
72
+
73
+ @staticmethod
74
+ def _handle_kwargs(**kwargs):
75
+ return {k: v for k, v in kwargs.items() if k not in ("by_alias",)}
76
+
77
+ def serialize(self, *args, **kwargs) -> dict[str, Any]:
78
+ return self.model_dump(*args, mode="json", **kwargs)
79
+
80
+ @classmethod
81
+ def from_file(cls, filename: Path) -> Self:
82
+ """Deserialize the model from a file. Unlike the load method,
83
+ this does not change directories.
84
+ """
85
+ return cls(**load_data(filename))
86
+
87
+ def to_file(self, filename: Path) -> None:
88
+ """Serialize the model to a file."""
89
+ data = self.serialize()
90
+ dump_data(data, filename, indent=2)
91
+
92
+
93
+ class DSGBaseDatabaseModel(DSGBaseModel):
94
+ """Base model for all configs stored in the database."""
95
+
96
+ id: int | None = Field(
97
+ default=None,
98
+ description="Registry database ID",
99
+ json_schema_extra={
100
+ "dsgrid_internal": True,
101
+ },
102
+ )
103
+ version: str | None = Field(
104
+ default=None,
105
+ title="version",
106
+ description="Version, generated by dsgrid",
107
+ json_schema_extra={
108
+ "dsgrid_internal": True,
109
+ "updateable": False,
110
+ },
111
+ )
112
+
113
+
114
+ class EnumValue:
115
+ """Class to define a DSGEnum value"""
116
+
117
+ def __init__(self, value, description, **kwargs):
118
+ self.value = value
119
+ self.description = description
120
+ for kwarg, val in kwargs.items():
121
+ self.__setattr__(kwarg, val)
122
+
123
+
124
+ class DSGEnum(Enum):
125
+ """dsgrid Enum class"""
126
+
127
+ def __new__(cls, *args):
128
+ obj = object.__new__(cls)
129
+ assert len(args) in (1, 2)
130
+ if isinstance(args[0], EnumValue):
131
+ obj._value_ = args[0].value
132
+ obj.description = args[0].description
133
+ for attr, val in args[0].__dict__.items():
134
+ if attr not in ("value", "description"):
135
+ setattr(obj, attr, val)
136
+ elif len(args) == 2:
137
+ obj._value_ = args[0]
138
+ obj.description = args[1]
139
+ else:
140
+ obj._value_ = args[0]
141
+ obj.description = None
142
+ return obj
143
+
144
+ @classmethod
145
+ def format_for_docs(cls):
146
+ """Returns set of formatted enum values for docs."""
147
+ return str([e.value for e in cls]).replace("'", "``")
148
+
149
+ @classmethod
150
+ def format_descriptions_for_docs(cls):
151
+ """Returns formatted dict of enum values and descriptions for docs."""
152
+ desc = {}
153
+ for e in cls:
154
+ desc[f"``{e.value}``"] = f"{e.description}"
155
+ return desc
File without changes
@@ -0,0 +1,123 @@
1
+ """Provides access to a dataset."""
2
+
3
+ import abc
4
+ import logging
5
+
6
+ from sqlalchemy import Connection
7
+
8
+ from dsgrid.config.dataset_config import DatasetConfig
9
+ from dsgrid.config.dataset_schema_handler_factory import make_dataset_schema_handler
10
+ from dsgrid.config.dimension_mapping_base import DimensionMappingReferenceListModel
11
+ from dsgrid.config.project_config import ProjectConfig
12
+ from dsgrid.query.query_context import QueryContext
13
+ from dsgrid.dataset.dataset_schema_handler_base import DatasetSchemaHandlerBase
14
+ from dsgrid.registry.data_store_interface import DataStoreInterface
15
+ from dsgrid.registry.dimension_mapping_registry_manager import DimensionMappingRegistryManager
16
+ from dsgrid.registry.dimension_registry_manager import DimensionRegistryManager
17
+ from dsgrid.spark.types import DataFrame
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class DatasetBase(abc.ABC):
23
+ """Base class for datasets"""
24
+
25
+ def __init__(self, schema_handler: DatasetSchemaHandlerBase):
26
+ self._config = schema_handler.config
27
+ self._handler = schema_handler
28
+ self._id = schema_handler.config.model.dataset_id
29
+ # Can't use dashes in view names. This will need to be handled when we implement
30
+ # queries based on dataset ID.
31
+
32
+ @property
33
+ def config(self) -> DatasetConfig:
34
+ return self._config
35
+
36
+ @property
37
+ def dataset_id(self) -> str:
38
+ return self._id
39
+
40
+ @property
41
+ def handler(self) -> DatasetSchemaHandlerBase:
42
+ return self._handler
43
+
44
+
45
+ class Dataset(DatasetBase):
46
+ """Represents a dataset used within a project."""
47
+
48
+ @classmethod
49
+ def load(
50
+ cls,
51
+ config: DatasetConfig,
52
+ dimension_mgr: DimensionRegistryManager,
53
+ dimension_mapping_mgr: DimensionMappingRegistryManager,
54
+ store: DataStoreInterface,
55
+ mapping_references: list[DimensionMappingReferenceListModel],
56
+ conn: Connection | None = None,
57
+ ):
58
+ """Load a dataset from a store.
59
+
60
+ Parameters
61
+ ----------
62
+ config : DatasetConfig
63
+ dimension_mgr : DimensionRegistryManager
64
+ dimension_mapping_mgr : DimensionMappingRegistryManager
65
+ mapping_references: list[DimensionMappingReferenceListModel]
66
+
67
+ Returns
68
+ -------
69
+ Dataset
70
+
71
+ """
72
+ return cls(
73
+ make_dataset_schema_handler(
74
+ conn,
75
+ config,
76
+ dimension_mgr,
77
+ dimension_mapping_mgr,
78
+ store=store,
79
+ mapping_references=mapping_references,
80
+ )
81
+ )
82
+
83
+ def make_project_dataframe(
84
+ self, query: QueryContext, project_config: ProjectConfig
85
+ ) -> DataFrame:
86
+ return self._handler.make_project_dataframe(query, project_config)
87
+
88
+
89
+ class StandaloneDataset(DatasetBase):
90
+ """Represents a dataset used outside of a project."""
91
+
92
+ @classmethod
93
+ def load(
94
+ cls,
95
+ config: DatasetConfig,
96
+ dimension_mgr: DimensionRegistryManager,
97
+ dimension_mapping_mgr: DimensionMappingRegistryManager,
98
+ store: DataStoreInterface,
99
+ mapping_references: list[DimensionMappingReferenceListModel] | None = None,
100
+ conn: Connection | None = None,
101
+ ):
102
+ """Load a dataset from a store.
103
+
104
+ Parameters
105
+ ----------
106
+ config : DatasetConfig
107
+ dimension_mgr : DimensionRegistryManager
108
+
109
+ Returns
110
+ -------
111
+ Dataset
112
+
113
+ """
114
+ return cls(
115
+ make_dataset_schema_handler(
116
+ conn,
117
+ config,
118
+ dimension_mgr,
119
+ dimension_mapping_mgr,
120
+ store=store,
121
+ mapping_references=mapping_references,
122
+ )
123
+ )
@@ -0,0 +1,86 @@
1
+ import operator
2
+
3
+ from dsgrid.exceptions import DSGInvalidOperation
4
+ from dsgrid.spark.functions import join_multiple_columns
5
+ from dsgrid.spark.types import DataFrame
6
+ from dsgrid.utils.py_expression_eval import Parser
7
+
8
+
9
+ class DatasetExpressionHandler:
10
+ """Abstracts SQL expressions for dataset combinations with mathematical expressions."""
11
+
12
+ def __init__(self, df: DataFrame, dimension_columns: list[str], value_columns: list[str]):
13
+ self.df = df
14
+ self.dimension_columns = dimension_columns
15
+ self.value_columns = value_columns
16
+
17
+ def _op(self, other, op):
18
+ orig_self_count = self.df.count()
19
+ orig_other_count = other.df.count()
20
+ if orig_self_count != orig_other_count:
21
+ msg = (
22
+ f"{op=} requires that the datasets have the same length "
23
+ f"{orig_self_count=} {orig_other_count=}"
24
+ )
25
+ raise DSGInvalidOperation(msg)
26
+
27
+ def renamed(col):
28
+ return col + "_other"
29
+
30
+ other_df = other.df
31
+ for column in self.value_columns:
32
+ other_df = other_df.withColumnRenamed(column, renamed(column))
33
+ df = join_multiple_columns(self.df, other_df, self.dimension_columns)
34
+
35
+ for column in self.value_columns:
36
+ other_column = renamed(column)
37
+ df = df.withColumn(column, op(getattr(df, column), getattr(df, other_column)))
38
+
39
+ df = df.select(*self.df.columns)
40
+ joined_count = df.count()
41
+ if joined_count != orig_self_count:
42
+ msg = (
43
+ f"join for operation {op=} has a different row count than the original. "
44
+ f"{orig_self_count=} {joined_count=}"
45
+ )
46
+ raise DSGInvalidOperation(msg)
47
+
48
+ return DatasetExpressionHandler(df, self.dimension_columns, self.value_columns)
49
+
50
+ def __add__(self, other):
51
+ return self._op(other, operator.add)
52
+
53
+ def __mul__(self, other):
54
+ return self._op(other, operator.mul)
55
+
56
+ def __sub__(self, other):
57
+ return self._op(other, operator.sub)
58
+
59
+ def __or__(self, other):
60
+ if self.df.columns != other.df.columns:
61
+ msg = (
62
+ "Union is only allowed when datasets have identical columns: "
63
+ f"{self.df.columns=} vs {other.df.columns=}"
64
+ )
65
+ raise DSGInvalidOperation(msg)
66
+ return DatasetExpressionHandler(
67
+ self.df.union(other.df), self.dimension_columns, self.value_columns
68
+ )
69
+
70
+
71
+ def evaluate_expression(expr: str, dataset_mapping: dict[str, DatasetExpressionHandler]):
72
+ """Evaluates an expresion containing dataset IDs.
73
+
74
+ Parameters
75
+ ----------
76
+ expr : str
77
+ Dataset combination expression, such as "dataset1 | dataset2"
78
+ dataset_mapping : dict[str, DatasetExpressionHandler]
79
+ Maps dataset ID to dataset. Each dataset_id in expr must be present in the mapping.
80
+
81
+ Returns
82
+ -------
83
+ DatasetExpressionHandler
84
+
85
+ """
86
+ return Parser().parse(expr).evaluate(dataset_mapping)
@@ -0,0 +1,121 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Self
4
+
5
+ from dsgrid.query.dataset_mapping_plan import (
6
+ DatasetMappingPlan,
7
+ MapOperation,
8
+ MapOperationCheckpoint,
9
+ )
10
+ from dsgrid.spark.types import DataFrame
11
+ from dsgrid.utils.files import delete_if_exists
12
+ from dsgrid.utils.spark import read_dataframe, write_dataframe
13
+ from dsgrid.utils.scratch_dir_context import ScratchDirContext
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class DatasetMappingManager:
19
+ """Manages the mapping operations for a dataset."""
20
+
21
+ def __init__(
22
+ self,
23
+ dataset_id: str,
24
+ plan: DatasetMappingPlan,
25
+ scratch_dir_context: ScratchDirContext,
26
+ checkpoint: MapOperationCheckpoint | None = None,
27
+ ):
28
+ self._dataset_id = dataset_id
29
+ self._plan = plan
30
+ self._scratch_dir_context = scratch_dir_context
31
+ self._checkpoint = checkpoint
32
+ self._checkpoint_file: Path | None = None
33
+
34
+ def __enter__(self) -> Self:
35
+ return self
36
+
37
+ def __exit__(self, *args, **kwargs) -> None:
38
+ # Don't cleanup if an exception occurred.
39
+ self.cleanup()
40
+
41
+ @property
42
+ def plan(self) -> DatasetMappingPlan:
43
+ """Return the mapping plan for the dataset."""
44
+ return self._plan
45
+
46
+ @property
47
+ def scratch_dir_context(self) -> ScratchDirContext:
48
+ """Return the scratch_dir_context."""
49
+ return self._scratch_dir_context
50
+
51
+ def try_read_checkpointed_table(self) -> DataFrame | None:
52
+ """Read the checkpointed table for the dataset, if it exists."""
53
+ if self._checkpoint is None:
54
+ return None
55
+ return read_dataframe(self._checkpoint.persisted_table_filename)
56
+
57
+ def get_completed_mapping_operations(self) -> set[str]:
58
+ """Return the names of completed mapping operations."""
59
+ return (
60
+ set() if self._checkpoint is None else set(self._checkpoint.completed_operation_names)
61
+ )
62
+
63
+ def has_completed_operation(self, op: MapOperation) -> bool:
64
+ """Return True if the mapping operation has been completed."""
65
+ return op.name in self.get_completed_mapping_operations()
66
+
67
+ def persist_table(self, df: DataFrame, op: MapOperation) -> DataFrame:
68
+ """Persist the intermediate table to the filesystem and return the persisted DataFrame."""
69
+ persisted_file = self._scratch_dir_context.get_temp_filename(
70
+ suffix=".parquet", add_tracked_path=False
71
+ )
72
+ write_dataframe(df, persisted_file)
73
+ self.save_checkpoint(persisted_file, op)
74
+ logger.info("Persisted mapping operation name=%s to %s", op.name, persisted_file)
75
+ return read_dataframe(persisted_file)
76
+
77
+ def save_checkpoint(self, persisted_table: Path, op: MapOperation) -> None:
78
+ """Save a checkpoint after persisting an operation to the filesystem."""
79
+ completed_operation_names: list[str] = []
80
+ for mapping_op in self._plan.list_mapping_operations():
81
+ completed_operation_names.append(mapping_op.name)
82
+ if mapping_op.name == op.name:
83
+ break
84
+
85
+ checkpoint = MapOperationCheckpoint(
86
+ dataset_id=self._dataset_id,
87
+ completed_operation_names=completed_operation_names,
88
+ persisted_table_filename=persisted_table,
89
+ mapping_plan_hash=self._plan.compute_hash(),
90
+ )
91
+ checkpoint_filename = self._scratch_dir_context.get_temp_filename(
92
+ suffix=".json", add_tracked_path=False
93
+ )
94
+ checkpoint.to_file(checkpoint_filename)
95
+ if self._checkpoint_file is not None and not self._plan.keep_intermediate_files:
96
+ assert self._checkpoint is not None, self._checkpoint
97
+ logger.info("Remove previous checkpoint: %s", self._checkpoint_file)
98
+ delete_if_exists(self._checkpoint_file)
99
+ delete_if_exists(self._checkpoint.persisted_table_filename)
100
+
101
+ self._checkpoint = checkpoint
102
+ self._checkpoint_file = checkpoint_filename
103
+ logger.info("Saved checkpoint in %s", self._checkpoint_file)
104
+
105
+ def cleanup(self) -> None:
106
+ """Cleanup the intermediate files. Call if the operation completed succesfully."""
107
+ if self._plan.keep_intermediate_files:
108
+ logger.info("Keeping intermediate files for dataset %s", self._dataset_id)
109
+ return
110
+
111
+ if self._checkpoint_file is not None:
112
+ logger.info("Removing checkpoint filename %s", self._checkpoint_file)
113
+ delete_if_exists(self._checkpoint_file)
114
+ self._checkpoint_file = None
115
+ if self._checkpoint is not None:
116
+ logger.info(
117
+ "Removing persisted intermediate table filename %s",
118
+ self._checkpoint.persisted_table_filename,
119
+ )
120
+ delete_if_exists(self._checkpoint.persisted_table_filename)
121
+ self._checkpoint = None