dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- build_backend.py +93 -0
- dsgrid/__init__.py +22 -0
- dsgrid/api/__init__.py +0 -0
- dsgrid/api/api_manager.py +179 -0
- dsgrid/api/app.py +419 -0
- dsgrid/api/models.py +60 -0
- dsgrid/api/response_models.py +116 -0
- dsgrid/apps/__init__.py +0 -0
- dsgrid/apps/project_viewer/app.py +216 -0
- dsgrid/apps/registration_gui.py +444 -0
- dsgrid/chronify.py +32 -0
- dsgrid/cli/__init__.py +0 -0
- dsgrid/cli/common.py +120 -0
- dsgrid/cli/config.py +176 -0
- dsgrid/cli/download.py +13 -0
- dsgrid/cli/dsgrid.py +157 -0
- dsgrid/cli/dsgrid_admin.py +92 -0
- dsgrid/cli/install_notebooks.py +62 -0
- dsgrid/cli/query.py +729 -0
- dsgrid/cli/registry.py +1862 -0
- dsgrid/cloud/__init__.py +0 -0
- dsgrid/cloud/cloud_storage_interface.py +140 -0
- dsgrid/cloud/factory.py +31 -0
- dsgrid/cloud/fake_storage_interface.py +37 -0
- dsgrid/cloud/s3_storage_interface.py +156 -0
- dsgrid/common.py +36 -0
- dsgrid/config/__init__.py +0 -0
- dsgrid/config/annual_time_dimension_config.py +194 -0
- dsgrid/config/common.py +142 -0
- dsgrid/config/config_base.py +148 -0
- dsgrid/config/dataset_config.py +907 -0
- dsgrid/config/dataset_schema_handler_factory.py +46 -0
- dsgrid/config/date_time_dimension_config.py +136 -0
- dsgrid/config/dimension_config.py +54 -0
- dsgrid/config/dimension_config_factory.py +65 -0
- dsgrid/config/dimension_mapping_base.py +350 -0
- dsgrid/config/dimension_mappings_config.py +48 -0
- dsgrid/config/dimensions.py +1025 -0
- dsgrid/config/dimensions_config.py +71 -0
- dsgrid/config/file_schema.py +190 -0
- dsgrid/config/index_time_dimension_config.py +80 -0
- dsgrid/config/input_dataset_requirements.py +31 -0
- dsgrid/config/mapping_tables.py +209 -0
- dsgrid/config/noop_time_dimension_config.py +42 -0
- dsgrid/config/project_config.py +1462 -0
- dsgrid/config/registration_models.py +188 -0
- dsgrid/config/representative_period_time_dimension_config.py +194 -0
- dsgrid/config/simple_models.py +49 -0
- dsgrid/config/supplemental_dimension.py +29 -0
- dsgrid/config/time_dimension_base_config.py +192 -0
- dsgrid/data_models.py +155 -0
- dsgrid/dataset/__init__.py +0 -0
- dsgrid/dataset/dataset.py +123 -0
- dsgrid/dataset/dataset_expression_handler.py +86 -0
- dsgrid/dataset/dataset_mapping_manager.py +121 -0
- dsgrid/dataset/dataset_schema_handler_base.py +945 -0
- dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
- dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
- dsgrid/dataset/growth_rates.py +162 -0
- dsgrid/dataset/models.py +51 -0
- dsgrid/dataset/table_format_handler_base.py +257 -0
- dsgrid/dataset/table_format_handler_factory.py +17 -0
- dsgrid/dataset/unpivoted_table.py +121 -0
- dsgrid/dimension/__init__.py +0 -0
- dsgrid/dimension/base_models.py +230 -0
- dsgrid/dimension/dimension_filters.py +308 -0
- dsgrid/dimension/standard.py +252 -0
- dsgrid/dimension/time.py +352 -0
- dsgrid/dimension/time_utils.py +103 -0
- dsgrid/dsgrid_rc.py +88 -0
- dsgrid/exceptions.py +105 -0
- dsgrid/filesystem/__init__.py +0 -0
- dsgrid/filesystem/cloud_filesystem.py +32 -0
- dsgrid/filesystem/factory.py +32 -0
- dsgrid/filesystem/filesystem_interface.py +136 -0
- dsgrid/filesystem/local_filesystem.py +74 -0
- dsgrid/filesystem/s3_filesystem.py +118 -0
- dsgrid/loggers.py +132 -0
- dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
- dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
- dsgrid/notebooks/registration.ipynb +48 -0
- dsgrid/notebooks/start_notebook.sh +11 -0
- dsgrid/project.py +451 -0
- dsgrid/query/__init__.py +0 -0
- dsgrid/query/dataset_mapping_plan.py +142 -0
- dsgrid/query/derived_dataset.py +388 -0
- dsgrid/query/models.py +728 -0
- dsgrid/query/query_context.py +287 -0
- dsgrid/query/query_submitter.py +994 -0
- dsgrid/query/report_factory.py +19 -0
- dsgrid/query/report_peak_load.py +70 -0
- dsgrid/query/reports_base.py +20 -0
- dsgrid/registry/__init__.py +0 -0
- dsgrid/registry/bulk_register.py +165 -0
- dsgrid/registry/common.py +287 -0
- dsgrid/registry/config_update_checker_base.py +63 -0
- dsgrid/registry/data_store_factory.py +34 -0
- dsgrid/registry/data_store_interface.py +74 -0
- dsgrid/registry/dataset_config_generator.py +158 -0
- dsgrid/registry/dataset_registry_manager.py +950 -0
- dsgrid/registry/dataset_update_checker.py +16 -0
- dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
- dsgrid/registry/dimension_mapping_update_checker.py +16 -0
- dsgrid/registry/dimension_registry_manager.py +413 -0
- dsgrid/registry/dimension_update_checker.py +16 -0
- dsgrid/registry/duckdb_data_store.py +207 -0
- dsgrid/registry/filesystem_data_store.py +150 -0
- dsgrid/registry/filter_registry_manager.py +123 -0
- dsgrid/registry/project_config_generator.py +57 -0
- dsgrid/registry/project_registry_manager.py +1623 -0
- dsgrid/registry/project_update_checker.py +48 -0
- dsgrid/registry/registration_context.py +223 -0
- dsgrid/registry/registry_auto_updater.py +316 -0
- dsgrid/registry/registry_database.py +667 -0
- dsgrid/registry/registry_interface.py +446 -0
- dsgrid/registry/registry_manager.py +558 -0
- dsgrid/registry/registry_manager_base.py +367 -0
- dsgrid/registry/versioning.py +92 -0
- dsgrid/rust_ext/__init__.py +14 -0
- dsgrid/rust_ext/find_minimal_patterns.py +129 -0
- dsgrid/spark/__init__.py +0 -0
- dsgrid/spark/functions.py +589 -0
- dsgrid/spark/types.py +110 -0
- dsgrid/tests/__init__.py +0 -0
- dsgrid/tests/common.py +140 -0
- dsgrid/tests/make_us_data_registry.py +265 -0
- dsgrid/tests/register_derived_datasets.py +103 -0
- dsgrid/tests/utils.py +25 -0
- dsgrid/time/__init__.py +0 -0
- dsgrid/time/time_conversions.py +80 -0
- dsgrid/time/types.py +67 -0
- dsgrid/units/__init__.py +0 -0
- dsgrid/units/constants.py +113 -0
- dsgrid/units/convert.py +71 -0
- dsgrid/units/energy.py +145 -0
- dsgrid/units/power.py +87 -0
- dsgrid/utils/__init__.py +0 -0
- dsgrid/utils/dataset.py +830 -0
- dsgrid/utils/files.py +179 -0
- dsgrid/utils/filters.py +125 -0
- dsgrid/utils/id_remappings.py +100 -0
- dsgrid/utils/py_expression_eval/LICENSE +19 -0
- dsgrid/utils/py_expression_eval/README.md +8 -0
- dsgrid/utils/py_expression_eval/__init__.py +847 -0
- dsgrid/utils/py_expression_eval/tests.py +283 -0
- dsgrid/utils/run_command.py +70 -0
- dsgrid/utils/scratch_dir_context.py +65 -0
- dsgrid/utils/spark.py +918 -0
- dsgrid/utils/spark_partition.py +98 -0
- dsgrid/utils/timing.py +239 -0
- dsgrid/utils/utilities.py +221 -0
- dsgrid/utils/versioning.py +36 -0
- dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
- dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
- dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
- dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
- dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
dsgrid/utils/spark.py
ADDED
|
@@ -0,0 +1,918 @@
|
|
|
1
|
+
"""Spark helper functions"""
|
|
2
|
+
|
|
3
|
+
import enum
|
|
4
|
+
import itertools
|
|
5
|
+
import logging
|
|
6
|
+
import math
|
|
7
|
+
import os
|
|
8
|
+
import shutil
|
|
9
|
+
import time
|
|
10
|
+
from contextlib import contextmanager
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from types import UnionType
|
|
13
|
+
from typing import Any, Generator, Iterable, Sequence, Type, Union, get_origin, get_args
|
|
14
|
+
|
|
15
|
+
import duckdb
|
|
16
|
+
import pandas as pd
|
|
17
|
+
|
|
18
|
+
from dsgrid.data_models import DSGBaseModel
|
|
19
|
+
from dsgrid.exceptions import (
|
|
20
|
+
DSGInvalidField,
|
|
21
|
+
DSGInvalidFile,
|
|
22
|
+
DSGInvalidOperation,
|
|
23
|
+
DSGInvalidParameter,
|
|
24
|
+
)
|
|
25
|
+
from dsgrid.utils.files import delete_if_exists, load_data
|
|
26
|
+
from dsgrid.utils.scratch_dir_context import ScratchDirContext
|
|
27
|
+
from dsgrid.spark.functions import (
|
|
28
|
+
cross_join,
|
|
29
|
+
get_spark_session,
|
|
30
|
+
get_duckdb_spark_session,
|
|
31
|
+
get_current_time_zone,
|
|
32
|
+
set_current_time_zone,
|
|
33
|
+
init_spark,
|
|
34
|
+
is_dataframe_empty,
|
|
35
|
+
read_csv,
|
|
36
|
+
read_json,
|
|
37
|
+
read_parquet,
|
|
38
|
+
)
|
|
39
|
+
from dsgrid.spark.types import (
|
|
40
|
+
AnalysisException,
|
|
41
|
+
BooleanType,
|
|
42
|
+
DataFrame,
|
|
43
|
+
DoubleType,
|
|
44
|
+
IntegerType,
|
|
45
|
+
SparkSession,
|
|
46
|
+
StringType,
|
|
47
|
+
StructField,
|
|
48
|
+
StructType,
|
|
49
|
+
use_duckdb,
|
|
50
|
+
)
|
|
51
|
+
from dsgrid.utils.timing import Timer, track_timing, timer_stats_collector
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
logger = logging.getLogger(__name__)
|
|
55
|
+
|
|
56
|
+
# Consider using our own database. Would need to manage creation with
|
|
57
|
+
# spark.sql(f"CREATE DATABASE IF NOT EXISTS {database}")
|
|
58
|
+
# Doing so has caused conflicts in tests with the Derby db.
|
|
59
|
+
DSGRID_DB_NAME = "default"
|
|
60
|
+
|
|
61
|
+
MAX_PARTITION_SIZE_MB = 128
|
|
62
|
+
|
|
63
|
+
PYTHON_TO_SPARK_TYPES = {
|
|
64
|
+
int: IntegerType,
|
|
65
|
+
float: DoubleType,
|
|
66
|
+
str: StringType,
|
|
67
|
+
bool: BooleanType,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_active_session(*args) -> SparkSession:
|
|
72
|
+
"""Return the active Spark Session."""
|
|
73
|
+
return get_duckdb_spark_session() or init_spark(*args)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def restart_spark(*args, force=False, **kwargs) -> SparkSession:
|
|
77
|
+
"""Restart a SparkSession with new config parameters. Refer to init_spark for parameters.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
force : bool
|
|
82
|
+
If True, restart the session even if the config parameters haven't changed.
|
|
83
|
+
You might want to do this in order to clear cached tables or start Spark fresh.
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
pyspark.sql.SparkSession
|
|
88
|
+
|
|
89
|
+
"""
|
|
90
|
+
spark = get_duckdb_spark_session()
|
|
91
|
+
if spark is not None:
|
|
92
|
+
return spark
|
|
93
|
+
|
|
94
|
+
spark = SparkSession.getActiveSession()
|
|
95
|
+
needs_restart = force
|
|
96
|
+
orig_time_zone = spark.conf.get("spark.sql.session.timeZone")
|
|
97
|
+
conf = kwargs.get("spark_conf", {})
|
|
98
|
+
new_time_zone = conf.get("spark.sql.session.timeZone", orig_time_zone)
|
|
99
|
+
|
|
100
|
+
if not force:
|
|
101
|
+
for key, val in conf.items():
|
|
102
|
+
current = spark.conf.get(key, None)
|
|
103
|
+
if isinstance(current, str):
|
|
104
|
+
match current.lower():
|
|
105
|
+
case "true":
|
|
106
|
+
current = True
|
|
107
|
+
case "false":
|
|
108
|
+
current = False
|
|
109
|
+
if current is not None and current != val:
|
|
110
|
+
logger.info("SparkSession needs restart because of %s = %s", key, val)
|
|
111
|
+
needs_restart = True
|
|
112
|
+
break
|
|
113
|
+
|
|
114
|
+
if needs_restart:
|
|
115
|
+
spark.stop()
|
|
116
|
+
logger.info("Stopped the SparkSession so that it can be restarted with a new config.")
|
|
117
|
+
spark = init_spark(*args, **kwargs)
|
|
118
|
+
if spark.conf.get("spark.sql.session.timeZone") != new_time_zone:
|
|
119
|
+
# We set this value in query_submitter.py and that change will get lost
|
|
120
|
+
# when the session is restarted.
|
|
121
|
+
spark.conf.set("spark.sql.session.timeZone", new_time_zone)
|
|
122
|
+
else:
|
|
123
|
+
logger.info("No restart of Spark is needed.")
|
|
124
|
+
|
|
125
|
+
return spark
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@track_timing(timer_stats_collector)
|
|
129
|
+
def create_dataframe(records, table_name=None, require_unique=None) -> DataFrame:
|
|
130
|
+
"""Create a spark DataFrame from a list of records.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
records : list
|
|
135
|
+
list of spark.sql.Row
|
|
136
|
+
table_name : str | None
|
|
137
|
+
If set, cache the DataFrame in memory with this name. Must be unique.
|
|
138
|
+
require_unique : list
|
|
139
|
+
list of column names (str) to check for uniqueness
|
|
140
|
+
"""
|
|
141
|
+
df = get_spark_session().createDataFrame(records)
|
|
142
|
+
_post_process_dataframe(df, table_name=table_name, require_unique=require_unique)
|
|
143
|
+
return df
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@track_timing(timer_stats_collector)
|
|
147
|
+
def create_dataframe_from_ids(ids: Iterable[str], column: str) -> DataFrame:
|
|
148
|
+
"""Create a spark DataFrame from a list of dimension IDs."""
|
|
149
|
+
schema = StructType([StructField(column, StringType())])
|
|
150
|
+
return get_spark_session().createDataFrame([[x] for x in ids], schema)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def create_dataframe_from_pandas(df):
|
|
154
|
+
"""Create a spark DataFrame from a pandas DataFrame."""
|
|
155
|
+
return get_spark_session().createDataFrame(df)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def create_dataframe_from_dicts(records: list[dict[str, Any]]) -> DataFrame:
|
|
159
|
+
"""Create a spark DataFrame from a list of dictionaries.
|
|
160
|
+
|
|
161
|
+
The only purpose is to avoid pyright complaints about the type of the input to
|
|
162
|
+
spark.createDataFrame. This can be removed if pyspark fixes the type annotations.
|
|
163
|
+
"""
|
|
164
|
+
if not records:
|
|
165
|
+
msg = "records cannot be empty in create_dataframe_from_dicts"
|
|
166
|
+
raise DSGInvalidParameter(msg)
|
|
167
|
+
|
|
168
|
+
data = [tuple(row.values()) for row in records]
|
|
169
|
+
columns = list(records[0].keys())
|
|
170
|
+
return get_spark_session().createDataFrame(data, columns)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def try_read_dataframe(filename: Path, delete_if_invalid=True, **kwargs):
|
|
174
|
+
"""Try to read the dataframe.
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
filename : Path
|
|
179
|
+
delete_if_invalid : bool
|
|
180
|
+
Delete the file if it cannot be read, defaults to true.
|
|
181
|
+
kwargs
|
|
182
|
+
Forwarded to read_dataframe.
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
pyspark.sql.DataFrame | None
|
|
187
|
+
Returns None if the file does not exist or is invalid.
|
|
188
|
+
|
|
189
|
+
"""
|
|
190
|
+
if not filename.exists():
|
|
191
|
+
return None
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
return read_dataframe(filename, **kwargs)
|
|
195
|
+
except DSGInvalidFile:
|
|
196
|
+
if delete_if_invalid:
|
|
197
|
+
if filename.is_dir():
|
|
198
|
+
shutil.rmtree(filename)
|
|
199
|
+
else:
|
|
200
|
+
filename.unlink()
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@track_timing(timer_stats_collector)
|
|
205
|
+
def read_dataframe(
|
|
206
|
+
filename: str | Path,
|
|
207
|
+
table_name: str | None = None,
|
|
208
|
+
require_unique: None | bool = None,
|
|
209
|
+
read_with_spark: bool = True,
|
|
210
|
+
) -> DataFrame:
|
|
211
|
+
"""Create a spark DataFrame from a file.
|
|
212
|
+
|
|
213
|
+
Supported formats when read_with_spark=True: .csv, .json, .parquet
|
|
214
|
+
Supported formats when read_with_spark=False: .csv, .json
|
|
215
|
+
|
|
216
|
+
When reading CSV files on AWS read_with_spark should be set to False because the
|
|
217
|
+
files would need to be present on local storage for all workers. The master node
|
|
218
|
+
will sync the config files from S3, read them with standard filesystem system calls,
|
|
219
|
+
and then convert the data to Spark dataframes. This could change if we ever decide
|
|
220
|
+
to read CSV files with Spark directly from S3.
|
|
221
|
+
|
|
222
|
+
Parameters
|
|
223
|
+
----------
|
|
224
|
+
filename : str | Path
|
|
225
|
+
path to file
|
|
226
|
+
table_name : str | None
|
|
227
|
+
If set, cache the DataFrame in memory. Must be unique.
|
|
228
|
+
require_unique : list
|
|
229
|
+
list of column names (str) to check for uniqueness
|
|
230
|
+
read_with_spark : bool
|
|
231
|
+
If True, read the file with pyspark.read. Otherwise, read the file into
|
|
232
|
+
a list of dicts, convert to pyspark Rows, and then to a DataFrame.
|
|
233
|
+
|
|
234
|
+
Returns
|
|
235
|
+
-------
|
|
236
|
+
spark.sql.DataFrame
|
|
237
|
+
|
|
238
|
+
Raises
|
|
239
|
+
------
|
|
240
|
+
ValueError
|
|
241
|
+
Raised if a require_unique column has duplicate values.
|
|
242
|
+
DSGInvalidFile
|
|
243
|
+
Raised if the file cannot be read. This can happen if a Parquet write operation fails.
|
|
244
|
+
|
|
245
|
+
"""
|
|
246
|
+
func = _read_with_spark if read_with_spark else _read_natively
|
|
247
|
+
df = func(str(filename))
|
|
248
|
+
_post_process_dataframe(df, table_name=table_name, require_unique=require_unique)
|
|
249
|
+
return df
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _read_with_spark(filename):
|
|
253
|
+
if not os.path.exists(filename):
|
|
254
|
+
msg = f"{filename} does not exist"
|
|
255
|
+
raise FileNotFoundError(msg)
|
|
256
|
+
suffix = Path(filename).suffix
|
|
257
|
+
if suffix == ".csv":
|
|
258
|
+
df = read_csv(filename)
|
|
259
|
+
elif suffix == ".parquet":
|
|
260
|
+
try:
|
|
261
|
+
df = read_parquet(filename)
|
|
262
|
+
except AnalysisException as exc:
|
|
263
|
+
if "Unable to infer schema for Parquet. It must be specified manually." in str(exc):
|
|
264
|
+
logger.exception("Failed to read Parquet file=%s. File may be invalid", filename)
|
|
265
|
+
msg = f"Cannot read {filename=}"
|
|
266
|
+
raise DSGInvalidFile(msg)
|
|
267
|
+
raise
|
|
268
|
+
except duckdb.duckdb.IOException:
|
|
269
|
+
logger.exception("Failed to read Parquet file=%s. File may be invalid", filename)
|
|
270
|
+
msg = f"Cannot read {filename=}"
|
|
271
|
+
raise DSGInvalidFile(msg)
|
|
272
|
+
|
|
273
|
+
elif suffix == ".json":
|
|
274
|
+
df = read_json(filename)
|
|
275
|
+
else:
|
|
276
|
+
assert False, f"Unsupported file extension: {filename}"
|
|
277
|
+
return df
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _read_natively(filename):
|
|
281
|
+
suffix = Path(filename).suffix
|
|
282
|
+
if suffix == ".csv":
|
|
283
|
+
# Reading the file is faster with pandas. Converting a list of Row to spark df
|
|
284
|
+
# is a tiny bit faster. Pandas is likely scales better with bigger files.
|
|
285
|
+
# Keep the code in case we ever want to revert.
|
|
286
|
+
# with open(filename, encoding="utf-8-sig") as f_in:
|
|
287
|
+
# rows = [Row(**x) for x in csv.DictReader(f_in)]
|
|
288
|
+
obj = pd.read_csv(filename)
|
|
289
|
+
elif suffix == ".json":
|
|
290
|
+
obj = load_data(filename)
|
|
291
|
+
else:
|
|
292
|
+
msg = f"Unsupported file extension: {filename}"
|
|
293
|
+
raise NotImplementedError(msg)
|
|
294
|
+
return get_spark_session().createDataFrame(obj)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _post_process_dataframe(df, table_name=None, require_unique=None):
|
|
298
|
+
if not use_duckdb() and table_name is not None:
|
|
299
|
+
df.createOrReplaceTempView(table_name)
|
|
300
|
+
df.cache()
|
|
301
|
+
|
|
302
|
+
if require_unique is not None:
|
|
303
|
+
with Timer(timer_stats_collector, "check_unique"):
|
|
304
|
+
for column in require_unique:
|
|
305
|
+
unique = df.select(column).distinct()
|
|
306
|
+
if unique.count() != df.count():
|
|
307
|
+
msg = f"DataFrame has duplicate entries for {column}"
|
|
308
|
+
raise DSGInvalidField(msg)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def cross_join_dfs(dfs: list[DataFrame]) -> DataFrame:
|
|
312
|
+
"""Perform a cross join of all dataframes in dfs."""
|
|
313
|
+
if len(dfs) == 1:
|
|
314
|
+
return dfs[0]
|
|
315
|
+
|
|
316
|
+
df = dfs[0]
|
|
317
|
+
for other in dfs[1:]:
|
|
318
|
+
df = cross_join(df, other)
|
|
319
|
+
return df
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def get_unique_values(df: DataFrame, columns: Sequence[str]) -> set[str]:
|
|
323
|
+
"""Return the unique values of a dataframe in one column or a list of columns."""
|
|
324
|
+
dfc = df.select(columns).distinct().collect()
|
|
325
|
+
if isinstance(columns, list):
|
|
326
|
+
values = {tuple(getattr(row, col) for col in columns) for row in dfc}
|
|
327
|
+
else:
|
|
328
|
+
values = {getattr(x, columns) for x in dfc}
|
|
329
|
+
|
|
330
|
+
return values
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
@track_timing(timer_stats_collector)
|
|
334
|
+
def models_to_dataframe(models: list[DSGBaseModel], table_name: str | None = None) -> DataFrame:
|
|
335
|
+
"""Converts a list of Pydantic models to a Spark DataFrame.
|
|
336
|
+
|
|
337
|
+
Parameters
|
|
338
|
+
----------
|
|
339
|
+
models : list
|
|
340
|
+
table_name : str | None
|
|
341
|
+
If set, a unique ID to use as the cached table name. Return from cache if already stored.
|
|
342
|
+
"""
|
|
343
|
+
spark = get_spark_session()
|
|
344
|
+
if not use_duckdb():
|
|
345
|
+
if (
|
|
346
|
+
table_name is not None
|
|
347
|
+
and spark.catalog.tableExists(table_name)
|
|
348
|
+
and spark.catalog.isCached(table_name)
|
|
349
|
+
):
|
|
350
|
+
return spark.table(table_name)
|
|
351
|
+
|
|
352
|
+
assert models
|
|
353
|
+
cls = type(models[0])
|
|
354
|
+
rows = []
|
|
355
|
+
struct_fields = []
|
|
356
|
+
for i, model in enumerate(models):
|
|
357
|
+
dct = {}
|
|
358
|
+
for f in cls.model_fields:
|
|
359
|
+
val = getattr(model, f)
|
|
360
|
+
if isinstance(val, enum.Enum):
|
|
361
|
+
val = val.value
|
|
362
|
+
if i == 0:
|
|
363
|
+
if val is None:
|
|
364
|
+
python_type = cls.model_fields[f].annotation
|
|
365
|
+
origin = get_origin(python_type)
|
|
366
|
+
if origin is Union or origin is UnionType:
|
|
367
|
+
python_type = get_type_from_union(python_type)
|
|
368
|
+
# else: will likely fail below
|
|
369
|
+
# Need to add more logic to detect the actual type or add to
|
|
370
|
+
# PYTHON_TO_SPARK_TYPES.
|
|
371
|
+
else:
|
|
372
|
+
python_type = type(val)
|
|
373
|
+
spark_type = PYTHON_TO_SPARK_TYPES[python_type]()
|
|
374
|
+
struct_fields.append(StructField(f, spark_type, nullable=True))
|
|
375
|
+
dct[f] = val
|
|
376
|
+
rows.append(tuple(dct.values()))
|
|
377
|
+
|
|
378
|
+
schema = StructType(struct_fields)
|
|
379
|
+
df = spark.createDataFrame(rows, schema=schema)
|
|
380
|
+
|
|
381
|
+
if not use_duckdb() and table_name is not None:
|
|
382
|
+
df.createOrReplaceTempView(table_name)
|
|
383
|
+
df.cache()
|
|
384
|
+
|
|
385
|
+
return df
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def get_type_from_union(python_type) -> Type:
|
|
389
|
+
"""Return the Python type from a Union.
|
|
390
|
+
|
|
391
|
+
Only works if it is Union of NoneType and something.
|
|
392
|
+
|
|
393
|
+
Raises
|
|
394
|
+
------
|
|
395
|
+
NotImplementedError
|
|
396
|
+
Raised if the code does know how to determine the type.
|
|
397
|
+
"""
|
|
398
|
+
args = get_args(python_type)
|
|
399
|
+
if issubclass(args[0], enum.Enum):
|
|
400
|
+
python_type = type(next(iter(args[0])).value)
|
|
401
|
+
else:
|
|
402
|
+
types = [x for x in args if not issubclass(x, type(None))]
|
|
403
|
+
if not types:
|
|
404
|
+
msg = f"Unhandled Union type: {python_type=} {args=}"
|
|
405
|
+
raise NotImplementedError(msg)
|
|
406
|
+
elif len(types) > 1:
|
|
407
|
+
msg = f"Unhandled Union type: {types=}"
|
|
408
|
+
raise NotImplementedError(msg)
|
|
409
|
+
else:
|
|
410
|
+
python_type = types[0]
|
|
411
|
+
|
|
412
|
+
return python_type
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
@track_timing(timer_stats_collector)
|
|
416
|
+
def create_dataframe_from_dimension_ids(records, *dimension_types, cache=True) -> DataFrame:
|
|
417
|
+
"""Return a DataFrame created from the IDs of dimension_types.
|
|
418
|
+
|
|
419
|
+
Parameters
|
|
420
|
+
----------
|
|
421
|
+
records : sequence
|
|
422
|
+
Iterable of lists of record IDs
|
|
423
|
+
dimension_types : tuple
|
|
424
|
+
cache : If True, cache the DataFrame.
|
|
425
|
+
"""
|
|
426
|
+
schema = StructType()
|
|
427
|
+
for dimension_type in dimension_types:
|
|
428
|
+
schema.add(dimension_type.value, StringType(), nullable=False)
|
|
429
|
+
df = get_spark_session().createDataFrame(records, schema=schema)
|
|
430
|
+
if not use_duckdb() and cache:
|
|
431
|
+
df.cache()
|
|
432
|
+
return df
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
@track_timing(timer_stats_collector)
|
|
436
|
+
def check_for_nulls(df, exclude_columns=None):
|
|
437
|
+
"""Check if a DataFrame has null values.
|
|
438
|
+
|
|
439
|
+
Parameters
|
|
440
|
+
----------
|
|
441
|
+
df : spark.sql.DataFrame
|
|
442
|
+
exclude_columns : None or Set
|
|
443
|
+
|
|
444
|
+
Raises
|
|
445
|
+
------
|
|
446
|
+
DSGInvalidField
|
|
447
|
+
Raised if null exists in any column.
|
|
448
|
+
|
|
449
|
+
"""
|
|
450
|
+
if exclude_columns is None:
|
|
451
|
+
exclude_columns = set()
|
|
452
|
+
cols_to_check = set(df.columns).difference(exclude_columns)
|
|
453
|
+
cols_str = ", ".join(cols_to_check)
|
|
454
|
+
filter_str = " OR ".join((f"{x} IS NULL" for x in cols_to_check))
|
|
455
|
+
df.createOrReplaceTempView("tmp_view")
|
|
456
|
+
|
|
457
|
+
try:
|
|
458
|
+
# Avoid iterating with many checks unless we know there is at least one failure.
|
|
459
|
+
nulls = sql(f"SELECT {cols_str} FROM tmp_view WHERE {filter_str}")
|
|
460
|
+
if not is_dataframe_empty(nulls):
|
|
461
|
+
cols_with_null = set()
|
|
462
|
+
for col in cols_to_check:
|
|
463
|
+
if not is_dataframe_empty(nulls.select(col).filter(f"{col} is NULL")):
|
|
464
|
+
cols_with_null.add(col)
|
|
465
|
+
assert cols_with_null, "Did not find any columns with NULL values"
|
|
466
|
+
|
|
467
|
+
msg = f"DataFrame contains NULL value(s) for column(s): {cols_with_null}"
|
|
468
|
+
raise DSGInvalidField(msg)
|
|
469
|
+
finally:
|
|
470
|
+
sql("DROP VIEW tmp_view")
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
@track_timing(timer_stats_collector)
|
|
474
|
+
def overwrite_dataframe_file(filename: Path | str, df: DataFrame) -> DataFrame:
|
|
475
|
+
"""Perform an in-place overwrite of a Spark DataFrame, accounting for different file types
|
|
476
|
+
and symlinks.
|
|
477
|
+
|
|
478
|
+
Do not attempt to access the original dataframe unless it was fully cached.
|
|
479
|
+
"""
|
|
480
|
+
spark = get_spark_session()
|
|
481
|
+
suffix = Path(filename).suffix
|
|
482
|
+
tmp = str(filename) + ".tmp"
|
|
483
|
+
if suffix == ".parquet":
|
|
484
|
+
df.write.parquet(tmp)
|
|
485
|
+
read_method = read_parquet
|
|
486
|
+
kwargs = {}
|
|
487
|
+
elif suffix == ".csv":
|
|
488
|
+
df.write.csv(str(tmp), header=True)
|
|
489
|
+
read_method = spark.read.csv
|
|
490
|
+
kwargs = {"header": True, "schema": df.schema}
|
|
491
|
+
elif suffix == ".json":
|
|
492
|
+
df.write.json(str(tmp))
|
|
493
|
+
read_method = spark.read.json
|
|
494
|
+
kwargs = {}
|
|
495
|
+
delete_if_exists(filename)
|
|
496
|
+
os.rename(tmp, str(filename))
|
|
497
|
+
return read_method(str(filename), **kwargs)
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
@track_timing(timer_stats_collector)
|
|
501
|
+
def persist_intermediate_query(
|
|
502
|
+
df: DataFrame, scratch_dir_context: ScratchDirContext, auto_partition=False
|
|
503
|
+
) -> DataFrame:
|
|
504
|
+
"""Persist the current query to files and then read it back and return it.
|
|
505
|
+
|
|
506
|
+
This is advised when the query has become too complex or when the query might be evaluated
|
|
507
|
+
twice.
|
|
508
|
+
|
|
509
|
+
Parameters
|
|
510
|
+
----------
|
|
511
|
+
df : DataFrame
|
|
512
|
+
scratch_dir_context : ScratchDirContext
|
|
513
|
+
auto_partition : bool
|
|
514
|
+
If True, call write_dataframe_and_auto_partition.
|
|
515
|
+
|
|
516
|
+
Returns
|
|
517
|
+
-------
|
|
518
|
+
DataFrame
|
|
519
|
+
"""
|
|
520
|
+
spark = get_spark_session()
|
|
521
|
+
tmp_file = scratch_dir_context.get_temp_filename(suffix=".parquet")
|
|
522
|
+
if auto_partition:
|
|
523
|
+
return write_dataframe_and_auto_partition(df, tmp_file)
|
|
524
|
+
df.write.parquet(str(tmp_file))
|
|
525
|
+
return spark.read.parquet(str(tmp_file))
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
@track_timing(timer_stats_collector)
|
|
529
|
+
def write_dataframe_and_auto_partition(
|
|
530
|
+
df: DataFrame,
|
|
531
|
+
filename: Path,
|
|
532
|
+
partition_size_mb: int = MAX_PARTITION_SIZE_MB,
|
|
533
|
+
columns: list[str] | None = None,
|
|
534
|
+
rtol_pct: float = 50,
|
|
535
|
+
min_num_partitions: int = 36,
|
|
536
|
+
) -> DataFrame:
|
|
537
|
+
"""Write a dataframe to a Parquet file and then automatically coalesce or repartition it if
|
|
538
|
+
needed. If the file already exists, it will be overwritten.
|
|
539
|
+
|
|
540
|
+
Parameters
|
|
541
|
+
----------
|
|
542
|
+
df : pyspark.sql.DataFrame
|
|
543
|
+
filename : Path
|
|
544
|
+
partition_size_mb : int
|
|
545
|
+
Target size in MB for each partition
|
|
546
|
+
columns : None, list
|
|
547
|
+
If not None and repartitioning is needed, partition on these columns.
|
|
548
|
+
rtol_pct : int
|
|
549
|
+
Don't repartition or coalesce if the relative difference between desired and actual
|
|
550
|
+
partitions is within this tolerance as a percentage.
|
|
551
|
+
min_num_partitions : int
|
|
552
|
+
Minimum number of partitions to create. If the number of partitions is less than this,
|
|
553
|
+
Do not coalesce/repartition because it will reduce parallelism.
|
|
554
|
+
|
|
555
|
+
Raises
|
|
556
|
+
------
|
|
557
|
+
DSGInvalidParameter
|
|
558
|
+
Raised if a non-Parquet file is passed
|
|
559
|
+
"""
|
|
560
|
+
suffix = Path(filename).suffix
|
|
561
|
+
if suffix != ".parquet":
|
|
562
|
+
msg = "write_dataframe_and_auto_partition only supports Parquet files: {filename=}"
|
|
563
|
+
raise DSGInvalidParameter(msg)
|
|
564
|
+
|
|
565
|
+
start_initial_write = time.time()
|
|
566
|
+
if filename.exists():
|
|
567
|
+
df = overwrite_dataframe_file(filename, df)
|
|
568
|
+
else:
|
|
569
|
+
df.write.parquet(str(filename))
|
|
570
|
+
df = read_parquet(filename)
|
|
571
|
+
|
|
572
|
+
end_initial_write = time.time()
|
|
573
|
+
duration_first_write = end_initial_write - start_initial_write
|
|
574
|
+
|
|
575
|
+
if use_duckdb():
|
|
576
|
+
logger.debug("write_dataframe_and_auto_partition is not optimized for DuckDB")
|
|
577
|
+
return df
|
|
578
|
+
|
|
579
|
+
num_partitions = len(list(filename.parent.iterdir()))
|
|
580
|
+
if num_partitions < min_num_partitions:
|
|
581
|
+
logger.info(
|
|
582
|
+
"Not coalescing %s because it has only %s partitions, "
|
|
583
|
+
"which is less than the minimum of %s.",
|
|
584
|
+
filename,
|
|
585
|
+
num_partitions,
|
|
586
|
+
min_num_partitions,
|
|
587
|
+
)
|
|
588
|
+
# TODO: consider repartitioning to increase the number of partitions.
|
|
589
|
+
return df
|
|
590
|
+
|
|
591
|
+
partition_size_bytes = partition_size_mb * 1024 * 1024
|
|
592
|
+
total_size = sum((x.stat().st_size for x in filename.glob("*.parquet")))
|
|
593
|
+
desired = math.ceil(total_size / partition_size_bytes)
|
|
594
|
+
actual = len(list(filename.glob("*.parquet")))
|
|
595
|
+
if abs(actual - desired) / desired * 100 < rtol_pct:
|
|
596
|
+
logger.info("No change in number of partitions is needed for %s.", filename)
|
|
597
|
+
elif actual > desired:
|
|
598
|
+
df = df.coalesce(desired)
|
|
599
|
+
df = overwrite_dataframe_file(filename, df)
|
|
600
|
+
duration_second_write = time.time() - end_initial_write
|
|
601
|
+
logger.info(
|
|
602
|
+
"Coalesced %s from partition count %s to %s. "
|
|
603
|
+
"duration_first_write=%s duration_second_write=%s",
|
|
604
|
+
filename,
|
|
605
|
+
actual,
|
|
606
|
+
desired,
|
|
607
|
+
duration_first_write,
|
|
608
|
+
duration_second_write,
|
|
609
|
+
)
|
|
610
|
+
else:
|
|
611
|
+
if columns is None:
|
|
612
|
+
df = df.repartition(desired)
|
|
613
|
+
else:
|
|
614
|
+
df = df.repartition(desired, *columns)
|
|
615
|
+
df = overwrite_dataframe_file(filename, df)
|
|
616
|
+
duration_second_write = time.time() - end_initial_write
|
|
617
|
+
logger.info(
|
|
618
|
+
"Repartitioned %s from partition count %s to %s. "
|
|
619
|
+
"duration_first_write=%s duration_second_write=%s",
|
|
620
|
+
filename,
|
|
621
|
+
actual,
|
|
622
|
+
desired,
|
|
623
|
+
duration_first_write,
|
|
624
|
+
duration_second_write,
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
logger.info("Wrote dataframe to %s", filename)
|
|
628
|
+
return df
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
@track_timing(timer_stats_collector)
|
|
632
|
+
def write_dataframe(df: DataFrame, filename: str | Path, overwrite: bool = False) -> None:
|
|
633
|
+
"""Write a Spark DataFrame, accounting for different file types.
|
|
634
|
+
|
|
635
|
+
Parameters
|
|
636
|
+
----------
|
|
637
|
+
filename : str
|
|
638
|
+
df : pyspark.sql.DataFrame
|
|
639
|
+
"""
|
|
640
|
+
path = Path(filename)
|
|
641
|
+
if overwrite:
|
|
642
|
+
delete_if_exists(path)
|
|
643
|
+
|
|
644
|
+
suffix = path.suffix
|
|
645
|
+
name = str(filename)
|
|
646
|
+
if suffix == ".parquet":
|
|
647
|
+
df.write.parquet(name)
|
|
648
|
+
elif suffix == ".csv":
|
|
649
|
+
df.write.csv(name, header=True)
|
|
650
|
+
elif suffix == ".json":
|
|
651
|
+
if use_duckdb():
|
|
652
|
+
new_name = name.replace(".json", ".parquet")
|
|
653
|
+
df.write.parquet(new_name)
|
|
654
|
+
else:
|
|
655
|
+
df.write.json(name)
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
@track_timing(timer_stats_collector)
|
|
659
|
+
def persist_table(df: DataFrame, context: ScratchDirContext, tag=None) -> Path:
|
|
660
|
+
"""Persist a table to the scratch directory. This can be helpful to avoid multiple
|
|
661
|
+
evaluations of the same query.
|
|
662
|
+
"""
|
|
663
|
+
# Note: This does not use the Spark warehouse because we are not properly configuring or
|
|
664
|
+
# managing it across sessions. And, we are already using the scratch dir for our own files.
|
|
665
|
+
path = context.get_temp_filename(suffix=".parquet")
|
|
666
|
+
logger.info("Start persist_table %s %s", path, tag or "")
|
|
667
|
+
write_dataframe(df, path)
|
|
668
|
+
logger.info("Completed persist_table %s %s", path, tag or "")
|
|
669
|
+
return path
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
@track_timing(timer_stats_collector)
|
|
673
|
+
def save_to_warehouse(df: DataFrame, table_name: str) -> DataFrame:
|
|
674
|
+
"""Save a table to the Spark warehouse. Not supported when using DuckDB."""
|
|
675
|
+
if use_duckdb():
|
|
676
|
+
msg = "save_to_warehouse is not supported when using DuckDB"
|
|
677
|
+
raise DSGInvalidOperation(msg)
|
|
678
|
+
|
|
679
|
+
logger.info("Start saveAsTable to warehouse %s", table_name)
|
|
680
|
+
df.write.saveAsTable(table_name)
|
|
681
|
+
logger.info("Completed saveAsTable %s", table_name)
|
|
682
|
+
return df.sparkSession.sql(f"select * from {table_name}")
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def sql(query: str) -> DataFrame:
|
|
686
|
+
"""Run a SQL query with Spark."""
|
|
687
|
+
logger.debug("Run SQL query [%s]", query)
|
|
688
|
+
return get_spark_session().sql(query)
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
def load_stored_table(table_name: str) -> DataFrame:
|
|
692
|
+
"""Return a table stored in the Spark warehouse."""
|
|
693
|
+
spark = get_spark_session()
|
|
694
|
+
return spark.table(table_name)
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def try_load_stored_table(
|
|
698
|
+
table_name: str, database: str | None = DSGRID_DB_NAME
|
|
699
|
+
) -> DataFrame | None:
|
|
700
|
+
"""Return a table if it is stored in the Spark warehouse."""
|
|
701
|
+
spark = get_spark_session()
|
|
702
|
+
full_name = f"{database}.{table_name}"
|
|
703
|
+
if spark.catalog.tableExists(full_name):
|
|
704
|
+
return spark.table(table_name)
|
|
705
|
+
return None
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
def is_table_stored(table_name, database=DSGRID_DB_NAME):
|
|
709
|
+
spark = get_spark_session()
|
|
710
|
+
full_name = f"{database}.{table_name}"
|
|
711
|
+
return spark.catalog.tableExists(full_name)
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
def save_table(table, table_name, overwrite=True, database=DSGRID_DB_NAME):
|
|
715
|
+
full_name = f"{database}.{table_name}"
|
|
716
|
+
if overwrite:
|
|
717
|
+
table.write.mode("overwrite").saveAsTable(full_name)
|
|
718
|
+
else:
|
|
719
|
+
table.write.saveAsTable(full_name)
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def list_tables(database=DSGRID_DB_NAME):
|
|
723
|
+
spark = get_spark_session()
|
|
724
|
+
return [x.name for x in spark.catalog.listTables(dbName=database)]
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
def drop_table(table_name, database=DSGRID_DB_NAME):
|
|
728
|
+
spark = get_spark_session()
|
|
729
|
+
if is_table_stored(table_name, database=database):
|
|
730
|
+
spark.sql(f"DROP TABLE {table_name}")
|
|
731
|
+
logger.info("Dropped table %s", table_name)
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
@track_timing(timer_stats_collector)
|
|
735
|
+
def create_dataframe_from_product(
|
|
736
|
+
data: dict[str, list[str]],
|
|
737
|
+
context: ScratchDirContext,
|
|
738
|
+
max_partition_size_mb=MAX_PARTITION_SIZE_MB,
|
|
739
|
+
) -> DataFrame:
|
|
740
|
+
"""Create a dataframe by taking a product of values/columns in a dict.
|
|
741
|
+
|
|
742
|
+
Parameters
|
|
743
|
+
----------
|
|
744
|
+
data : dict
|
|
745
|
+
Columns on which to perform a cross product.
|
|
746
|
+
{"sector": [com], "subsector": ["SmallOffice", "LargeOffice"]}
|
|
747
|
+
context : ScratchDirContext
|
|
748
|
+
Manages temporary files.
|
|
749
|
+
"""
|
|
750
|
+
# dthom: 1/29/2024
|
|
751
|
+
# This implementation creates a product of all columns in Python, writes them to temporary
|
|
752
|
+
# CSV files, and then loads that back into Spark.
|
|
753
|
+
# This is the fastest way I've found to pass a large dataframe from the Spark driver (Python
|
|
754
|
+
# app) to the Spark workers on compute nodes.
|
|
755
|
+
# The total size of a table can be large depending on the numbers of dimensions. For example,
|
|
756
|
+
# comstock_conus_2022_projected is 3108 counties * 41 model years * 21 end uses * 14 subsectors * 3 scenarios
|
|
757
|
+
# 112_391_496 rows. The CSV files are ~7.7 GB.
|
|
758
|
+
# (Note that, due to compression, the same table in Parquet is 7 MB.)
|
|
759
|
+
# This is not ideal because it writes temporary files to the filesystem.
|
|
760
|
+
# Other solutions tried:
|
|
761
|
+
# 1. spark.createDataFrame(spark.sparkContext.parallelize(itertools.product(*(data.values()))), list(data.keys))
|
|
762
|
+
# Reasonably fast until the data is larger than Spark's max RPC message size. Then it fails.
|
|
763
|
+
# 2. Create an RDD and then call rdd.flatMap with the output of itertools.product. Very slow.
|
|
764
|
+
# 3. Create one Spark DataFrame per column and then cross-join all of them. Extremely slow.
|
|
765
|
+
# 4. Create one pyarrow Table, write to temp Parquet, read back in Spark. ~2x slower
|
|
766
|
+
# than CSV implementaion.
|
|
767
|
+
# 5. Create the joined table via SQLite and then read the contents into Spark with a JDBC
|
|
768
|
+
# driver. Much slower.
|
|
769
|
+
|
|
770
|
+
# Note: This location must be accessible on all compute nodes.
|
|
771
|
+
csv_dir = context.get_temp_filename(suffix=".csv")
|
|
772
|
+
columns = list(data.keys())
|
|
773
|
+
schema = StructType([StructField(x, StringType()) for x in columns])
|
|
774
|
+
|
|
775
|
+
with CsvPartitionWriter(csv_dir, max_partition_size_mb=max_partition_size_mb) as writer:
|
|
776
|
+
for row in itertools.product(*(data.values())):
|
|
777
|
+
writer.add_row(row)
|
|
778
|
+
|
|
779
|
+
spark = get_spark_session()
|
|
780
|
+
if use_duckdb():
|
|
781
|
+
df = spark.read.csv(f"{csv_dir}/*.csv", header=False, schema=schema)
|
|
782
|
+
else:
|
|
783
|
+
df = spark.read.csv(str(csv_dir), header=False, schema=schema)
|
|
784
|
+
return df
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
class CsvPartitionWriter:
|
|
788
|
+
"""Writes dataframe rows to partitioned CSV files."""
|
|
789
|
+
|
|
790
|
+
def __init__(self, directory: Path, max_partition_size_mb: int = MAX_PARTITION_SIZE_MB):
|
|
791
|
+
self._directory = directory
|
|
792
|
+
self._directory.mkdir(exist_ok=True)
|
|
793
|
+
self._max_size = max_partition_size_mb * 1024 * 1024
|
|
794
|
+
self._size = 0
|
|
795
|
+
self._index = 1
|
|
796
|
+
self._fp = None
|
|
797
|
+
|
|
798
|
+
def __enter__(self):
|
|
799
|
+
return self
|
|
800
|
+
|
|
801
|
+
def __exit__(self, *args, **kwargs):
|
|
802
|
+
if self._fp is not None:
|
|
803
|
+
self._fp.close()
|
|
804
|
+
|
|
805
|
+
def add_row(self, row: tuple) -> None:
|
|
806
|
+
"""Add a row to the CSV files."""
|
|
807
|
+
line = ",".join(row)
|
|
808
|
+
if self._fp is None:
|
|
809
|
+
filename = self._directory / f"part{self._index}.csv"
|
|
810
|
+
self._fp = open(filename, "w", encoding="utf-8")
|
|
811
|
+
self._size += self._fp.write(line)
|
|
812
|
+
self._size += self._fp.write("\n")
|
|
813
|
+
if self._size >= self._max_size:
|
|
814
|
+
self._fp.close()
|
|
815
|
+
self._fp = None
|
|
816
|
+
self._size = 0
|
|
817
|
+
self._index += 1
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
@contextmanager
|
|
821
|
+
def custom_spark_conf(conf):
|
|
822
|
+
"""Apply a custom Spark configuration for the duration of a code block.
|
|
823
|
+
|
|
824
|
+
Parameters
|
|
825
|
+
----------
|
|
826
|
+
conf : dict
|
|
827
|
+
Key-value pairs to set on the spark configuration.
|
|
828
|
+
|
|
829
|
+
"""
|
|
830
|
+
spark = get_duckdb_spark_session()
|
|
831
|
+
if spark is not None:
|
|
832
|
+
yield
|
|
833
|
+
return
|
|
834
|
+
|
|
835
|
+
spark = get_spark_session()
|
|
836
|
+
orig_settings = {}
|
|
837
|
+
|
|
838
|
+
try:
|
|
839
|
+
for key, val in conf.items():
|
|
840
|
+
orig_settings[key] = spark.conf.get(key)
|
|
841
|
+
spark.conf.set(key, val)
|
|
842
|
+
logger.info("Set %s=%s temporarily", key, val)
|
|
843
|
+
yield
|
|
844
|
+
finally:
|
|
845
|
+
# Note that the user code could have restarted the session.
|
|
846
|
+
# Get the current one.
|
|
847
|
+
spark = get_spark_session()
|
|
848
|
+
for key, val in orig_settings.items():
|
|
849
|
+
spark.conf.set(key, val)
|
|
850
|
+
|
|
851
|
+
|
|
852
|
+
@contextmanager
|
|
853
|
+
def custom_time_zone(time_zone: str):
|
|
854
|
+
"""Apply a custom Spark time zone for the duration of a code block."""
|
|
855
|
+
orig_time_zone = get_current_time_zone()
|
|
856
|
+
try:
|
|
857
|
+
set_current_time_zone(time_zone)
|
|
858
|
+
yield
|
|
859
|
+
finally:
|
|
860
|
+
# Note that the user code could have restarted the session.
|
|
861
|
+
# This will function will get the current one.
|
|
862
|
+
set_current_time_zone(orig_time_zone)
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
@contextmanager
|
|
866
|
+
def restart_spark_with_custom_conf(conf: dict, force=False):
|
|
867
|
+
"""Restart the SparkSession with a custom configuration for the duration of a code block.
|
|
868
|
+
|
|
869
|
+
Parameters
|
|
870
|
+
----------
|
|
871
|
+
conf : dict
|
|
872
|
+
Key-value pairs to set on the spark configuration.
|
|
873
|
+
force : bool
|
|
874
|
+
If True, restart the session even if the config parameters haven't changed.
|
|
875
|
+
You might want to do this in order to clear cached tables or start Spark fresh.
|
|
876
|
+
"""
|
|
877
|
+
spark = get_duckdb_spark_session()
|
|
878
|
+
if spark is not None:
|
|
879
|
+
yield spark
|
|
880
|
+
return
|
|
881
|
+
|
|
882
|
+
spark = get_spark_session()
|
|
883
|
+
app_name = spark.conf.get("spark.app.name")
|
|
884
|
+
orig_settings = {}
|
|
885
|
+
|
|
886
|
+
try:
|
|
887
|
+
for name in conf:
|
|
888
|
+
current = spark.conf.get(name, None)
|
|
889
|
+
if current is not None:
|
|
890
|
+
orig_settings[name] = current
|
|
891
|
+
new_spark = restart_spark(name=app_name, spark_conf=conf, force=force)
|
|
892
|
+
yield new_spark
|
|
893
|
+
finally:
|
|
894
|
+
restart_spark(name=app_name, spark_conf=orig_settings, force=force)
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
@contextmanager
|
|
898
|
+
def set_session_time_zone(time_zone: str) -> Generator[None, None, None]:
|
|
899
|
+
"""Set the session time zone for execution of a code block."""
|
|
900
|
+
orig = get_current_time_zone()
|
|
901
|
+
|
|
902
|
+
try:
|
|
903
|
+
set_current_time_zone(time_zone)
|
|
904
|
+
yield
|
|
905
|
+
finally:
|
|
906
|
+
set_current_time_zone(orig)
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
def union(dfs: list[DataFrame]) -> DataFrame:
|
|
910
|
+
"""Return a union of the dataframes, ensuring that the columns match."""
|
|
911
|
+
df = dfs[0]
|
|
912
|
+
if len(dfs) > 1:
|
|
913
|
+
for dft in dfs[1:]:
|
|
914
|
+
if df.columns != dft.columns:
|
|
915
|
+
msg = f"columns don't match: {df.columns=} {dft.columns=}"
|
|
916
|
+
raise Exception(msg)
|
|
917
|
+
df = df.union(dft)
|
|
918
|
+
return df
|