dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. build_backend.py +93 -0
  2. dsgrid/__init__.py +22 -0
  3. dsgrid/api/__init__.py +0 -0
  4. dsgrid/api/api_manager.py +179 -0
  5. dsgrid/api/app.py +419 -0
  6. dsgrid/api/models.py +60 -0
  7. dsgrid/api/response_models.py +116 -0
  8. dsgrid/apps/__init__.py +0 -0
  9. dsgrid/apps/project_viewer/app.py +216 -0
  10. dsgrid/apps/registration_gui.py +444 -0
  11. dsgrid/chronify.py +32 -0
  12. dsgrid/cli/__init__.py +0 -0
  13. dsgrid/cli/common.py +120 -0
  14. dsgrid/cli/config.py +176 -0
  15. dsgrid/cli/download.py +13 -0
  16. dsgrid/cli/dsgrid.py +157 -0
  17. dsgrid/cli/dsgrid_admin.py +92 -0
  18. dsgrid/cli/install_notebooks.py +62 -0
  19. dsgrid/cli/query.py +729 -0
  20. dsgrid/cli/registry.py +1862 -0
  21. dsgrid/cloud/__init__.py +0 -0
  22. dsgrid/cloud/cloud_storage_interface.py +140 -0
  23. dsgrid/cloud/factory.py +31 -0
  24. dsgrid/cloud/fake_storage_interface.py +37 -0
  25. dsgrid/cloud/s3_storage_interface.py +156 -0
  26. dsgrid/common.py +36 -0
  27. dsgrid/config/__init__.py +0 -0
  28. dsgrid/config/annual_time_dimension_config.py +194 -0
  29. dsgrid/config/common.py +142 -0
  30. dsgrid/config/config_base.py +148 -0
  31. dsgrid/config/dataset_config.py +907 -0
  32. dsgrid/config/dataset_schema_handler_factory.py +46 -0
  33. dsgrid/config/date_time_dimension_config.py +136 -0
  34. dsgrid/config/dimension_config.py +54 -0
  35. dsgrid/config/dimension_config_factory.py +65 -0
  36. dsgrid/config/dimension_mapping_base.py +350 -0
  37. dsgrid/config/dimension_mappings_config.py +48 -0
  38. dsgrid/config/dimensions.py +1025 -0
  39. dsgrid/config/dimensions_config.py +71 -0
  40. dsgrid/config/file_schema.py +190 -0
  41. dsgrid/config/index_time_dimension_config.py +80 -0
  42. dsgrid/config/input_dataset_requirements.py +31 -0
  43. dsgrid/config/mapping_tables.py +209 -0
  44. dsgrid/config/noop_time_dimension_config.py +42 -0
  45. dsgrid/config/project_config.py +1462 -0
  46. dsgrid/config/registration_models.py +188 -0
  47. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  48. dsgrid/config/simple_models.py +49 -0
  49. dsgrid/config/supplemental_dimension.py +29 -0
  50. dsgrid/config/time_dimension_base_config.py +192 -0
  51. dsgrid/data_models.py +155 -0
  52. dsgrid/dataset/__init__.py +0 -0
  53. dsgrid/dataset/dataset.py +123 -0
  54. dsgrid/dataset/dataset_expression_handler.py +86 -0
  55. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  56. dsgrid/dataset/dataset_schema_handler_base.py +945 -0
  57. dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
  58. dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
  59. dsgrid/dataset/growth_rates.py +162 -0
  60. dsgrid/dataset/models.py +51 -0
  61. dsgrid/dataset/table_format_handler_base.py +257 -0
  62. dsgrid/dataset/table_format_handler_factory.py +17 -0
  63. dsgrid/dataset/unpivoted_table.py +121 -0
  64. dsgrid/dimension/__init__.py +0 -0
  65. dsgrid/dimension/base_models.py +230 -0
  66. dsgrid/dimension/dimension_filters.py +308 -0
  67. dsgrid/dimension/standard.py +252 -0
  68. dsgrid/dimension/time.py +352 -0
  69. dsgrid/dimension/time_utils.py +103 -0
  70. dsgrid/dsgrid_rc.py +88 -0
  71. dsgrid/exceptions.py +105 -0
  72. dsgrid/filesystem/__init__.py +0 -0
  73. dsgrid/filesystem/cloud_filesystem.py +32 -0
  74. dsgrid/filesystem/factory.py +32 -0
  75. dsgrid/filesystem/filesystem_interface.py +136 -0
  76. dsgrid/filesystem/local_filesystem.py +74 -0
  77. dsgrid/filesystem/s3_filesystem.py +118 -0
  78. dsgrid/loggers.py +132 -0
  79. dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
  80. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
  81. dsgrid/notebooks/registration.ipynb +48 -0
  82. dsgrid/notebooks/start_notebook.sh +11 -0
  83. dsgrid/project.py +451 -0
  84. dsgrid/query/__init__.py +0 -0
  85. dsgrid/query/dataset_mapping_plan.py +142 -0
  86. dsgrid/query/derived_dataset.py +388 -0
  87. dsgrid/query/models.py +728 -0
  88. dsgrid/query/query_context.py +287 -0
  89. dsgrid/query/query_submitter.py +994 -0
  90. dsgrid/query/report_factory.py +19 -0
  91. dsgrid/query/report_peak_load.py +70 -0
  92. dsgrid/query/reports_base.py +20 -0
  93. dsgrid/registry/__init__.py +0 -0
  94. dsgrid/registry/bulk_register.py +165 -0
  95. dsgrid/registry/common.py +287 -0
  96. dsgrid/registry/config_update_checker_base.py +63 -0
  97. dsgrid/registry/data_store_factory.py +34 -0
  98. dsgrid/registry/data_store_interface.py +74 -0
  99. dsgrid/registry/dataset_config_generator.py +158 -0
  100. dsgrid/registry/dataset_registry_manager.py +950 -0
  101. dsgrid/registry/dataset_update_checker.py +16 -0
  102. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  103. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  104. dsgrid/registry/dimension_registry_manager.py +413 -0
  105. dsgrid/registry/dimension_update_checker.py +16 -0
  106. dsgrid/registry/duckdb_data_store.py +207 -0
  107. dsgrid/registry/filesystem_data_store.py +150 -0
  108. dsgrid/registry/filter_registry_manager.py +123 -0
  109. dsgrid/registry/project_config_generator.py +57 -0
  110. dsgrid/registry/project_registry_manager.py +1623 -0
  111. dsgrid/registry/project_update_checker.py +48 -0
  112. dsgrid/registry/registration_context.py +223 -0
  113. dsgrid/registry/registry_auto_updater.py +316 -0
  114. dsgrid/registry/registry_database.py +667 -0
  115. dsgrid/registry/registry_interface.py +446 -0
  116. dsgrid/registry/registry_manager.py +558 -0
  117. dsgrid/registry/registry_manager_base.py +367 -0
  118. dsgrid/registry/versioning.py +92 -0
  119. dsgrid/rust_ext/__init__.py +14 -0
  120. dsgrid/rust_ext/find_minimal_patterns.py +129 -0
  121. dsgrid/spark/__init__.py +0 -0
  122. dsgrid/spark/functions.py +589 -0
  123. dsgrid/spark/types.py +110 -0
  124. dsgrid/tests/__init__.py +0 -0
  125. dsgrid/tests/common.py +140 -0
  126. dsgrid/tests/make_us_data_registry.py +265 -0
  127. dsgrid/tests/register_derived_datasets.py +103 -0
  128. dsgrid/tests/utils.py +25 -0
  129. dsgrid/time/__init__.py +0 -0
  130. dsgrid/time/time_conversions.py +80 -0
  131. dsgrid/time/types.py +67 -0
  132. dsgrid/units/__init__.py +0 -0
  133. dsgrid/units/constants.py +113 -0
  134. dsgrid/units/convert.py +71 -0
  135. dsgrid/units/energy.py +145 -0
  136. dsgrid/units/power.py +87 -0
  137. dsgrid/utils/__init__.py +0 -0
  138. dsgrid/utils/dataset.py +830 -0
  139. dsgrid/utils/files.py +179 -0
  140. dsgrid/utils/filters.py +125 -0
  141. dsgrid/utils/id_remappings.py +100 -0
  142. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  143. dsgrid/utils/py_expression_eval/README.md +8 -0
  144. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  145. dsgrid/utils/py_expression_eval/tests.py +283 -0
  146. dsgrid/utils/run_command.py +70 -0
  147. dsgrid/utils/scratch_dir_context.py +65 -0
  148. dsgrid/utils/spark.py +918 -0
  149. dsgrid/utils/spark_partition.py +98 -0
  150. dsgrid/utils/timing.py +239 -0
  151. dsgrid/utils/utilities.py +221 -0
  152. dsgrid/utils/versioning.py +36 -0
  153. dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
  154. dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
  155. dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
  156. dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
  157. dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,589 @@
