dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- build_backend.py +93 -0
- dsgrid/__init__.py +22 -0
- dsgrid/api/__init__.py +0 -0
- dsgrid/api/api_manager.py +179 -0
- dsgrid/api/app.py +419 -0
- dsgrid/api/models.py +60 -0
- dsgrid/api/response_models.py +116 -0
- dsgrid/apps/__init__.py +0 -0
- dsgrid/apps/project_viewer/app.py +216 -0
- dsgrid/apps/registration_gui.py +444 -0
- dsgrid/chronify.py +32 -0
- dsgrid/cli/__init__.py +0 -0
- dsgrid/cli/common.py +120 -0
- dsgrid/cli/config.py +176 -0
- dsgrid/cli/download.py +13 -0
- dsgrid/cli/dsgrid.py +157 -0
- dsgrid/cli/dsgrid_admin.py +92 -0
- dsgrid/cli/install_notebooks.py +62 -0
- dsgrid/cli/query.py +729 -0
- dsgrid/cli/registry.py +1862 -0
- dsgrid/cloud/__init__.py +0 -0
- dsgrid/cloud/cloud_storage_interface.py +140 -0
- dsgrid/cloud/factory.py +31 -0
- dsgrid/cloud/fake_storage_interface.py +37 -0
- dsgrid/cloud/s3_storage_interface.py +156 -0
- dsgrid/common.py +36 -0
- dsgrid/config/__init__.py +0 -0
- dsgrid/config/annual_time_dimension_config.py +194 -0
- dsgrid/config/common.py +142 -0
- dsgrid/config/config_base.py +148 -0
- dsgrid/config/dataset_config.py +907 -0
- dsgrid/config/dataset_schema_handler_factory.py +46 -0
- dsgrid/config/date_time_dimension_config.py +136 -0
- dsgrid/config/dimension_config.py +54 -0
- dsgrid/config/dimension_config_factory.py +65 -0
- dsgrid/config/dimension_mapping_base.py +350 -0
- dsgrid/config/dimension_mappings_config.py +48 -0
- dsgrid/config/dimensions.py +1025 -0
- dsgrid/config/dimensions_config.py +71 -0
- dsgrid/config/file_schema.py +190 -0
- dsgrid/config/index_time_dimension_config.py +80 -0
- dsgrid/config/input_dataset_requirements.py +31 -0
- dsgrid/config/mapping_tables.py +209 -0
- dsgrid/config/noop_time_dimension_config.py +42 -0
- dsgrid/config/project_config.py +1462 -0
- dsgrid/config/registration_models.py +188 -0
- dsgrid/config/representative_period_time_dimension_config.py +194 -0
- dsgrid/config/simple_models.py +49 -0
- dsgrid/config/supplemental_dimension.py +29 -0
- dsgrid/config/time_dimension_base_config.py +192 -0
- dsgrid/data_models.py +155 -0
- dsgrid/dataset/__init__.py +0 -0
- dsgrid/dataset/dataset.py +123 -0
- dsgrid/dataset/dataset_expression_handler.py +86 -0
- dsgrid/dataset/dataset_mapping_manager.py +121 -0
- dsgrid/dataset/dataset_schema_handler_base.py +945 -0
- dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
- dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
- dsgrid/dataset/growth_rates.py +162 -0
- dsgrid/dataset/models.py +51 -0
- dsgrid/dataset/table_format_handler_base.py +257 -0
- dsgrid/dataset/table_format_handler_factory.py +17 -0
- dsgrid/dataset/unpivoted_table.py +121 -0
- dsgrid/dimension/__init__.py +0 -0
- dsgrid/dimension/base_models.py +230 -0
- dsgrid/dimension/dimension_filters.py +308 -0
- dsgrid/dimension/standard.py +252 -0
- dsgrid/dimension/time.py +352 -0
- dsgrid/dimension/time_utils.py +103 -0
- dsgrid/dsgrid_rc.py +88 -0
- dsgrid/exceptions.py +105 -0
- dsgrid/filesystem/__init__.py +0 -0
- dsgrid/filesystem/cloud_filesystem.py +32 -0
- dsgrid/filesystem/factory.py +32 -0
- dsgrid/filesystem/filesystem_interface.py +136 -0
- dsgrid/filesystem/local_filesystem.py +74 -0
- dsgrid/filesystem/s3_filesystem.py +118 -0
- dsgrid/loggers.py +132 -0
- dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
- dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
- dsgrid/notebooks/registration.ipynb +48 -0
- dsgrid/notebooks/start_notebook.sh +11 -0
- dsgrid/project.py +451 -0
- dsgrid/query/__init__.py +0 -0
- dsgrid/query/dataset_mapping_plan.py +142 -0
- dsgrid/query/derived_dataset.py +388 -0
- dsgrid/query/models.py +728 -0
- dsgrid/query/query_context.py +287 -0
- dsgrid/query/query_submitter.py +994 -0
- dsgrid/query/report_factory.py +19 -0
- dsgrid/query/report_peak_load.py +70 -0
- dsgrid/query/reports_base.py +20 -0
- dsgrid/registry/__init__.py +0 -0
- dsgrid/registry/bulk_register.py +165 -0
- dsgrid/registry/common.py +287 -0
- dsgrid/registry/config_update_checker_base.py +63 -0
- dsgrid/registry/data_store_factory.py +34 -0
- dsgrid/registry/data_store_interface.py +74 -0
- dsgrid/registry/dataset_config_generator.py +158 -0
- dsgrid/registry/dataset_registry_manager.py +950 -0
- dsgrid/registry/dataset_update_checker.py +16 -0
- dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
- dsgrid/registry/dimension_mapping_update_checker.py +16 -0
- dsgrid/registry/dimension_registry_manager.py +413 -0
- dsgrid/registry/dimension_update_checker.py +16 -0
- dsgrid/registry/duckdb_data_store.py +207 -0
- dsgrid/registry/filesystem_data_store.py +150 -0
- dsgrid/registry/filter_registry_manager.py +123 -0
- dsgrid/registry/project_config_generator.py +57 -0
- dsgrid/registry/project_registry_manager.py +1623 -0
- dsgrid/registry/project_update_checker.py +48 -0
- dsgrid/registry/registration_context.py +223 -0
- dsgrid/registry/registry_auto_updater.py +316 -0
- dsgrid/registry/registry_database.py +667 -0
- dsgrid/registry/registry_interface.py +446 -0
- dsgrid/registry/registry_manager.py +558 -0
- dsgrid/registry/registry_manager_base.py +367 -0
- dsgrid/registry/versioning.py +92 -0
- dsgrid/rust_ext/__init__.py +14 -0
- dsgrid/rust_ext/find_minimal_patterns.py +129 -0
- dsgrid/spark/__init__.py +0 -0
- dsgrid/spark/functions.py +589 -0
- dsgrid/spark/types.py +110 -0
- dsgrid/tests/__init__.py +0 -0
- dsgrid/tests/common.py +140 -0
- dsgrid/tests/make_us_data_registry.py +265 -0
- dsgrid/tests/register_derived_datasets.py +103 -0
- dsgrid/tests/utils.py +25 -0
- dsgrid/time/__init__.py +0 -0
- dsgrid/time/time_conversions.py +80 -0
- dsgrid/time/types.py +67 -0
- dsgrid/units/__init__.py +0 -0
- dsgrid/units/constants.py +113 -0
- dsgrid/units/convert.py +71 -0
- dsgrid/units/energy.py +145 -0
- dsgrid/units/power.py +87 -0
- dsgrid/utils/__init__.py +0 -0
- dsgrid/utils/dataset.py +830 -0
- dsgrid/utils/files.py +179 -0
- dsgrid/utils/filters.py +125 -0
- dsgrid/utils/id_remappings.py +100 -0
- dsgrid/utils/py_expression_eval/LICENSE +19 -0
- dsgrid/utils/py_expression_eval/README.md +8 -0
- dsgrid/utils/py_expression_eval/__init__.py +847 -0
- dsgrid/utils/py_expression_eval/tests.py +283 -0
- dsgrid/utils/run_command.py +70 -0
- dsgrid/utils/scratch_dir_context.py +65 -0
- dsgrid/utils/spark.py +918 -0
- dsgrid/utils/spark_partition.py +98 -0
- dsgrid/utils/timing.py +239 -0
- dsgrid/utils/utilities.py +221 -0
- dsgrid/utils/versioning.py +36 -0
- dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
- dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
- dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
- dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
- dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
dsgrid/data_models.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""Base functionality for all Pydantic data models used in dsgrid"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Self
|
|
7
|
+
|
|
8
|
+
from pydantic import ConfigDict, BaseModel, Field, ValidationError
|
|
9
|
+
|
|
10
|
+
from dsgrid.exceptions import DSGInvalidParameter
|
|
11
|
+
from dsgrid.utils.files import in_other_dir, load_data, dump_data
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def make_model_config(**kwargs) -> ConfigDict:
|
|
18
|
+
"""Return a Pydantic config"""
|
|
19
|
+
return ConfigDict(
|
|
20
|
+
str_strip_whitespace=True,
|
|
21
|
+
validate_assignment=True,
|
|
22
|
+
validate_default=True,
|
|
23
|
+
extra="forbid",
|
|
24
|
+
use_enum_values=False,
|
|
25
|
+
arbitrary_types_allowed=True,
|
|
26
|
+
populate_by_name=True,
|
|
27
|
+
**kwargs,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class DSGBaseModel(BaseModel):
|
|
32
|
+
"""Base data model for all dsgrid data models"""
|
|
33
|
+
|
|
34
|
+
model_config = make_model_config()
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def load(cls, filename):
|
|
38
|
+
"""Load a data model from a file.
|
|
39
|
+
Temporarily changes to the file's parent directory so that Pydantic
|
|
40
|
+
validators can load relative file paths within the file.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
filename : str
|
|
45
|
+
|
|
46
|
+
"""
|
|
47
|
+
filename = Path(filename)
|
|
48
|
+
if not filename.is_file():
|
|
49
|
+
msg = f"{filename} is not a file"
|
|
50
|
+
raise DSGInvalidParameter(msg)
|
|
51
|
+
|
|
52
|
+
with in_other_dir(filename.parent):
|
|
53
|
+
try:
|
|
54
|
+
return cls(**load_data(filename.name))
|
|
55
|
+
except ValidationError:
|
|
56
|
+
logger.exception("Failed to validate %s", filename)
|
|
57
|
+
raise
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def get_fields_with_extra_attribute(cls, attribute):
|
|
61
|
+
fields = set()
|
|
62
|
+
for f, attrs in cls.model_fields.items():
|
|
63
|
+
if attrs.json_schema_extra.get(attribute):
|
|
64
|
+
fields.add(f)
|
|
65
|
+
return fields
|
|
66
|
+
|
|
67
|
+
def model_dump(self, *args, by_alias=True, **kwargs):
|
|
68
|
+
return super().model_dump(*args, by_alias=by_alias, **self._handle_kwargs(**kwargs))
|
|
69
|
+
|
|
70
|
+
def model_dump_json(self, *args, by_alias=True, **kwargs):
|
|
71
|
+
return super().model_dump_json(*args, by_alias=by_alias, **self._handle_kwargs(**kwargs))
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def _handle_kwargs(**kwargs):
|
|
75
|
+
return {k: v for k, v in kwargs.items() if k not in ("by_alias",)}
|
|
76
|
+
|
|
77
|
+
def serialize(self, *args, **kwargs) -> dict[str, Any]:
|
|
78
|
+
return self.model_dump(*args, mode="json", **kwargs)
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_file(cls, filename: Path) -> Self:
|
|
82
|
+
"""Deserialize the model from a file. Unlike the load method,
|
|
83
|
+
this does not change directories.
|
|
84
|
+
"""
|
|
85
|
+
return cls(**load_data(filename))
|
|
86
|
+
|
|
87
|
+
def to_file(self, filename: Path) -> None:
|
|
88
|
+
"""Serialize the model to a file."""
|
|
89
|
+
data = self.serialize()
|
|
90
|
+
dump_data(data, filename, indent=2)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class DSGBaseDatabaseModel(DSGBaseModel):
|
|
94
|
+
"""Base model for all configs stored in the database."""
|
|
95
|
+
|
|
96
|
+
id: int | None = Field(
|
|
97
|
+
default=None,
|
|
98
|
+
description="Registry database ID",
|
|
99
|
+
json_schema_extra={
|
|
100
|
+
"dsgrid_internal": True,
|
|
101
|
+
},
|
|
102
|
+
)
|
|
103
|
+
version: str | None = Field(
|
|
104
|
+
default=None,
|
|
105
|
+
title="version",
|
|
106
|
+
description="Version, generated by dsgrid",
|
|
107
|
+
json_schema_extra={
|
|
108
|
+
"dsgrid_internal": True,
|
|
109
|
+
"updateable": False,
|
|
110
|
+
},
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class EnumValue:
|
|
115
|
+
"""Class to define a DSGEnum value"""
|
|
116
|
+
|
|
117
|
+
def __init__(self, value, description, **kwargs):
|
|
118
|
+
self.value = value
|
|
119
|
+
self.description = description
|
|
120
|
+
for kwarg, val in kwargs.items():
|
|
121
|
+
self.__setattr__(kwarg, val)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class DSGEnum(Enum):
|
|
125
|
+
"""dsgrid Enum class"""
|
|
126
|
+
|
|
127
|
+
def __new__(cls, *args):
|
|
128
|
+
obj = object.__new__(cls)
|
|
129
|
+
assert len(args) in (1, 2)
|
|
130
|
+
if isinstance(args[0], EnumValue):
|
|
131
|
+
obj._value_ = args[0].value
|
|
132
|
+
obj.description = args[0].description
|
|
133
|
+
for attr, val in args[0].__dict__.items():
|
|
134
|
+
if attr not in ("value", "description"):
|
|
135
|
+
setattr(obj, attr, val)
|
|
136
|
+
elif len(args) == 2:
|
|
137
|
+
obj._value_ = args[0]
|
|
138
|
+
obj.description = args[1]
|
|
139
|
+
else:
|
|
140
|
+
obj._value_ = args[0]
|
|
141
|
+
obj.description = None
|
|
142
|
+
return obj
|
|
143
|
+
|
|
144
|
+
@classmethod
|
|
145
|
+
def format_for_docs(cls):
|
|
146
|
+
"""Returns set of formatted enum values for docs."""
|
|
147
|
+
return str([e.value for e in cls]).replace("'", "``")
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def format_descriptions_for_docs(cls):
|
|
151
|
+
"""Returns formatted dict of enum values and descriptions for docs."""
|
|
152
|
+
desc = {}
|
|
153
|
+
for e in cls:
|
|
154
|
+
desc[f"``{e.value}``"] = f"{e.description}"
|
|
155
|
+
return desc
|
|
File without changes
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Provides access to a dataset."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import Connection
|
|
7
|
+
|
|
8
|
+
from dsgrid.config.dataset_config import DatasetConfig
|
|
9
|
+
from dsgrid.config.dataset_schema_handler_factory import make_dataset_schema_handler
|
|
10
|
+
from dsgrid.config.dimension_mapping_base import DimensionMappingReferenceListModel
|
|
11
|
+
from dsgrid.config.project_config import ProjectConfig
|
|
12
|
+
from dsgrid.query.query_context import QueryContext
|
|
13
|
+
from dsgrid.dataset.dataset_schema_handler_base import DatasetSchemaHandlerBase
|
|
14
|
+
from dsgrid.registry.data_store_interface import DataStoreInterface
|
|
15
|
+
from dsgrid.registry.dimension_mapping_registry_manager import DimensionMappingRegistryManager
|
|
16
|
+
from dsgrid.registry.dimension_registry_manager import DimensionRegistryManager
|
|
17
|
+
from dsgrid.spark.types import DataFrame
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DatasetBase(abc.ABC):
|
|
23
|
+
"""Base class for datasets"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, schema_handler: DatasetSchemaHandlerBase):
|
|
26
|
+
self._config = schema_handler.config
|
|
27
|
+
self._handler = schema_handler
|
|
28
|
+
self._id = schema_handler.config.model.dataset_id
|
|
29
|
+
# Can't use dashes in view names. This will need to be handled when we implement
|
|
30
|
+
# queries based on dataset ID.
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def config(self) -> DatasetConfig:
|
|
34
|
+
return self._config
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def dataset_id(self) -> str:
|
|
38
|
+
return self._id
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def handler(self) -> DatasetSchemaHandlerBase:
|
|
42
|
+
return self._handler
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Dataset(DatasetBase):
|
|
46
|
+
"""Represents a dataset used within a project."""
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def load(
|
|
50
|
+
cls,
|
|
51
|
+
config: DatasetConfig,
|
|
52
|
+
dimension_mgr: DimensionRegistryManager,
|
|
53
|
+
dimension_mapping_mgr: DimensionMappingRegistryManager,
|
|
54
|
+
store: DataStoreInterface,
|
|
55
|
+
mapping_references: list[DimensionMappingReferenceListModel],
|
|
56
|
+
conn: Connection | None = None,
|
|
57
|
+
):
|
|
58
|
+
"""Load a dataset from a store.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
config : DatasetConfig
|
|
63
|
+
dimension_mgr : DimensionRegistryManager
|
|
64
|
+
dimension_mapping_mgr : DimensionMappingRegistryManager
|
|
65
|
+
mapping_references: list[DimensionMappingReferenceListModel]
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
Dataset
|
|
70
|
+
|
|
71
|
+
"""
|
|
72
|
+
return cls(
|
|
73
|
+
make_dataset_schema_handler(
|
|
74
|
+
conn,
|
|
75
|
+
config,
|
|
76
|
+
dimension_mgr,
|
|
77
|
+
dimension_mapping_mgr,
|
|
78
|
+
store=store,
|
|
79
|
+
mapping_references=mapping_references,
|
|
80
|
+
)
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def make_project_dataframe(
|
|
84
|
+
self, query: QueryContext, project_config: ProjectConfig
|
|
85
|
+
) -> DataFrame:
|
|
86
|
+
return self._handler.make_project_dataframe(query, project_config)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class StandaloneDataset(DatasetBase):
|
|
90
|
+
"""Represents a dataset used outside of a project."""
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def load(
|
|
94
|
+
cls,
|
|
95
|
+
config: DatasetConfig,
|
|
96
|
+
dimension_mgr: DimensionRegistryManager,
|
|
97
|
+
dimension_mapping_mgr: DimensionMappingRegistryManager,
|
|
98
|
+
store: DataStoreInterface,
|
|
99
|
+
mapping_references: list[DimensionMappingReferenceListModel] | None = None,
|
|
100
|
+
conn: Connection | None = None,
|
|
101
|
+
):
|
|
102
|
+
"""Load a dataset from a store.
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
config : DatasetConfig
|
|
107
|
+
dimension_mgr : DimensionRegistryManager
|
|
108
|
+
|
|
109
|
+
Returns
|
|
110
|
+
-------
|
|
111
|
+
Dataset
|
|
112
|
+
|
|
113
|
+
"""
|
|
114
|
+
return cls(
|
|
115
|
+
make_dataset_schema_handler(
|
|
116
|
+
conn,
|
|
117
|
+
config,
|
|
118
|
+
dimension_mgr,
|
|
119
|
+
dimension_mapping_mgr,
|
|
120
|
+
store=store,
|
|
121
|
+
mapping_references=mapping_references,
|
|
122
|
+
)
|
|
123
|
+
)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
from dsgrid.exceptions import DSGInvalidOperation
|
|
4
|
+
from dsgrid.spark.functions import join_multiple_columns
|
|
5
|
+
from dsgrid.spark.types import DataFrame
|
|
6
|
+
from dsgrid.utils.py_expression_eval import Parser
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DatasetExpressionHandler:
|
|
10
|
+
"""Abstracts SQL expressions for dataset combinations with mathematical expressions."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, df: DataFrame, dimension_columns: list[str], value_columns: list[str]):
|
|
13
|
+
self.df = df
|
|
14
|
+
self.dimension_columns = dimension_columns
|
|
15
|
+
self.value_columns = value_columns
|
|
16
|
+
|
|
17
|
+
def _op(self, other, op):
|
|
18
|
+
orig_self_count = self.df.count()
|
|
19
|
+
orig_other_count = other.df.count()
|
|
20
|
+
if orig_self_count != orig_other_count:
|
|
21
|
+
msg = (
|
|
22
|
+
f"{op=} requires that the datasets have the same length "
|
|
23
|
+
f"{orig_self_count=} {orig_other_count=}"
|
|
24
|
+
)
|
|
25
|
+
raise DSGInvalidOperation(msg)
|
|
26
|
+
|
|
27
|
+
def renamed(col):
|
|
28
|
+
return col + "_other"
|
|
29
|
+
|
|
30
|
+
other_df = other.df
|
|
31
|
+
for column in self.value_columns:
|
|
32
|
+
other_df = other_df.withColumnRenamed(column, renamed(column))
|
|
33
|
+
df = join_multiple_columns(self.df, other_df, self.dimension_columns)
|
|
34
|
+
|
|
35
|
+
for column in self.value_columns:
|
|
36
|
+
other_column = renamed(column)
|
|
37
|
+
df = df.withColumn(column, op(getattr(df, column), getattr(df, other_column)))
|
|
38
|
+
|
|
39
|
+
df = df.select(*self.df.columns)
|
|
40
|
+
joined_count = df.count()
|
|
41
|
+
if joined_count != orig_self_count:
|
|
42
|
+
msg = (
|
|
43
|
+
f"join for operation {op=} has a different row count than the original. "
|
|
44
|
+
f"{orig_self_count=} {joined_count=}"
|
|
45
|
+
)
|
|
46
|
+
raise DSGInvalidOperation(msg)
|
|
47
|
+
|
|
48
|
+
return DatasetExpressionHandler(df, self.dimension_columns, self.value_columns)
|
|
49
|
+
|
|
50
|
+
def __add__(self, other):
|
|
51
|
+
return self._op(other, operator.add)
|
|
52
|
+
|
|
53
|
+
def __mul__(self, other):
|
|
54
|
+
return self._op(other, operator.mul)
|
|
55
|
+
|
|
56
|
+
def __sub__(self, other):
|
|
57
|
+
return self._op(other, operator.sub)
|
|
58
|
+
|
|
59
|
+
def __or__(self, other):
|
|
60
|
+
if self.df.columns != other.df.columns:
|
|
61
|
+
msg = (
|
|
62
|
+
"Union is only allowed when datasets have identical columns: "
|
|
63
|
+
f"{self.df.columns=} vs {other.df.columns=}"
|
|
64
|
+
)
|
|
65
|
+
raise DSGInvalidOperation(msg)
|
|
66
|
+
return DatasetExpressionHandler(
|
|
67
|
+
self.df.union(other.df), self.dimension_columns, self.value_columns
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def evaluate_expression(expr: str, dataset_mapping: dict[str, DatasetExpressionHandler]):
|
|
72
|
+
"""Evaluates an expresion containing dataset IDs.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
expr : str
|
|
77
|
+
Dataset combination expression, such as "dataset1 | dataset2"
|
|
78
|
+
dataset_mapping : dict[str, DatasetExpressionHandler]
|
|
79
|
+
Maps dataset ID to dataset. Each dataset_id in expr must be present in the mapping.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
DatasetExpressionHandler
|
|
84
|
+
|
|
85
|
+
"""
|
|
86
|
+
return Parser().parse(expr).evaluate(dataset_mapping)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Self
|
|
4
|
+
|
|
5
|
+
from dsgrid.query.dataset_mapping_plan import (
|
|
6
|
+
DatasetMappingPlan,
|
|
7
|
+
MapOperation,
|
|
8
|
+
MapOperationCheckpoint,
|
|
9
|
+
)
|
|
10
|
+
from dsgrid.spark.types import DataFrame
|
|
11
|
+
from dsgrid.utils.files import delete_if_exists
|
|
12
|
+
from dsgrid.utils.spark import read_dataframe, write_dataframe
|
|
13
|
+
from dsgrid.utils.scratch_dir_context import ScratchDirContext
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DatasetMappingManager:
|
|
19
|
+
"""Manages the mapping operations for a dataset."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
dataset_id: str,
|
|
24
|
+
plan: DatasetMappingPlan,
|
|
25
|
+
scratch_dir_context: ScratchDirContext,
|
|
26
|
+
checkpoint: MapOperationCheckpoint | None = None,
|
|
27
|
+
):
|
|
28
|
+
self._dataset_id = dataset_id
|
|
29
|
+
self._plan = plan
|
|
30
|
+
self._scratch_dir_context = scratch_dir_context
|
|
31
|
+
self._checkpoint = checkpoint
|
|
32
|
+
self._checkpoint_file: Path | None = None
|
|
33
|
+
|
|
34
|
+
def __enter__(self) -> Self:
|
|
35
|
+
return self
|
|
36
|
+
|
|
37
|
+
def __exit__(self, *args, **kwargs) -> None:
|
|
38
|
+
# Don't cleanup if an exception occurred.
|
|
39
|
+
self.cleanup()
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def plan(self) -> DatasetMappingPlan:
|
|
43
|
+
"""Return the mapping plan for the dataset."""
|
|
44
|
+
return self._plan
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def scratch_dir_context(self) -> ScratchDirContext:
|
|
48
|
+
"""Return the scratch_dir_context."""
|
|
49
|
+
return self._scratch_dir_context
|
|
50
|
+
|
|
51
|
+
def try_read_checkpointed_table(self) -> DataFrame | None:
|
|
52
|
+
"""Read the checkpointed table for the dataset, if it exists."""
|
|
53
|
+
if self._checkpoint is None:
|
|
54
|
+
return None
|
|
55
|
+
return read_dataframe(self._checkpoint.persisted_table_filename)
|
|
56
|
+
|
|
57
|
+
def get_completed_mapping_operations(self) -> set[str]:
|
|
58
|
+
"""Return the names of completed mapping operations."""
|
|
59
|
+
return (
|
|
60
|
+
set() if self._checkpoint is None else set(self._checkpoint.completed_operation_names)
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def has_completed_operation(self, op: MapOperation) -> bool:
|
|
64
|
+
"""Return True if the mapping operation has been completed."""
|
|
65
|
+
return op.name in self.get_completed_mapping_operations()
|
|
66
|
+
|
|
67
|
+
def persist_table(self, df: DataFrame, op: MapOperation) -> DataFrame:
|
|
68
|
+
"""Persist the intermediate table to the filesystem and return the persisted DataFrame."""
|
|
69
|
+
persisted_file = self._scratch_dir_context.get_temp_filename(
|
|
70
|
+
suffix=".parquet", add_tracked_path=False
|
|
71
|
+
)
|
|
72
|
+
write_dataframe(df, persisted_file)
|
|
73
|
+
self.save_checkpoint(persisted_file, op)
|
|
74
|
+
logger.info("Persisted mapping operation name=%s to %s", op.name, persisted_file)
|
|
75
|
+
return read_dataframe(persisted_file)
|
|
76
|
+
|
|
77
|
+
def save_checkpoint(self, persisted_table: Path, op: MapOperation) -> None:
|
|
78
|
+
"""Save a checkpoint after persisting an operation to the filesystem."""
|
|
79
|
+
completed_operation_names: list[str] = []
|
|
80
|
+
for mapping_op in self._plan.list_mapping_operations():
|
|
81
|
+
completed_operation_names.append(mapping_op.name)
|
|
82
|
+
if mapping_op.name == op.name:
|
|
83
|
+
break
|
|
84
|
+
|
|
85
|
+
checkpoint = MapOperationCheckpoint(
|
|
86
|
+
dataset_id=self._dataset_id,
|
|
87
|
+
completed_operation_names=completed_operation_names,
|
|
88
|
+
persisted_table_filename=persisted_table,
|
|
89
|
+
mapping_plan_hash=self._plan.compute_hash(),
|
|
90
|
+
)
|
|
91
|
+
checkpoint_filename = self._scratch_dir_context.get_temp_filename(
|
|
92
|
+
suffix=".json", add_tracked_path=False
|
|
93
|
+
)
|
|
94
|
+
checkpoint.to_file(checkpoint_filename)
|
|
95
|
+
if self._checkpoint_file is not None and not self._plan.keep_intermediate_files:
|
|
96
|
+
assert self._checkpoint is not None, self._checkpoint
|
|
97
|
+
logger.info("Remove previous checkpoint: %s", self._checkpoint_file)
|
|
98
|
+
delete_if_exists(self._checkpoint_file)
|
|
99
|
+
delete_if_exists(self._checkpoint.persisted_table_filename)
|
|
100
|
+
|
|
101
|
+
self._checkpoint = checkpoint
|
|
102
|
+
self._checkpoint_file = checkpoint_filename
|
|
103
|
+
logger.info("Saved checkpoint in %s", self._checkpoint_file)
|
|
104
|
+
|
|
105
|
+
def cleanup(self) -> None:
|
|
106
|
+
"""Cleanup the intermediate files. Call if the operation completed succesfully."""
|
|
107
|
+
if self._plan.keep_intermediate_files:
|
|
108
|
+
logger.info("Keeping intermediate files for dataset %s", self._dataset_id)
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
if self._checkpoint_file is not None:
|
|
112
|
+
logger.info("Removing checkpoint filename %s", self._checkpoint_file)
|
|
113
|
+
delete_if_exists(self._checkpoint_file)
|
|
114
|
+
self._checkpoint_file = None
|
|
115
|
+
if self._checkpoint is not None:
|
|
116
|
+
logger.info(
|
|
117
|
+
"Removing persisted intermediate table filename %s",
|
|
118
|
+
self._checkpoint.persisted_table_filename,
|
|
119
|
+
)
|
|
120
|
+
delete_if_exists(self._checkpoint.persisted_table_filename)
|
|
121
|
+
self._checkpoint = None
|