dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- build_backend.py +93 -0
- dsgrid/__init__.py +22 -0
- dsgrid/api/__init__.py +0 -0
- dsgrid/api/api_manager.py +179 -0
- dsgrid/api/app.py +419 -0
- dsgrid/api/models.py +60 -0
- dsgrid/api/response_models.py +116 -0
- dsgrid/apps/__init__.py +0 -0
- dsgrid/apps/project_viewer/app.py +216 -0
- dsgrid/apps/registration_gui.py +444 -0
- dsgrid/chronify.py +32 -0
- dsgrid/cli/__init__.py +0 -0
- dsgrid/cli/common.py +120 -0
- dsgrid/cli/config.py +176 -0
- dsgrid/cli/download.py +13 -0
- dsgrid/cli/dsgrid.py +157 -0
- dsgrid/cli/dsgrid_admin.py +92 -0
- dsgrid/cli/install_notebooks.py +62 -0
- dsgrid/cli/query.py +729 -0
- dsgrid/cli/registry.py +1862 -0
- dsgrid/cloud/__init__.py +0 -0
- dsgrid/cloud/cloud_storage_interface.py +140 -0
- dsgrid/cloud/factory.py +31 -0
- dsgrid/cloud/fake_storage_interface.py +37 -0
- dsgrid/cloud/s3_storage_interface.py +156 -0
- dsgrid/common.py +36 -0
- dsgrid/config/__init__.py +0 -0
- dsgrid/config/annual_time_dimension_config.py +194 -0
- dsgrid/config/common.py +142 -0
- dsgrid/config/config_base.py +148 -0
- dsgrid/config/dataset_config.py +907 -0
- dsgrid/config/dataset_schema_handler_factory.py +46 -0
- dsgrid/config/date_time_dimension_config.py +136 -0
- dsgrid/config/dimension_config.py +54 -0
- dsgrid/config/dimension_config_factory.py +65 -0
- dsgrid/config/dimension_mapping_base.py +350 -0
- dsgrid/config/dimension_mappings_config.py +48 -0
- dsgrid/config/dimensions.py +1025 -0
- dsgrid/config/dimensions_config.py +71 -0
- dsgrid/config/file_schema.py +190 -0
- dsgrid/config/index_time_dimension_config.py +80 -0
- dsgrid/config/input_dataset_requirements.py +31 -0
- dsgrid/config/mapping_tables.py +209 -0
- dsgrid/config/noop_time_dimension_config.py +42 -0
- dsgrid/config/project_config.py +1462 -0
- dsgrid/config/registration_models.py +188 -0
- dsgrid/config/representative_period_time_dimension_config.py +194 -0
- dsgrid/config/simple_models.py +49 -0
- dsgrid/config/supplemental_dimension.py +29 -0
- dsgrid/config/time_dimension_base_config.py +192 -0
- dsgrid/data_models.py +155 -0
- dsgrid/dataset/__init__.py +0 -0
- dsgrid/dataset/dataset.py +123 -0
- dsgrid/dataset/dataset_expression_handler.py +86 -0
- dsgrid/dataset/dataset_mapping_manager.py +121 -0
- dsgrid/dataset/dataset_schema_handler_base.py +945 -0
- dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
- dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
- dsgrid/dataset/growth_rates.py +162 -0
- dsgrid/dataset/models.py +51 -0
- dsgrid/dataset/table_format_handler_base.py +257 -0
- dsgrid/dataset/table_format_handler_factory.py +17 -0
- dsgrid/dataset/unpivoted_table.py +121 -0
- dsgrid/dimension/__init__.py +0 -0
- dsgrid/dimension/base_models.py +230 -0
- dsgrid/dimension/dimension_filters.py +308 -0
- dsgrid/dimension/standard.py +252 -0
- dsgrid/dimension/time.py +352 -0
- dsgrid/dimension/time_utils.py +103 -0
- dsgrid/dsgrid_rc.py +88 -0
- dsgrid/exceptions.py +105 -0
- dsgrid/filesystem/__init__.py +0 -0
- dsgrid/filesystem/cloud_filesystem.py +32 -0
- dsgrid/filesystem/factory.py +32 -0
- dsgrid/filesystem/filesystem_interface.py +136 -0
- dsgrid/filesystem/local_filesystem.py +74 -0
- dsgrid/filesystem/s3_filesystem.py +118 -0
- dsgrid/loggers.py +132 -0
- dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
- dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
- dsgrid/notebooks/registration.ipynb +48 -0
- dsgrid/notebooks/start_notebook.sh +11 -0
- dsgrid/project.py +451 -0
- dsgrid/query/__init__.py +0 -0
- dsgrid/query/dataset_mapping_plan.py +142 -0
- dsgrid/query/derived_dataset.py +388 -0
- dsgrid/query/models.py +728 -0
- dsgrid/query/query_context.py +287 -0
- dsgrid/query/query_submitter.py +994 -0
- dsgrid/query/report_factory.py +19 -0
- dsgrid/query/report_peak_load.py +70 -0
- dsgrid/query/reports_base.py +20 -0
- dsgrid/registry/__init__.py +0 -0
- dsgrid/registry/bulk_register.py +165 -0
- dsgrid/registry/common.py +287 -0
- dsgrid/registry/config_update_checker_base.py +63 -0
- dsgrid/registry/data_store_factory.py +34 -0
- dsgrid/registry/data_store_interface.py +74 -0
- dsgrid/registry/dataset_config_generator.py +158 -0
- dsgrid/registry/dataset_registry_manager.py +950 -0
- dsgrid/registry/dataset_update_checker.py +16 -0
- dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
- dsgrid/registry/dimension_mapping_update_checker.py +16 -0
- dsgrid/registry/dimension_registry_manager.py +413 -0
- dsgrid/registry/dimension_update_checker.py +16 -0
- dsgrid/registry/duckdb_data_store.py +207 -0
- dsgrid/registry/filesystem_data_store.py +150 -0
- dsgrid/registry/filter_registry_manager.py +123 -0
- dsgrid/registry/project_config_generator.py +57 -0
- dsgrid/registry/project_registry_manager.py +1623 -0
- dsgrid/registry/project_update_checker.py +48 -0
- dsgrid/registry/registration_context.py +223 -0
- dsgrid/registry/registry_auto_updater.py +316 -0
- dsgrid/registry/registry_database.py +667 -0
- dsgrid/registry/registry_interface.py +446 -0
- dsgrid/registry/registry_manager.py +558 -0
- dsgrid/registry/registry_manager_base.py +367 -0
- dsgrid/registry/versioning.py +92 -0
- dsgrid/rust_ext/__init__.py +14 -0
- dsgrid/rust_ext/find_minimal_patterns.py +129 -0
- dsgrid/spark/__init__.py +0 -0
- dsgrid/spark/functions.py +589 -0
- dsgrid/spark/types.py +110 -0
- dsgrid/tests/__init__.py +0 -0
- dsgrid/tests/common.py +140 -0
- dsgrid/tests/make_us_data_registry.py +265 -0
- dsgrid/tests/register_derived_datasets.py +103 -0
- dsgrid/tests/utils.py +25 -0
- dsgrid/time/__init__.py +0 -0
- dsgrid/time/time_conversions.py +80 -0
- dsgrid/time/types.py +67 -0
- dsgrid/units/__init__.py +0 -0
- dsgrid/units/constants.py +113 -0
- dsgrid/units/convert.py +71 -0
- dsgrid/units/energy.py +145 -0
- dsgrid/units/power.py +87 -0
- dsgrid/utils/__init__.py +0 -0
- dsgrid/utils/dataset.py +830 -0
- dsgrid/utils/files.py +179 -0
- dsgrid/utils/filters.py +125 -0
- dsgrid/utils/id_remappings.py +100 -0
- dsgrid/utils/py_expression_eval/LICENSE +19 -0
- dsgrid/utils/py_expression_eval/README.md +8 -0
- dsgrid/utils/py_expression_eval/__init__.py +847 -0
- dsgrid/utils/py_expression_eval/tests.py +283 -0
- dsgrid/utils/run_command.py +70 -0
- dsgrid/utils/scratch_dir_context.py +65 -0
- dsgrid/utils/spark.py +918 -0
- dsgrid/utils/spark_partition.py +98 -0
- dsgrid/utils/timing.py +239 -0
- dsgrid/utils/utilities.py +221 -0
- dsgrid/utils/versioning.py +36 -0
- dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
- dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
- dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
- dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
- dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,589 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from tempfile import NamedTemporaryFile
|
|
6
|
+
from typing import Any, Iterable
|
|
7
|
+
from uuid import uuid4
|
|
8
|
+
from zoneinfo import ZoneInfo
|
|
9
|
+
|
|
10
|
+
import duckdb
|
|
11
|
+
|
|
12
|
+
import dsgrid
|
|
13
|
+
from dsgrid.dsgrid_rc import DsgridRuntimeConfig
|
|
14
|
+
from dsgrid.exceptions import DSGInvalidDataset
|
|
15
|
+
from dsgrid.loggers import disable_console_logging
|
|
16
|
+
from dsgrid.spark.types import (
|
|
17
|
+
DataFrame,
|
|
18
|
+
F,
|
|
19
|
+
SparkConf,
|
|
20
|
+
SparkSession,
|
|
21
|
+
use_duckdb,
|
|
22
|
+
)
|
|
23
|
+
from dsgrid.utils.files import load_line_delimited_json, dump_data
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
if use_duckdb():
|
|
29
|
+
g_duckdb_spark = SparkSession.builder.getOrCreate()
|
|
30
|
+
else:
|
|
31
|
+
g_duckdb_spark = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
TEMP_TABLE_PREFIX = "tmp_dsgrid"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def aggregate(df: DataFrame, agg_func: str, column: str, alias: str) -> DataFrame:
|
|
38
|
+
"""Run an aggregate function on the dataframe."""
|
|
39
|
+
if use_duckdb():
|
|
40
|
+
relation = df.relation.aggregate(f"{agg_func}({column}) as {alias}")
|
|
41
|
+
return DataFrame(relation.set_alias(make_temp_view_name()), df.session)
|
|
42
|
+
return df.agg(getattr(F, agg_func)(column).alias(alias))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def aggregate_single_value(df: DataFrame, agg_func: str, column: str) -> Any:
|
|
46
|
+
"""Run an aggregate function on the dataframe that will produce a single value, such as max.
|
|
47
|
+
Return that single value.
|
|
48
|
+
"""
|
|
49
|
+
alias = "__tmp__"
|
|
50
|
+
if use_duckdb():
|
|
51
|
+
return df.relation.aggregate(f"{agg_func}({column}) as {alias}").df().values[0][0]
|
|
52
|
+
return df.agg(getattr(F, agg_func)(column).alias(alias)).collect()[0][alias]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def cache(df: DataFrame) -> DataFrame:
|
|
56
|
+
"""Cache the dataframe. This is a no-op for DuckDB."""
|
|
57
|
+
if use_duckdb():
|
|
58
|
+
return df
|
|
59
|
+
return df.cache()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def unpersist(df: DataFrame) -> None:
|
|
63
|
+
"""Unpersist a dataframe that was previously cached. This is a no-op for DuckDB."""
|
|
64
|
+
if not use_duckdb():
|
|
65
|
+
df.unpersist()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def coalesce(df: DataFrame, num_partitions: int) -> DataFrame:
|
|
69
|
+
"""Coalesce the dataframe into num_partitions partitions. This is a no-op for DuckDB."""
|
|
70
|
+
if use_duckdb():
|
|
71
|
+
return df
|
|
72
|
+
return df.coalesce(num_partitions)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def collect_list(df: DataFrame, column: str) -> list:
|
|
76
|
+
"""Collect the dataframe column into a list."""
|
|
77
|
+
if use_duckdb():
|
|
78
|
+
return [x[column] for x in df.collect()]
|
|
79
|
+
|
|
80
|
+
return next(iter(df.select(F.collect_list(column)).first()))
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def count_distinct_on_group_by(
|
|
84
|
+
df: DataFrame, group_by_columns: list[str], agg_column: str, alias: str
|
|
85
|
+
) -> DataFrame:
|
|
86
|
+
"""Perform a count distinct on one column after grouping."""
|
|
87
|
+
if use_duckdb():
|
|
88
|
+
view = create_temp_view(df)
|
|
89
|
+
cols = ",".join([f'"{x}"' for x in group_by_columns])
|
|
90
|
+
query = f"""
|
|
91
|
+
SELECT {cols}, COUNT(DISTINCT "{agg_column}") AS "{alias}"
|
|
92
|
+
FROM {view}
|
|
93
|
+
GROUP BY {cols}
|
|
94
|
+
"""
|
|
95
|
+
return get_spark_session().sql(query)
|
|
96
|
+
|
|
97
|
+
return df.groupBy(*group_by_columns).agg(F.count_distinct(agg_column).alias(alias))
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def create_temp_view(df: DataFrame) -> str:
|
|
101
|
+
"""Create a temporary view with a random name and return the name."""
|
|
102
|
+
view1 = make_temp_view_name()
|
|
103
|
+
df.createOrReplaceTempView(view1)
|
|
104
|
+
return view1
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def make_temp_view_name() -> str:
|
|
108
|
+
"""Make a random name to be used as a view."""
|
|
109
|
+
return f"{TEMP_TABLE_PREFIX}_{uuid4().hex}"
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def drop_temp_tables_and_views() -> None:
|
|
113
|
+
"""Drop all temporary views and tables."""
|
|
114
|
+
drop_temp_views()
|
|
115
|
+
drop_temp_tables()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def drop_temp_tables() -> None:
|
|
119
|
+
"""Drop all temporary tables."""
|
|
120
|
+
spark = get_spark_session()
|
|
121
|
+
if use_duckdb():
|
|
122
|
+
query = f"SELECT * FROM pg_tables WHERE tablename LIKE '%{TEMP_TABLE_PREFIX}%'"
|
|
123
|
+
for row in spark.sql(query).collect():
|
|
124
|
+
spark.sql(f"DROP TABLE {row.tablename}")
|
|
125
|
+
else:
|
|
126
|
+
for row in spark.sql(f"SHOW TABLES LIKE '*{TEMP_TABLE_PREFIX}*'").collect():
|
|
127
|
+
spark.sql(f"DROP TABLE {row.tableName}")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def drop_temp_views() -> None:
|
|
131
|
+
"""Drop all temporary views."""
|
|
132
|
+
spark = get_spark_session()
|
|
133
|
+
if use_duckdb():
|
|
134
|
+
query = f"""
|
|
135
|
+
SELECT view_name FROM duckdb_views()
|
|
136
|
+
WHERE NOT internal AND view_name LIKE '%{TEMP_TABLE_PREFIX}%'
|
|
137
|
+
"""
|
|
138
|
+
for row in spark.sql(query).collect():
|
|
139
|
+
spark.sql(f"DROP VIEW {row.view_name}")
|
|
140
|
+
else:
|
|
141
|
+
for row in spark.sql(f"SHOW VIEWS LIKE '*{TEMP_TABLE_PREFIX}*'").collect():
|
|
142
|
+
spark.sql(f"DROP VIEW {row.viewName}")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def cross_join(df1: DataFrame, df2: DataFrame) -> DataFrame:
|
|
146
|
+
"""Return a cross join of the two dataframes."""
|
|
147
|
+
if use_duckdb():
|
|
148
|
+
view1 = create_temp_view(df1)
|
|
149
|
+
view2 = create_temp_view(df2)
|
|
150
|
+
spark = get_spark_session()
|
|
151
|
+
return spark.sql(f"SELECT * from {view1} CROSS JOIN {view2}")
|
|
152
|
+
|
|
153
|
+
return df1.crossJoin(df2)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def except_all(df1: DataFrame, df2: DataFrame) -> DataFrame:
|
|
157
|
+
"""Return a dataframe with all rows in df1 that are not in df2."""
|
|
158
|
+
method = _except_all_duckdb if use_duckdb() else _except_all_spark
|
|
159
|
+
return method(df1, df2)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _except_all_duckdb(df1: DataFrame, df2: DataFrame) -> DataFrame:
|
|
163
|
+
view1 = create_temp_view(df1)
|
|
164
|
+
view2 = create_temp_view(df2)
|
|
165
|
+
query = f"""
|
|
166
|
+
SELECT * FROM {view1}
|
|
167
|
+
EXCEPT ALL
|
|
168
|
+
SELECT * FROM {view2}
|
|
169
|
+
"""
|
|
170
|
+
spark = get_spark_session()
|
|
171
|
+
return spark.sql(query)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _except_all_spark(df1: DataFrame, df2: DataFrame) -> DataFrame:
|
|
175
|
+
return df1.exceptAll(df2)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def handle_column_spaces(column: str) -> str:
|
|
179
|
+
"""Return a column string suitable for the backend."""
|
|
180
|
+
if use_duckdb():
|
|
181
|
+
return f'"{column}"'
|
|
182
|
+
return f"`{column}`"
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def intersect(df1: DataFrame, df2: DataFrame) -> DataFrame:
|
|
186
|
+
"""Return an intersection of rows. Duplicates are not returned"""
|
|
187
|
+
# Could add intersect all if duplicated are needed.
|
|
188
|
+
method = _intersect_duckdb if use_duckdb() else _intersect_spark
|
|
189
|
+
return method(df1, df2)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _intersect_duckdb(df1: DataFrame, df2: DataFrame) -> DataFrame:
|
|
193
|
+
view1 = create_temp_view(df1)
|
|
194
|
+
view2 = create_temp_view(df2)
|
|
195
|
+
query = f"""
|
|
196
|
+
SELECT * FROM {view1}
|
|
197
|
+
INTERSECT
|
|
198
|
+
SELECT * FROM {view2}
|
|
199
|
+
"""
|
|
200
|
+
spark = get_spark_session()
|
|
201
|
+
return spark.sql(query)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _intersect_spark(df1: DataFrame, df2: DataFrame) -> DataFrame:
|
|
205
|
+
return df1.intersect(df2)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def get_duckdb_spark_session() -> SparkSession | None:
|
|
209
|
+
"""Return the active DuckDB Spark Session if it is set."""
|
|
210
|
+
return g_duckdb_spark
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def get_spark_session() -> SparkSession:
|
|
214
|
+
"""Return the active SparkSession or create a new one is none is active."""
|
|
215
|
+
spark = get_duckdb_spark_session()
|
|
216
|
+
if spark is not None:
|
|
217
|
+
return spark
|
|
218
|
+
|
|
219
|
+
spark = SparkSession.getActiveSession()
|
|
220
|
+
if spark is None:
|
|
221
|
+
logger.warning("Could not find a SparkSession. Create a new one.")
|
|
222
|
+
spark = SparkSession.builder.getOrCreate()
|
|
223
|
+
log_spark_conf(spark)
|
|
224
|
+
return spark
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def get_spark_warehouse_dir() -> Path:
|
|
228
|
+
"""Return the Spark warehouse directory. Not valid with DuckDB."""
|
|
229
|
+
assert not use_duckdb()
|
|
230
|
+
val = get_spark_session().conf.get("spark.sql.warehouse.dir")
|
|
231
|
+
assert isinstance(val, str)
|
|
232
|
+
if not val:
|
|
233
|
+
msg = "Bug: spark.sql.warehouse.dir is not set"
|
|
234
|
+
raise Exception(msg)
|
|
235
|
+
if not val.startswith("file:"):
|
|
236
|
+
msg = f"get_spark_warehouse_dir only supports local file paths currently: {val}"
|
|
237
|
+
raise NotImplementedError(msg)
|
|
238
|
+
return Path(val.split("file:")[1])
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def get_current_time_zone() -> str:
|
|
242
|
+
"""Return the current time zone."""
|
|
243
|
+
spark = get_spark_session()
|
|
244
|
+
if use_duckdb():
|
|
245
|
+
res = spark.sql("SELECT * FROM duckdb_settings() WHERE name = 'TimeZone'").collect()
|
|
246
|
+
assert len(res) == 1
|
|
247
|
+
return res[0].value
|
|
248
|
+
|
|
249
|
+
tz = spark.conf.get("spark.sql.session.timeZone")
|
|
250
|
+
assert tz is not None
|
|
251
|
+
return tz
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def set_current_time_zone(time_zone: str) -> None:
|
|
255
|
+
"""Set the current time zone."""
|
|
256
|
+
spark = get_spark_session()
|
|
257
|
+
if use_duckdb():
|
|
258
|
+
spark.sql(f"SET TimeZone='{time_zone}'")
|
|
259
|
+
return
|
|
260
|
+
|
|
261
|
+
spark.conf.set("spark.sql.session.timeZone", time_zone)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def init_spark(name="dsgrid", check_env=True, spark_conf=None) -> SparkSession:
|
|
265
|
+
"""Initialize a SparkSession.
|
|
266
|
+
|
|
267
|
+
Parameters
|
|
268
|
+
----------
|
|
269
|
+
name : str
|
|
270
|
+
check_env : bool
|
|
271
|
+
If True, which is default, check for the SPARK_CLUSTER environment variable and attach to
|
|
272
|
+
it. Otherwise, create a local-mode cluster or attach to the SparkSession that was created
|
|
273
|
+
by pyspark/spark-submit prior to starting the current process.
|
|
274
|
+
spark_conf : dict | None, defaults to None
|
|
275
|
+
If set, Spark configuration parameters
|
|
276
|
+
|
|
277
|
+
"""
|
|
278
|
+
if use_duckdb():
|
|
279
|
+
logger.info("Using DuckDB as the backend engine.")
|
|
280
|
+
return g_duckdb_spark
|
|
281
|
+
|
|
282
|
+
logger.info("Using Spark as the backend engine.")
|
|
283
|
+
cluster = os.environ.get("SPARK_CLUSTER")
|
|
284
|
+
conf = SparkConf().setAppName(name)
|
|
285
|
+
if spark_conf is not None:
|
|
286
|
+
for key, val in spark_conf.items():
|
|
287
|
+
conf.set(key, val)
|
|
288
|
+
|
|
289
|
+
out_ts_type = conf.get("spark.sql.parquet.outputTimestampType")
|
|
290
|
+
if out_ts_type is None:
|
|
291
|
+
conf.set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS")
|
|
292
|
+
elif out_ts_type != "TIMESTAMP_MICROS":
|
|
293
|
+
logger.warning(
|
|
294
|
+
"spark.sql.parquet.outputTimestampType is set to %s. Writing parquet files may "
|
|
295
|
+
"produced undesired results.",
|
|
296
|
+
out_ts_type,
|
|
297
|
+
)
|
|
298
|
+
conf.set("spark.sql.legacy.parquet.nanosAsLong", "true")
|
|
299
|
+
|
|
300
|
+
if check_env and cluster is not None:
|
|
301
|
+
logger.info("Create SparkSession %s on existing cluster %s", name, cluster)
|
|
302
|
+
conf.setMaster(cluster)
|
|
303
|
+
|
|
304
|
+
config = SparkSession.builder.config(conf=conf)
|
|
305
|
+
if dsgrid.runtime_config.use_hive_metastore:
|
|
306
|
+
config = config.enableHiveSupport()
|
|
307
|
+
spark = config.getOrCreate()
|
|
308
|
+
|
|
309
|
+
with disable_console_logging():
|
|
310
|
+
log_spark_conf(spark)
|
|
311
|
+
logger.info("Custom configuration settings: %s", spark_conf)
|
|
312
|
+
|
|
313
|
+
return spark
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def is_dataframe_empty(df: DataFrame) -> bool:
|
|
317
|
+
"""Return True if the DataFrame is empty."""
|
|
318
|
+
if use_duckdb():
|
|
319
|
+
view = create_temp_view(df)
|
|
320
|
+
spark = get_spark_session()
|
|
321
|
+
col = df.columns[0]
|
|
322
|
+
return spark.sql(f'SELECT "{col}" FROM {view} LIMIT 1').count() == 0
|
|
323
|
+
return df.rdd.isEmpty()
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def perform_interval_op(
|
|
327
|
+
df: DataFrame, time_column, op: str, val: Any, unit: str, alias: str
|
|
328
|
+
) -> DataFrame:
|
|
329
|
+
"""Perform an interval operation ('-' or '+') on a time column."""
|
|
330
|
+
if use_duckdb():
|
|
331
|
+
view = create_temp_view(df)
|
|
332
|
+
cols = df.columns[:]
|
|
333
|
+
if alias == time_column:
|
|
334
|
+
cols.remove(time_column)
|
|
335
|
+
cols_str = ",".join([f'"{x}"' for x in cols])
|
|
336
|
+
query = (
|
|
337
|
+
f'SELECT "{time_column}" {op} INTERVAL {val} {unit} AS {alias}, {cols_str} from {view}'
|
|
338
|
+
)
|
|
339
|
+
return get_spark_session().sql(query)
|
|
340
|
+
|
|
341
|
+
interval_expr = F.expr(f"INTERVAL {val} SECONDS")
|
|
342
|
+
match op:
|
|
343
|
+
case "-":
|
|
344
|
+
expr = F.col(time_column) - interval_expr
|
|
345
|
+
case "+":
|
|
346
|
+
expr = F.col(time_column) + interval_expr
|
|
347
|
+
case _:
|
|
348
|
+
msg = f"{op=} is not supported"
|
|
349
|
+
raise NotImplementedError(msg)
|
|
350
|
+
return df.withColumn(alias, expr)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def join(df1: DataFrame, df2: DataFrame, column1: str, column2: str, how="inner") -> DataFrame:
|
|
354
|
+
"""Join two dataframes on one column. Use this method whenever the result may be joined
|
|
355
|
+
with another dataframe in order to workaround a DuckDB issue.
|
|
356
|
+
"""
|
|
357
|
+
df = df1.join(df2, on=df1[column1] == df2[column2], how=how)
|
|
358
|
+
if use_duckdb():
|
|
359
|
+
# DuckDB sets the relation alias to "relation", which causes problems with future
|
|
360
|
+
# joins. They declined to address this in https://github.com/duckdb/duckdb/issues/12959
|
|
361
|
+
df.relation = df.relation.set_alias(f"relation_{uuid4()}")
|
|
362
|
+
|
|
363
|
+
return df
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def join_multiple_columns(
|
|
367
|
+
df1: DataFrame, df2: DataFrame, columns: list[str], how="inner"
|
|
368
|
+
) -> DataFrame:
|
|
369
|
+
"""Join two dataframes on multiple columns."""
|
|
370
|
+
if use_duckdb():
|
|
371
|
+
view1 = create_temp_view(df1)
|
|
372
|
+
view2 = create_temp_view(df2)
|
|
373
|
+
view2_columns = ",".join((f'{view2}."{x}"' for x in df2.columns if x not in columns))
|
|
374
|
+
on_str = " AND ".join((f'{view1}."{x}" = {view2}."{x}"' for x in columns))
|
|
375
|
+
query = f"""
|
|
376
|
+
SELECT {view1}.*, {view2_columns}
|
|
377
|
+
FROM {view1}
|
|
378
|
+
{how} JOIN {view2}
|
|
379
|
+
ON {on_str}
|
|
380
|
+
"""
|
|
381
|
+
# This does not have the alias="relation" issue discussed above.
|
|
382
|
+
return get_spark_session().sql(query)
|
|
383
|
+
|
|
384
|
+
return df1.join(df2, columns, how=how)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def log_spark_conf(spark: SparkSession):
|
|
388
|
+
"""Log the Spark configuration details."""
|
|
389
|
+
if not use_duckdb():
|
|
390
|
+
conf = spark.sparkContext.getConf().getAll()
|
|
391
|
+
conf.sort(key=lambda x: x[0])
|
|
392
|
+
logger.info("Spark conf: %s", "\n".join([f"{x} = {y}" for x, y in conf]))
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def prepare_timestamps_for_dataframe(timestamps: Iterable[datetime]) -> Iterable[datetime]:
|
|
396
|
+
"""Apply necessary conversions of the timestamps for dataframe creation."""
|
|
397
|
+
if use_duckdb():
|
|
398
|
+
return [x.astimezone(ZoneInfo("UTC")) for x in timestamps]
|
|
399
|
+
return timestamps
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def read_csv(path: Path | str, schema: dict[str, str] | None = None) -> DataFrame:
|
|
403
|
+
"""Return a DataFrame from a CSV file, handling special cases with duckdb."""
|
|
404
|
+
func = read_csv_duckdb if use_duckdb() else _read_csv_spark
|
|
405
|
+
df = func(path, schema)
|
|
406
|
+
if schema is not None:
|
|
407
|
+
if set(df.columns).symmetric_difference(schema.keys()):
|
|
408
|
+
msg = (
|
|
409
|
+
f"Mismatch in CSV schema ({sorted(schema.keys())}) "
|
|
410
|
+
f"vs DataFrame columns ({df.columns})"
|
|
411
|
+
)
|
|
412
|
+
raise DSGInvalidDataset(msg)
|
|
413
|
+
|
|
414
|
+
return df
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def _read_csv_spark(path: Path | str, schema: dict[str, str] | None) -> DataFrame:
|
|
418
|
+
spark = get_spark_session()
|
|
419
|
+
if schema is None:
|
|
420
|
+
return spark.read.csv(str(path), header=True, inferSchema=True)
|
|
421
|
+
|
|
422
|
+
schema_str = ",".join([f"{key} {val}" for key, val in schema.items()])
|
|
423
|
+
return spark.read.csv(str(path), header=True, schema=schema_str)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def read_csv_duckdb(path_or_str: Path | str, schema: dict[str, str] | None) -> DataFrame:
|
|
427
|
+
"""Read a CSV file using DuckDB and return a Spark DataFrame.
|
|
428
|
+
|
|
429
|
+
Parameters
|
|
430
|
+
----------
|
|
431
|
+
path_or_str : Path | str
|
|
432
|
+
Path to the CSV file or directory containing CSV files.
|
|
433
|
+
schema : dict[str, str] | None
|
|
434
|
+
Mapping of column names to DuckDB data types.
|
|
435
|
+
"""
|
|
436
|
+
path = Path(path_or_str)
|
|
437
|
+
if path.is_dir():
|
|
438
|
+
path_str = str(path) + "**/*.csv"
|
|
439
|
+
else:
|
|
440
|
+
path_str = str(path)
|
|
441
|
+
|
|
442
|
+
spark = get_spark_session()
|
|
443
|
+
if not schema:
|
|
444
|
+
return spark.read.csv(path_str, header=True)
|
|
445
|
+
|
|
446
|
+
dtypes = {k: duckdb.type(v) for k, v in schema.items()}
|
|
447
|
+
rel = duckdb.read_csv(path_str, header=True, dtype=dtypes)
|
|
448
|
+
if use_duckdb():
|
|
449
|
+
return spark.createDataFrame(rel.to_df())
|
|
450
|
+
|
|
451
|
+
# DT 12/1/2025
|
|
452
|
+
# This obnoxious code block provides the only way I've found to read a CSV file into Spark
|
|
453
|
+
# while allowing these behaviors:
|
|
454
|
+
# - Preserve NULL values. DuckDB -> Pandas -> Spark converts NULLs to NaNs.
|
|
455
|
+
# - Allow the user to specify a subset of columns with data types. The native Spark CSV
|
|
456
|
+
# reader will drop columns not specified in the schema.
|
|
457
|
+
# This shouldn't matter much because Spark + CSV should never happen with large datasets.
|
|
458
|
+
scratch_dir = DsgridRuntimeConfig().get_scratch_dir()
|
|
459
|
+
with NamedTemporaryFile(suffix=".parquet", dir=scratch_dir) as f:
|
|
460
|
+
f.close()
|
|
461
|
+
rel.write_parquet(f.name)
|
|
462
|
+
df = spark.read.parquet(f.name)
|
|
463
|
+
# Bring the entire table into memory so that we can delete the file.
|
|
464
|
+
df.cache()
|
|
465
|
+
df.count()
|
|
466
|
+
return df
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def read_json(path: Path | str) -> DataFrame:
|
|
470
|
+
"""Return a DataFrame from a JSON file, handling special cases with duckdb.
|
|
471
|
+
|
|
472
|
+
Warning: Use of this function with DuckDB is not efficient because it requires that we
|
|
473
|
+
convert line-delimited JSON to standard JSON.
|
|
474
|
+
"""
|
|
475
|
+
spark = get_spark_session()
|
|
476
|
+
filename = str(path)
|
|
477
|
+
if use_duckdb():
|
|
478
|
+
with NamedTemporaryFile(suffix=".json") as f:
|
|
479
|
+
f.close()
|
|
480
|
+
# TODO duckdb: look for something more efficient. Not a big deal right now.
|
|
481
|
+
data = load_line_delimited_json(path)
|
|
482
|
+
dump_data(data, f.name)
|
|
483
|
+
return spark.read.json(f.name)
|
|
484
|
+
return spark.read.json(filename, mode="FAILFAST")
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def read_parquet(path: Path | str) -> DataFrame:
|
|
488
|
+
path = Path(path) if isinstance(path, str) else path
|
|
489
|
+
spark = get_spark_session()
|
|
490
|
+
if path.is_file() or not use_duckdb():
|
|
491
|
+
df = spark.read.parquet(str(path))
|
|
492
|
+
else:
|
|
493
|
+
df = spark.read.parquet(f"{path}/**/*.parquet")
|
|
494
|
+
return df
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def select_expr(df: DataFrame, exprs: list[str]) -> DataFrame:
|
|
498
|
+
"""Execute the SQL SELECT expression. It is the caller's responsibility to handle column
|
|
499
|
+
names with spaces or special characters.
|
|
500
|
+
"""
|
|
501
|
+
if use_duckdb():
|
|
502
|
+
view = create_temp_view(df)
|
|
503
|
+
spark = get_spark_session()
|
|
504
|
+
cols = ",".join(exprs)
|
|
505
|
+
return spark.sql(f"SELECT {cols} FROM {view}")
|
|
506
|
+
return df.selectExpr(*exprs)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def sql_from_df(df: DataFrame, query: str) -> DataFrame:
|
|
510
|
+
"""Run a SQL query on a dataframe with Spark."""
|
|
511
|
+
logger.debug("Run SQL query [%s]", query)
|
|
512
|
+
spark = get_spark_session()
|
|
513
|
+
if use_duckdb():
|
|
514
|
+
view = create_temp_view(df)
|
|
515
|
+
query += f" FROM {view}"
|
|
516
|
+
return spark.sql(query)
|
|
517
|
+
|
|
518
|
+
query += " FROM {df}"
|
|
519
|
+
return spark.sql(query, df=df)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def pivot(df: DataFrame, name_column: str, value_column: str) -> DataFrame:
|
|
523
|
+
"""Unpivot the dataframe."""
|
|
524
|
+
method = _pivot_duckdb if use_duckdb() else _pivot_spark
|
|
525
|
+
return method(df, name_column, value_column)
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def _pivot_duckdb(df: DataFrame, name_column: str, value_column: str) -> DataFrame:
|
|
529
|
+
view = create_temp_view(df)
|
|
530
|
+
query = f"""
|
|
531
|
+
PIVOT {view}
|
|
532
|
+
ON "{name_column}"
|
|
533
|
+
USING SUM({value_column})
|
|
534
|
+
"""
|
|
535
|
+
return get_spark_session().sql(query)
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def _pivot_spark(df: DataFrame, name_column: str, value_column: str) -> DataFrame:
|
|
539
|
+
ids = [x for x in df.columns if x not in {name_column, value_column}]
|
|
540
|
+
return df.groupBy(*ids).pivot(name_column).sum(value_column)
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def unpivot(df: DataFrame, pivoted_columns, name_column: str, value_column: str) -> DataFrame:
|
|
544
|
+
"""Unpivot the dataframe."""
|
|
545
|
+
method = _unpivot_duckdb if use_duckdb() else _unpivot_spark
|
|
546
|
+
return method(df, pivoted_columns, name_column, value_column)
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def _unpivot_duckdb(
|
|
550
|
+
df: DataFrame, pivoted_columns, name_column: str, value_column: str
|
|
551
|
+
) -> DataFrame:
|
|
552
|
+
view = create_temp_view(df)
|
|
553
|
+
cols = ",".join([f'"{x}"' for x in pivoted_columns])
|
|
554
|
+
query = f"""
|
|
555
|
+
SELECT * FROM {view}
|
|
556
|
+
UNPIVOT INCLUDE NULLS (
|
|
557
|
+
"{value_column}"
|
|
558
|
+
FOR "{name_column}" in ({cols})
|
|
559
|
+
)
|
|
560
|
+
"""
|
|
561
|
+
spark = get_spark_session()
|
|
562
|
+
df = spark.sql(query)
|
|
563
|
+
return df
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def _unpivot_spark(
|
|
567
|
+
df: DataFrame, pivoted_columns, name_column: str, value_column: str
|
|
568
|
+
) -> DataFrame:
|
|
569
|
+
ids = list(set(df.columns) - {value_column, *pivoted_columns})
|
|
570
|
+
return df.unpivot(
|
|
571
|
+
ids,
|
|
572
|
+
pivoted_columns,
|
|
573
|
+
name_column,
|
|
574
|
+
value_column,
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def write_csv(
|
|
579
|
+
df: DataFrame, path: Path | str, header: bool = True, overwrite: bool = False
|
|
580
|
+
) -> None:
|
|
581
|
+
"""Write a DataFrame to a CSV file, handling special cases with duckdb."""
|
|
582
|
+
path_str = path if isinstance(path, str) else str(path)
|
|
583
|
+
if use_duckdb():
|
|
584
|
+
df.relation.write_csv(path_str, header=header, overwrite=overwrite)
|
|
585
|
+
else:
|
|
586
|
+
if overwrite:
|
|
587
|
+
df.write.options(header=True).mode("overwrite").csv(path_str)
|
|
588
|
+
else:
|
|
589
|
+
df.write.options(header=True).csv(path_str)
|
dsgrid/spark/types.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# flake8: noqa
|
|
2
|
+
|
|
3
|
+
import dsgrid
|
|
4
|
+
from dsgrid.common import BackendEngine
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def use_duckdb() -> bool:
|
|
8
|
+
"""Return True if the environment is set to use DuckDB instead of Spark."""
|
|
9
|
+
return dsgrid.runtime_config.backend_engine == BackendEngine.DUCKDB
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
if use_duckdb():
|
|
13
|
+
import duckdb.experimental.spark.sql.functions as F
|
|
14
|
+
from duckdb.experimental.spark.conf import SparkConf
|
|
15
|
+
from duckdb.experimental.spark.sql import DataFrame, SparkSession
|
|
16
|
+
from duckdb.experimental.spark.sql.types import (
|
|
17
|
+
ByteType,
|
|
18
|
+
StructField,
|
|
19
|
+
StructType,
|
|
20
|
+
StringType,
|
|
21
|
+
BooleanType,
|
|
22
|
+
IntegerType,
|
|
23
|
+
ShortType,
|
|
24
|
+
LongType,
|
|
25
|
+
DoubleType,
|
|
26
|
+
FloatType,
|
|
27
|
+
TimestampType,
|
|
28
|
+
TimestampNTZType,
|
|
29
|
+
Row,
|
|
30
|
+
)
|
|
31
|
+
from duckdb.experimental.spark.errors import AnalysisException
|
|
32
|
+
else:
|
|
33
|
+
import pyspark.sql.functions as F
|
|
34
|
+
from pyspark.sql import DataFrame, Row, SparkSession
|
|
35
|
+
from pyspark.sql.types import (
|
|
36
|
+
ByteType,
|
|
37
|
+
FloatType,
|
|
38
|
+
StructType,
|
|
39
|
+
StructField,
|
|
40
|
+
StringType,
|
|
41
|
+
DoubleType,
|
|
42
|
+
IntegerType,
|
|
43
|
+
LongType,
|
|
44
|
+
ShortType,
|
|
45
|
+
BooleanType,
|
|
46
|
+
TimestampType,
|
|
47
|
+
TimestampNTZType,
|
|
48
|
+
)
|
|
49
|
+
from pyspark.errors import AnalysisException
|
|
50
|
+
from pyspark import SparkConf
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
SUPPORTED_TYPES = set(
|
|
54
|
+
(
|
|
55
|
+
"BOOLEAN",
|
|
56
|
+
"INT",
|
|
57
|
+
"INTEGER",
|
|
58
|
+
"TINYINT",
|
|
59
|
+
"SMALLINT",
|
|
60
|
+
"BIGINT",
|
|
61
|
+
"FLOAT",
|
|
62
|
+
"DOUBLE",
|
|
63
|
+
"TIMESTAMP_TZ",
|
|
64
|
+
"TIMESTAMP_NTZ",
|
|
65
|
+
"STRING",
|
|
66
|
+
"TEXT",
|
|
67
|
+
"VARCHAR",
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
DUCKDB_COLUMN_TYPES = {
|
|
72
|
+
"BOOLEAN": "BOOLEAN",
|
|
73
|
+
"INT": "INTEGER",
|
|
74
|
+
"INTEGER": "INTEGER",
|
|
75
|
+
"TINYINT": "TINYINT",
|
|
76
|
+
"SMALLINT": "INTEGER",
|
|
77
|
+
"BIGINT": "BIGINT",
|
|
78
|
+
"FLOAT": "FLOAT",
|
|
79
|
+
"DOUBLE": "DOUBLE",
|
|
80
|
+
"TIMESTAMP_TZ": "TIMESTAMP WITH TIME ZONE",
|
|
81
|
+
"TIMESTAMP_NTZ": "TIMESTAMP",
|
|
82
|
+
"STRING": "VARCHAR",
|
|
83
|
+
"TEXT": "VARCHAR",
|
|
84
|
+
"VARCHAR": "VARCHAR",
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
SPARK_COLUMN_TYPES = {
|
|
88
|
+
"BOOLEAN": "BOOLEAN",
|
|
89
|
+
"INT": "INT",
|
|
90
|
+
"INTEGER": "INT",
|
|
91
|
+
"TINYINT": "TINYINT",
|
|
92
|
+
"SMALLINT": "SMALLINT",
|
|
93
|
+
"BIGINT": "BIGINT",
|
|
94
|
+
"FLOAT": "FLOAT",
|
|
95
|
+
"DOUBLE": "DOUBLE",
|
|
96
|
+
"STRING": "STRING",
|
|
97
|
+
"TEXT": "STRING",
|
|
98
|
+
"VARCHAR": "STRING",
|
|
99
|
+
"TIMESTAMP_TZ": "TIMESTAMP",
|
|
100
|
+
"TIMESTAMP_NTZ": "TIMESTAMP_NTZ",
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
assert sorted(DUCKDB_COLUMN_TYPES.keys()) == sorted(SPARK_COLUMN_TYPES.keys())
|
|
104
|
+
assert not SUPPORTED_TYPES.difference(DUCKDB_COLUMN_TYPES.keys())
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_str_type() -> str:
|
|
108
|
+
"""Return the string type used by the current database system."""
|
|
109
|
+
types = DUCKDB_COLUMN_TYPES if use_duckdb() else SPARK_COLUMN_TYPES
|
|
110
|
+
return types["STRING"]
|
dsgrid/tests/__init__.py
ADDED
|
File without changes
|