1
+ import logging
2
+ import os
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ from tempfile import NamedTemporaryFile
6
+ from typing import Any, Iterable
7
+ from uuid import uuid4
8
+ from zoneinfo import ZoneInfo
9
+
10
+ import duckdb
11
+
12
+ import dsgrid
13
+ from dsgrid.dsgrid_rc import DsgridRuntimeConfig
14
+ from dsgrid.exceptions import DSGInvalidDataset
15
+ from dsgrid.loggers import disable_console_logging
16
+ from dsgrid.spark.types import (
17
+ DataFrame,
18
+ F,
19
+ SparkConf,
20
+ SparkSession,
21
+ use_duckdb,
22
+ )
23
+ from dsgrid.utils.files import load_line_delimited_json, dump_data
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ if use_duckdb():
29
+ g_duckdb_spark = SparkSession.builder.getOrCreate()
30
+ else:
31
+ g_duckdb_spark = None
32
+
33
+
34
+ TEMP_TABLE_PREFIX = "tmp_dsgrid"
35
+
36
+
37
+ def aggregate(df: DataFrame, agg_func: str, column: str, alias: str) -> DataFrame:
38
+ """Run an aggregate function on the dataframe."""
39
+ if use_duckdb():
40
+ relation = df.relation.aggregate(f"{agg_func}({column}) as {alias}")
41
+ return DataFrame(relation.set_alias(make_temp_view_name()), df.session)
42
+ return df.agg(getattr(F, agg_func)(column).alias(alias))
43
+
44
+
45
+ def aggregate_single_value(df: DataFrame, agg_func: str, column: str) -> Any:
46
+ """Run an aggregate function on the dataframe that will produce a single value, such as max.
47
+ Return that single value.
48
+ """
49
+ alias = "__tmp__"
50
+ if use_duckdb():
51
+ return df.relation.aggregate(f"{agg_func}({column}) as {alias}").df().values[0][0]
52
+ return df.agg(getattr(F, agg_func)(column).alias(alias)).collect()[0][alias]
53
+
54
+
55
+ def cache(df: DataFrame) -> DataFrame:
56
+ """Cache the dataframe. This is a no-op for DuckDB."""
57
+ if use_duckdb():
58
+ return df
59
+ return df.cache()
60
+
61
+
62
+ def unpersist(df: DataFrame) -> None:
63
+ """Unpersist a dataframe that was previously cached. This is a no-op for DuckDB."""
64
+ if not use_duckdb():
65
+ df.unpersist()
66
+
67
+
68
+ def coalesce(df: DataFrame, num_partitions: int) -> DataFrame:
69
+ """Coalesce the dataframe into num_partitions partitions. This is a no-op for DuckDB."""
70
+ if use_duckdb():
71
+ return df
72
+ return df.coalesce(num_partitions)
73
+
74
+
75
+ def collect_list(df: DataFrame, column: str) -> list:
76
+ """Collect the dataframe column into a list."""
77
+ if use_duckdb():
78
+ return [x[column] for x in df.collect()]
79
+
80
+ return next(iter(df.select(F.collect_list(column)).first()))
81
+
82
+
83
+ def count_distinct_on_group_by(
84
+ df: DataFrame, group_by_columns: list[str], agg_column: str, alias: str
85
+ ) -> DataFrame:
86
+ """Perform a count distinct on one column after grouping."""
87
+ if use_duckdb():
88
+ view = create_temp_view(df)
89
+ cols = ",".join([f'"{x}"' for x in group_by_columns])
90
+ query = f"""
91
+ SELECT {cols}, COUNT(DISTINCT "{agg_column}") AS "{alias}"
92
+ FROM {view}
93
+ GROUP BY {cols}
94
+ """
95
+ return get_spark_session().sql(query)
96
+
97
+ return df.groupBy(*group_by_columns).agg(F.count_distinct(agg_column).alias(alias))
98
+
99
+
100
+ def create_temp_view(df: DataFrame) -> str:
101
+ """Create a temporary view with a random name and return the name."""
102
+ view1 = make_temp_view_name()
103
+ df.createOrReplaceTempView(view1)
104
+ return view1
105
+
106
+
107
+ def make_temp_view_name() -> str:
108
+ """Make a random name to be used as a view."""
109
+ return f"{TEMP_TABLE_PREFIX}_{uuid4().hex}"
110
+
111
+
112
+ def drop_temp_tables_and_views() -> None:
113
+ """Drop all temporary views and tables."""
114
+ drop_temp_views()
115
+ drop_temp_tables()
116
+
117
+
118
+ def drop_temp_tables() -> None:
119
+ """Drop all temporary tables."""
120
+ spark = get_spark_session()
121
+ if use_duckdb():
122
+ query = f"SELECT * FROM pg_tables WHERE tablename LIKE '%{TEMP_TABLE_PREFIX}%'"
123
+ for row in spark.sql(query).collect():
124
+ spark.sql(f"DROP TABLE {row.tablename}")
125
+ else:
126
+ for row in spark.sql(f"SHOW TABLES LIKE '*{TEMP_TABLE_PREFIX}*'").collect():
127
+ spark.sql(f"DROP TABLE {row.tableName}")
128
+
129
+
130
+ def drop_temp_views() -> None:
131
+ """Drop all temporary views."""
132
+ spark = get_spark_session()
133
+ if use_duckdb():
134
+ query = f"""
135
+ SELECT view_name FROM duckdb_views()
136
+ WHERE NOT internal AND view_name LIKE '%{TEMP_TABLE_PREFIX}%'
137
+ """
138
+ for row in spark.sql(query).collect():
139
+ spark.sql(f"DROP VIEW {row.view_name}")
140
+ else:
141
+ for row in spark.sql(f"SHOW VIEWS LIKE '*{TEMP_TABLE_PREFIX}*'").collect():
142
+ spark.sql(f"DROP VIEW {row.viewName}")
143
+
144
+
145
+ def cross_join(df1: DataFrame, df2: DataFrame) -> DataFrame:
146
+ """Return a cross join of the two dataframes."""
147
+ if use_duckdb():
148
+ view1 = create_temp_view(df1)
149
+ view2 = create_temp_view(df2)
150
+ spark = get_spark_session()
151
+ return spark.sql(f"SELECT * from {view1} CROSS JOIN {view2}")
152
+
153
+ return df1.crossJoin(df2)
154
+
155
+
156
+ def except_all(df1: DataFrame, df2: DataFrame) -> DataFrame:
157
+ """Return a dataframe with all rows in df1 that are not in df2."""
158
+ method = _except_all_duckdb if use_duckdb() else _except_all_spark
159
+ return method(df1, df2)
160
+
161
+
162
+ def _except_all_duckdb(df1: DataFrame, df2: DataFrame) -> DataFrame:
163
+ view1 = create_temp_view(df1)
164
+ view2 = create_temp_view(df2)
165
+ query = f"""
166
+ SELECT * FROM {view1}
167
+ EXCEPT ALL
168
+ SELECT * FROM {view2}
169
+ """
170
+ spark = get_spark_session()
171
+ return spark.sql(query)
172
+
173
+
174
+ def _except_all_spark(df1: DataFrame, df2: DataFrame) -> DataFrame:
175
+ return df1.exceptAll(df2)
176
+
177
+
178
+ def handle_column_spaces(column: str) -> str:
179
+ """Return a column string suitable for the backend."""
180
+ if use_duckdb():
181
+ return f'"{column}"'
182
+ return f"`{column}`"
183
+
184
+
185
+ def intersect(df1: DataFrame, df2: DataFrame) -> DataFrame:
186
+ """Return an intersection of rows. Duplicates are not returned"""
187
+ # Could add intersect all if duplicated are needed.
188
+ method = _intersect_duckdb if use_duckdb() else _intersect_spark
189
+ return method(df1, df2)
190
+
191
+
192
+ def _intersect_duckdb(df1: DataFrame, df2: DataFrame) -> DataFrame:
193
+ view1 = create_temp_view(df1)
194
+ view2 = create_temp_view(df2)
195
+ query = f"""
196
+ SELECT * FROM {view1}
197
+ INTERSECT
198
+ SELECT * FROM {view2}
199
+ """
200
+ spark = get_spark_session()
201
+ return spark.sql(query)
202
+
203
+
204
+ def _intersect_spark(df1: DataFrame, df2: DataFrame) -> DataFrame:
205
+ return df1.intersect(df2)
206
+
207
+
208
+ def get_duckdb_spark_session() -> SparkSession | None:
209
+ """Return the active DuckDB Spark Session if it is set."""
210
+ return g_duckdb_spark
211
+
212
+
213
+ def get_spark_session() -> SparkSession:
214
+ """Return the active SparkSession or create a new one is none is active."""
215
+ spark = get_duckdb_spark_session()
216
+ if spark is not None:
217
+ return spark
218
+
219
+ spark = SparkSession.getActiveSession()
220
+ if spark is None:
221
+ logger.warning("Could not find a SparkSession. Create a new one.")
222
+ spark = SparkSession.builder.getOrCreate()
223
+ log_spark_conf(spark)
224
+ return spark
225
+
226
+
227
+ def get_spark_warehouse_dir() -> Path:
228
+ """Return the Spark warehouse directory. Not valid with DuckDB."""
229
+ assert not use_duckdb()
230
+ val = get_spark_session().conf.get("spark.sql.warehouse.dir")
231
+ assert isinstance(val, str)
232
+ if not val:
233
+ msg = "Bug: spark.sql.warehouse.dir is not set"
234
+ raise Exception(msg)
235
+ if not val.startswith("file:"):
236
+ msg = f"get_spark_warehouse_dir only supports local file paths currently: {val}"
237
+ raise NotImplementedError(msg)
238
+ return Path(val.split("file:")[1])
239
+
240
+
241
+ def get_current_time_zone() -> str:
242
+ """Return the current time zone."""
243
+ spark = get_spark_session()
244
+ if use_duckdb():
245
+ res = spark.sql("SELECT * FROM duckdb_settings() WHERE name = 'TimeZone'").collect()
246
+ assert len(res) == 1
247
+ return res[0].value
248
+
249
+ tz = spark.conf.get("spark.sql.session.timeZone")
250
+ assert tz is not None
251
+ return tz
252
+
253
+
254
+ def set_current_time_zone(time_zone: str) -> None:
255
+ """Set the current time zone."""
256
+ spark = get_spark_session()
257
+ if use_duckdb():
258
+ spark.sql(f"SET TimeZone='{time_zone}'")
259
+ return
260
+
261
+ spark.conf.set("spark.sql.session.timeZone", time_zone)
262
+
263
+
264
+ def init_spark(name="dsgrid", check_env=True, spark_conf=None) -> SparkSession:
265
+ """Initialize a SparkSession.
266
+
267
+ Parameters
268
+ ----------
269
+ name : str
270
+ check_env : bool
271
+ If True, which is default, check for the SPARK_CLUSTER environment variable and attach to
272
+ it. Otherwise, create a local-mode cluster or attach to the SparkSession that was created
273
+ by pyspark/spark-submit prior to starting the current process.
274
+ spark_conf : dict | None, defaults to None
275
+ If set, Spark configuration parameters
276
+
277
+ """
278
+ if use_duckdb():
279
+ logger.info("Using DuckDB as the backend engine.")
280
+ return g_duckdb_spark
281
+
282
+ logger.info("Using Spark as the backend engine.")
283
+ cluster = os.environ.get("SPARK_CLUSTER")
284
+ conf = SparkConf().setAppName(name)
285
+ if spark_conf is not None:
286
+ for key, val in spark_conf.items():
287
+ conf.set(key, val)
288
+
289
+ out_ts_type = conf.get("spark.sql.parquet.outputTimestampType")
290
+ if out_ts_type is None:
291
+ conf.set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS")
292
+ elif out_ts_type != "TIMESTAMP_MICROS":
293
+ logger.warning(
294
+ "spark.sql.parquet.outputTimestampType is set to %s. Writing parquet files may "
295
+ "produced undesired results.",
296
+ out_ts_type,
297
+ )
298
+ conf.set("spark.sql.legacy.parquet.nanosAsLong", "true")
299
+
300
+ if check_env and cluster is not None:
301
+ logger.info("Create SparkSession %s on existing cluster %s", name, cluster)
302
+ conf.setMaster(cluster)
303
+
304
+ config = SparkSession.builder.config(conf=conf)
305
+ if dsgrid.runtime_config.use_hive_metastore:
306
+ config = config.enableHiveSupport()
307
+ spark = config.getOrCreate()
308
+
309
+ with disable_console_logging():
310
+ log_spark_conf(spark)
311
+ logger.info("Custom configuration settings: %s", spark_conf)
312
+
313
+ return spark
314
+
315
+
316
+ def is_dataframe_empty(df: DataFrame) -> bool:
317
+ """Return True if the DataFrame is empty."""
318
+ if use_duckdb():
319
+ view = create_temp_view(df)
320
+ spark = get_spark_session()
321
+ col = df.columns[0]
322
+ return spark.sql(f'SELECT "{col}" FROM {view} LIMIT 1').count() == 0
323
+ return df.rdd.isEmpty()
324
+
325
+
326
+ def perform_interval_op(
327
+ df: DataFrame, time_column, op: str, val: Any, unit: str, alias: str
328
+ ) -> DataFrame:
329
+ """Perform an interval operation ('-' or '+') on a time column."""
330
+ if use_duckdb():
331
+ view = create_temp_view(df)
332
+ cols = df.columns[:]
333
+ if alias == time_column:
334
+ cols.remove(time_column)
335
+ cols_str = ",".join([f'"{x}"' for x in cols])
336
+ query = (
337
+ f'SELECT "{time_column}" {op} INTERVAL {val} {unit} AS {alias}, {cols_str} from {view}'
338
+ )
339
+ return get_spark_session().sql(query)
340
+
341
+ interval_expr = F.expr(f"INTERVAL {val} SECONDS")
342
+ match op:
343
+ case "-":
344
+ expr = F.col(time_column) - interval_expr
345
+ case "+":
346
+ expr = F.col(time_column) + interval_expr
347
+ case _:
348
+ msg = f"{op=} is not supported"
349
+ raise NotImplementedError(msg)
350
+ return df.withColumn(alias, expr)
351
+
352
+
353
+ def join(df1: DataFrame, df2: DataFrame, column1: str, column2: str, how="inner") -> DataFrame:
354
+ """Join two dataframes on one column. Use this method whenever the result may be joined
355
+ with another dataframe in order to workaround a DuckDB issue.
356
+ """
357
+ df = df1.join(df2, on=df1[column1] == df2[column2], how=how)
358
+ if use_duckdb():
359
+ # DuckDB sets the relation alias to "relation", which causes problems with future
360
+ # joins. They declined to address this in https://github.com/duckdb/duckdb/issues/12959
361
+ df.relation = df.relation.set_alias(f"relation_{uuid4()}")
362
+
363
+ return df
364
+
365
+
366
+ def join_multiple_columns(
367
+ df1: DataFrame, df2: DataFrame, columns: list[str], how="inner"
368
+ ) -> DataFrame:
369
+ """Join two dataframes on multiple columns."""
370
+ if use_duckdb():
371
+ view1 = create_temp_view(df1)
372
+ view2 = create_temp_view(df2)
373
+ view2_columns = ",".join((f'{view2}."{x}"' for x in df2.columns if x not in columns))
374
+ on_str = " AND ".join((f'{view1}."{x}" = {view2}."{x}"' for x in columns))
375
+ query = f"""
376
+ SELECT {view1}.*, {view2_columns}
377
+ FROM {view1}
378
+ {how} JOIN {view2}
379
+ ON {on_str}
380
+ """
381
+ # This does not have the alias="relation" issue discussed above.
382
+ return get_spark_session().sql(query)
383
+
384
+ return df1.join(df2, columns, how=how)
385
+
386
+
387
+ def log_spark_conf(spark: SparkSession):
388
+ """Log the Spark configuration details."""
389
+ if not use_duckdb():
390
+ conf = spark.sparkContext.getConf().getAll()
391
+ conf.sort(key=lambda x: x[0])
392
+ logger.info("Spark conf: %s", "\n".join([f"{x} = {y}" for x, y in conf]))
393
+
394
+
395
+ def prepare_timestamps_for_dataframe(timestamps: Iterable[datetime]) -> Iterable[datetime]:
396
+ """Apply necessary conversions of the timestamps for dataframe creation."""
397
+ if use_duckdb():
398
+ return [x.astimezone(ZoneInfo("UTC")) for x in timestamps]
399
+ return timestamps
400
+
401
+
402
+ def read_csv(path: Path | str, schema: dict[str, str] | None = None) -> DataFrame:
403
+ """Return a DataFrame from a CSV file, handling special cases with duckdb."""
404
+ func = read_csv_duckdb if use_duckdb() else _read_csv_spark
405
+ df = func(path, schema)
406
+ if schema is not None:
407
+ if set(df.columns).symmetric_difference(schema.keys()):
408
+ msg = (
409
+ f"Mismatch in CSV schema ({sorted(schema.keys())}) "
410
+ f"vs DataFrame columns ({df.columns})"
411
+ )
412
+ raise DSGInvalidDataset(msg)
413
+
414
+ return df
415
+
416
+
417
+ def _read_csv_spark(path: Path | str, schema: dict[str, str] | None) -> DataFrame:
418
+ spark = get_spark_session()
419
+ if schema is None:
420
+ return spark.read.csv(str(path), header=True, inferSchema=True)
421
+
422
+ schema_str = ",".join([f"{key} {val}" for key, val in schema.items()])
423
+ return spark.read.csv(str(path), header=True, schema=schema_str)
424
+
425
+
426
+ def read_csv_duckdb(path_or_str: Path | str, schema: dict[str, str] | None) -> DataFrame:
427
+ """Read a CSV file using DuckDB and return a Spark DataFrame.
428
+
429
+ Parameters
430
+ ----------
431
+ path_or_str : Path | str
432
+ Path to the CSV file or directory containing CSV files.
433
+ schema : dict[str, str] | None
434
+ Mapping of column names to DuckDB data types.
435
+ """
436
+ path = Path(path_or_str)
437
+ if path.is_dir():
438
+ path_str = str(path) + "**/*.csv"
439
+ else:
440
+ path_str = str(path)
441
+
442
+ spark = get_spark_session()
443
+ if not schema:
444
+ return spark.read.csv(path_str, header=True)
445
+
446
+ dtypes = {k: duckdb.type(v) for k, v in schema.items()}
447
+ rel = duckdb.read_csv(path_str, header=True, dtype=dtypes)
448
+ if use_duckdb():
449
+ return spark.createDataFrame(rel.to_df())
450
+
451
+ # DT 12/1/2025
452
+ # This obnoxious code block provides the only way I've found to read a CSV file into Spark
453
+ # while allowing these behaviors:
454
+ # - Preserve NULL values. DuckDB -> Pandas -> Spark converts NULLs to NaNs.
455
+ # - Allow the user to specify a subset of columns with data types. The native Spark CSV
456
+ # reader will drop columns not specified in the schema.
457
+ # This shouldn't matter much because Spark + CSV should never happen with large datasets.
458
+ scratch_dir = DsgridRuntimeConfig().get_scratch_dir()
459
+ with NamedTemporaryFile(suffix=".parquet", dir=scratch_dir) as f:
460
+ f.close()
461
+ rel.write_parquet(f.name)
462
+ df = spark.read.parquet(f.name)
463
+ # Bring the entire table into memory so that we can delete the file.
464
+ df.cache()
465
+ df.count()
466
+ return df
467
+
468
+
469
+ def read_json(path: Path | str) -> DataFrame:
470
+ """Return a DataFrame from a JSON file, handling special cases with duckdb.
471
+
472
+ Warning: Use of this function with DuckDB is not efficient because it requires that we
473
+ convert line-delimited JSON to standard JSON.
474
+ """
475
+ spark = get_spark_session()
476
+ filename = str(path)
477
+ if use_duckdb():
478
+ with NamedTemporaryFile(suffix=".json") as f:
479
+ f.close()
480
+ # TODO duckdb: look for something more efficient. Not a big deal right now.
481
+ data = load_line_delimited_json(path)
482
+ dump_data(data, f.name)
483
+ return spark.read.json(f.name)
484
+ return spark.read.json(filename, mode="FAILFAST")
485
+
486
+
487
+ def read_parquet(path: Path | str) -> DataFrame:
488
+ path = Path(path) if isinstance(path, str) else path
489
+ spark = get_spark_session()
490
+ if path.is_file() or not use_duckdb():
491
+ df = spark.read.parquet(str(path))
492
+ else:
493
+ df = spark.read.parquet(f"{path}/**/*.parquet")
494
+ return df
495
+
496
+
497
+ def select_expr(df: DataFrame, exprs: list[str]) -> DataFrame:
498
+ """Execute the SQL SELECT expression. It is the caller's responsibility to handle column
499
+ names with spaces or special characters.
500
+ """
501
+ if use_duckdb():
502
+ view = create_temp_view(df)
503
+ spark = get_spark_session()
504
+ cols = ",".join(exprs)
505
+ return spark.sql(f"SELECT {cols} FROM {view}")
506
+ return df.selectExpr(*exprs)
507
+
508
+
509
+ def sql_from_df(df: DataFrame, query: str) -> DataFrame:
510
+ """Run a SQL query on a dataframe with Spark."""
511
+ logger.debug("Run SQL query [%s]", query)
512
+ spark = get_spark_session()
513
+ if use_duckdb():
514
+ view = create_temp_view(df)
515
+ query += f" FROM {view}"
516
+ return spark.sql(query)
517
+
518
+ query += " FROM {df}"
519
+ return spark.sql(query, df=df)
520
+
521
+
522
+ def pivot(df: DataFrame, name_column: str, value_column: str) -> DataFrame:
523
+ """Unpivot the dataframe."""
524
+ method = _pivot_duckdb if use_duckdb() else _pivot_spark
525
+ return method(df, name_column, value_column)
526
+
527
+
528
+ def _pivot_duckdb(df: DataFrame, name_column: str, value_column: str) -> DataFrame:
529
+ view = create_temp_view(df)
530
+ query = f"""
531
+ PIVOT {view}
532
+ ON "{name_column}"
533
+ USING SUM({value_column})
534
+ """
535
+ return get_spark_session().sql(query)
536
+
537
+
538
+ def _pivot_spark(df: DataFrame, name_column: str, value_column: str) -> DataFrame:
539
+ ids = [x for x in df.columns if x not in {name_column, value_column}]
540
+ return df.groupBy(*ids).pivot(name_column).sum(value_column)
541
+
542
+
543
+ def unpivot(df: DataFrame, pivoted_columns, name_column: str, value_column: str) -> DataFrame:
544
+ """Unpivot the dataframe."""
545
+ method = _unpivot_duckdb if use_duckdb() else _unpivot_spark
546
+ return method(df, pivoted_columns, name_column, value_column)
547
+
548
+
549
+ def _unpivot_duckdb(
550
+ df: DataFrame, pivoted_columns, name_column: str, value_column: str
551
+ ) -> DataFrame:
552
+ view = create_temp_view(df)
553
+ cols = ",".join([f'"{x}"' for x in pivoted_columns])
554
+ query = f"""
555
+ SELECT * FROM {view}
556
+ UNPIVOT INCLUDE NULLS (
557
+ "{value_column}"
558
+ FOR "{name_column}" in ({cols})
559
+ )
560
+ """
561
+ spark = get_spark_session()
562
+ df = spark.sql(query)
563
+ return df
564
+
565
+
566
+ def _unpivot_spark(
567
+ df: DataFrame, pivoted_columns, name_column: str, value_column: str
568
+ ) -> DataFrame:
569
+ ids = list(set(df.columns) - {value_column, *pivoted_columns})
570
+ return df.unpivot(
571
+ ids,
572
+ pivoted_columns,
573
+ name_column,
574
+ value_column,
575
+ )
576
+
577
+
578
+ def write_csv(
579
+ df: DataFrame, path: Path | str, header: bool = True, overwrite: bool = False
580
+ ) -> None:
581
+ """Write a DataFrame to a CSV file, handling special cases with duckdb."""
582
+ path_str = path if isinstance(path, str) else str(path)
583
+ if use_duckdb():
584
+ df.relation.write_csv(path_str, header=header, overwrite=overwrite)
585
+ else:
586
+ if overwrite:
587
+ df.write.options(header=True).mode("overwrite").csv(path_str)
588
+ else:
589
+ df.write.options(header=True).csv(path_str)
dsgrid/spark/types.py ADDED
@@ -0,0 +1,110 @@
1
+ # flake8: noqa
2
+
3
+ import dsgrid
4
+ from dsgrid.common import BackendEngine
5
+
6
+
7
+ def use_duckdb() -> bool:
8
+ """Return True if the environment is set to use DuckDB instead of Spark."""
9
+ return dsgrid.runtime_config.backend_engine == BackendEngine.DUCKDB
10
+
11
+
12
+ if use_duckdb():
13
+ import duckdb.experimental.spark.sql.functions as F
14
+ from duckdb.experimental.spark.conf import SparkConf
15
+ from duckdb.experimental.spark.sql import DataFrame, SparkSession
16
+ from duckdb.experimental.spark.sql.types import (
17
+ ByteType,
18
+ StructField,
19
+ StructType,
20
+ StringType,
21
+ BooleanType,
22
+ IntegerType,
23
+ ShortType,
24
+ LongType,
25
+ DoubleType,
26
+ FloatType,
27
+ TimestampType,
28
+ TimestampNTZType,
29
+ Row,
30
+ )
31
+ from duckdb.experimental.spark.errors import AnalysisException
32
+ else:
33
+ import pyspark.sql.functions as F
34
+ from pyspark.sql import DataFrame, Row, SparkSession
35
+ from pyspark.sql.types import (
36
+ ByteType,
37
+ FloatType,
38
+ StructType,
39
+ StructField,
40
+ StringType,
41
+ DoubleType,
42
+ IntegerType,
43
+ LongType,
44
+ ShortType,
45
+ BooleanType,
46
+ TimestampType,
47
+ TimestampNTZType,
48
+ )
49
+ from pyspark.errors import AnalysisException
50
+ from pyspark import SparkConf
51
+
52
+
53
+ SUPPORTED_TYPES = set(
54
+ (
55
+ "BOOLEAN",
56
+ "INT",
57
+ "INTEGER",
58
+ "TINYINT",
59
+ "SMALLINT",
60
+ "BIGINT",
61
+ "FLOAT",
62
+ "DOUBLE",
63
+ "TIMESTAMP_TZ",
64
+ "TIMESTAMP_NTZ",
65
+ "STRING",
66
+ "TEXT",
67
+ "VARCHAR",
68
+ )
69
+ )
70
+
71
+ DUCKDB_COLUMN_TYPES = {
72
+ "BOOLEAN": "BOOLEAN",
73
+ "INT": "INTEGER",
74
+ "INTEGER": "INTEGER",
75
+ "TINYINT": "TINYINT",
76
+ "SMALLINT": "INTEGER",
77
+ "BIGINT": "BIGINT",
78
+ "FLOAT": "FLOAT",
79
+ "DOUBLE": "DOUBLE",
80
+ "TIMESTAMP_TZ": "TIMESTAMP WITH TIME ZONE",
81
+ "TIMESTAMP_NTZ": "TIMESTAMP",
82
+ "STRING": "VARCHAR",
83
+ "TEXT": "VARCHAR",
84
+ "VARCHAR": "VARCHAR",
85
+ }
86
+
87
+ SPARK_COLUMN_TYPES = {
88
+ "BOOLEAN": "BOOLEAN",
89
+ "INT": "INT",
90
+ "INTEGER": "INT",
91
+ "TINYINT": "TINYINT",
92
+ "SMALLINT": "SMALLINT",
93
+ "BIGINT": "BIGINT",
94
+ "FLOAT": "FLOAT",
95
+ "DOUBLE": "DOUBLE",
96
+ "STRING": "STRING",
97
+ "TEXT": "STRING",
98
+ "VARCHAR": "STRING",
99
+ "TIMESTAMP_TZ": "TIMESTAMP",
100
+ "TIMESTAMP_NTZ": "TIMESTAMP_NTZ",
101
+ }
102
+
103
+ assert sorted(DUCKDB_COLUMN_TYPES.keys()) == sorted(SPARK_COLUMN_TYPES.keys())
104
+ assert not SUPPORTED_TYPES.difference(DUCKDB_COLUMN_TYPES.keys())
105
+
106
+
107
+ def get_str_type() -> str:
108
+ """Return the string type used by the current database system."""
109
+ types = DUCKDB_COLUMN_TYPES if use_duckdb() else SPARK_COLUMN_TYPES
110
+ return types["STRING"]
File without changes