dsgrid-toolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsgrid-toolkit might be problematic. Click here for more details.

Files changed (152) hide show
  1. dsgrid/__init__.py +22 -0
  2. dsgrid/api/__init__.py +0 -0
  3. dsgrid/api/api_manager.py +179 -0
  4. dsgrid/api/app.py +420 -0
  5. dsgrid/api/models.py +60 -0
  6. dsgrid/api/response_models.py +116 -0
  7. dsgrid/apps/__init__.py +0 -0
  8. dsgrid/apps/project_viewer/app.py +216 -0
  9. dsgrid/apps/registration_gui.py +444 -0
  10. dsgrid/chronify.py +22 -0
  11. dsgrid/cli/__init__.py +0 -0
  12. dsgrid/cli/common.py +120 -0
  13. dsgrid/cli/config.py +177 -0
  14. dsgrid/cli/download.py +13 -0
  15. dsgrid/cli/dsgrid.py +142 -0
  16. dsgrid/cli/dsgrid_admin.py +349 -0
  17. dsgrid/cli/install_notebooks.py +62 -0
  18. dsgrid/cli/query.py +711 -0
  19. dsgrid/cli/registry.py +1773 -0
  20. dsgrid/cloud/__init__.py +0 -0
  21. dsgrid/cloud/cloud_storage_interface.py +140 -0
  22. dsgrid/cloud/factory.py +31 -0
  23. dsgrid/cloud/fake_storage_interface.py +37 -0
  24. dsgrid/cloud/s3_storage_interface.py +156 -0
  25. dsgrid/common.py +35 -0
  26. dsgrid/config/__init__.py +0 -0
  27. dsgrid/config/annual_time_dimension_config.py +187 -0
  28. dsgrid/config/common.py +131 -0
  29. dsgrid/config/config_base.py +148 -0
  30. dsgrid/config/dataset_config.py +684 -0
  31. dsgrid/config/dataset_schema_handler_factory.py +41 -0
  32. dsgrid/config/date_time_dimension_config.py +108 -0
  33. dsgrid/config/dimension_config.py +54 -0
  34. dsgrid/config/dimension_config_factory.py +65 -0
  35. dsgrid/config/dimension_mapping_base.py +349 -0
  36. dsgrid/config/dimension_mappings_config.py +48 -0
  37. dsgrid/config/dimensions.py +775 -0
  38. dsgrid/config/dimensions_config.py +71 -0
  39. dsgrid/config/index_time_dimension_config.py +76 -0
  40. dsgrid/config/input_dataset_requirements.py +31 -0
  41. dsgrid/config/mapping_tables.py +209 -0
  42. dsgrid/config/noop_time_dimension_config.py +42 -0
  43. dsgrid/config/project_config.py +1457 -0
  44. dsgrid/config/registration_models.py +199 -0
  45. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  46. dsgrid/config/simple_models.py +49 -0
  47. dsgrid/config/supplemental_dimension.py +29 -0
  48. dsgrid/config/time_dimension_base_config.py +200 -0
  49. dsgrid/data_models.py +155 -0
  50. dsgrid/dataset/__init__.py +0 -0
  51. dsgrid/dataset/dataset.py +123 -0
  52. dsgrid/dataset/dataset_expression_handler.py +86 -0
  53. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  54. dsgrid/dataset/dataset_schema_handler_base.py +899 -0
  55. dsgrid/dataset/dataset_schema_handler_one_table.py +196 -0
  56. dsgrid/dataset/dataset_schema_handler_standard.py +303 -0
  57. dsgrid/dataset/growth_rates.py +162 -0
  58. dsgrid/dataset/models.py +44 -0
  59. dsgrid/dataset/table_format_handler_base.py +257 -0
  60. dsgrid/dataset/table_format_handler_factory.py +17 -0
  61. dsgrid/dataset/unpivoted_table.py +121 -0
  62. dsgrid/dimension/__init__.py +0 -0
  63. dsgrid/dimension/base_models.py +218 -0
  64. dsgrid/dimension/dimension_filters.py +308 -0
  65. dsgrid/dimension/standard.py +213 -0
  66. dsgrid/dimension/time.py +531 -0
  67. dsgrid/dimension/time_utils.py +88 -0
  68. dsgrid/dsgrid_rc.py +88 -0
  69. dsgrid/exceptions.py +105 -0
  70. dsgrid/filesystem/__init__.py +0 -0
  71. dsgrid/filesystem/cloud_filesystem.py +32 -0
  72. dsgrid/filesystem/factory.py +32 -0
  73. dsgrid/filesystem/filesystem_interface.py +136 -0
  74. dsgrid/filesystem/local_filesystem.py +74 -0
  75. dsgrid/filesystem/s3_filesystem.py +118 -0
  76. dsgrid/loggers.py +132 -0
  77. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +950 -0
  78. dsgrid/notebooks/registration.ipynb +48 -0
  79. dsgrid/notebooks/start_notebook.sh +11 -0
  80. dsgrid/project.py +451 -0
  81. dsgrid/query/__init__.py +0 -0
  82. dsgrid/query/dataset_mapping_plan.py +142 -0
  83. dsgrid/query/derived_dataset.py +384 -0
  84. dsgrid/query/models.py +726 -0
  85. dsgrid/query/query_context.py +287 -0
  86. dsgrid/query/query_submitter.py +847 -0
  87. dsgrid/query/report_factory.py +19 -0
  88. dsgrid/query/report_peak_load.py +70 -0
  89. dsgrid/query/reports_base.py +20 -0
  90. dsgrid/registry/__init__.py +0 -0
  91. dsgrid/registry/bulk_register.py +161 -0
  92. dsgrid/registry/common.py +287 -0
  93. dsgrid/registry/config_update_checker_base.py +63 -0
  94. dsgrid/registry/data_store_factory.py +34 -0
  95. dsgrid/registry/data_store_interface.py +69 -0
  96. dsgrid/registry/dataset_config_generator.py +156 -0
  97. dsgrid/registry/dataset_registry_manager.py +734 -0
  98. dsgrid/registry/dataset_update_checker.py +16 -0
  99. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  100. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  101. dsgrid/registry/dimension_registry_manager.py +413 -0
  102. dsgrid/registry/dimension_update_checker.py +16 -0
  103. dsgrid/registry/duckdb_data_store.py +185 -0
  104. dsgrid/registry/filesystem_data_store.py +141 -0
  105. dsgrid/registry/filter_registry_manager.py +123 -0
  106. dsgrid/registry/project_config_generator.py +57 -0
  107. dsgrid/registry/project_registry_manager.py +1616 -0
  108. dsgrid/registry/project_update_checker.py +48 -0
  109. dsgrid/registry/registration_context.py +223 -0
  110. dsgrid/registry/registry_auto_updater.py +316 -0
  111. dsgrid/registry/registry_database.py +662 -0
  112. dsgrid/registry/registry_interface.py +446 -0
  113. dsgrid/registry/registry_manager.py +544 -0
  114. dsgrid/registry/registry_manager_base.py +367 -0
  115. dsgrid/registry/versioning.py +92 -0
  116. dsgrid/spark/__init__.py +0 -0
  117. dsgrid/spark/functions.py +545 -0
  118. dsgrid/spark/types.py +50 -0
  119. dsgrid/tests/__init__.py +0 -0
  120. dsgrid/tests/common.py +139 -0
  121. dsgrid/tests/make_us_data_registry.py +204 -0
  122. dsgrid/tests/register_derived_datasets.py +103 -0
  123. dsgrid/tests/utils.py +25 -0
  124. dsgrid/time/__init__.py +0 -0
  125. dsgrid/time/time_conversions.py +80 -0
  126. dsgrid/time/types.py +67 -0
  127. dsgrid/units/__init__.py +0 -0
  128. dsgrid/units/constants.py +113 -0
  129. dsgrid/units/convert.py +71 -0
  130. dsgrid/units/energy.py +145 -0
  131. dsgrid/units/power.py +87 -0
  132. dsgrid/utils/__init__.py +0 -0
  133. dsgrid/utils/dataset.py +612 -0
  134. dsgrid/utils/files.py +179 -0
  135. dsgrid/utils/filters.py +125 -0
  136. dsgrid/utils/id_remappings.py +100 -0
  137. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  138. dsgrid/utils/py_expression_eval/README.md +8 -0
  139. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  140. dsgrid/utils/py_expression_eval/tests.py +283 -0
  141. dsgrid/utils/run_command.py +70 -0
  142. dsgrid/utils/scratch_dir_context.py +64 -0
  143. dsgrid/utils/spark.py +918 -0
  144. dsgrid/utils/spark_partition.py +98 -0
  145. dsgrid/utils/timing.py +239 -0
  146. dsgrid/utils/utilities.py +184 -0
  147. dsgrid/utils/versioning.py +36 -0
  148. dsgrid_toolkit-0.2.0.dist-info/METADATA +216 -0
  149. dsgrid_toolkit-0.2.0.dist-info/RECORD +152 -0
  150. dsgrid_toolkit-0.2.0.dist-info/WHEEL +4 -0
  151. dsgrid_toolkit-0.2.0.dist-info/entry_points.txt +4 -0
  152. dsgrid_toolkit-0.2.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,545 @@
1
+ import logging
2
+ import os
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ from tempfile import NamedTemporaryFile
6
+ from typing import Any, Iterable
7
+ from uuid import uuid4
8
+ from zoneinfo import ZoneInfo
9
+
10
+ import pandas as pd
11
+
12
+ import dsgrid
13
+ from dsgrid.exceptions import DSGInvalidDimension
14
+ from dsgrid.loggers import disable_console_logging
15
+ from dsgrid.spark.types import (
16
+ DataFrame,
17
+ F,
18
+ SparkConf,
19
+ SparkSession,
20
+ TimestampType,
21
+ use_duckdb,
22
+ )
23
+ from dsgrid.utils.files import load_line_delimited_json, dump_data
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ if use_duckdb():
29
+ g_duckdb_spark = SparkSession.builder.getOrCreate()
30
+ else:
31
+ g_duckdb_spark = None
32
+
33
+
34
+ TEMP_TABLE_PREFIX = "tmp_dsgrid"
35
+
36
+
37
+ def aggregate(df: DataFrame, agg_func: str, column: str, alias: str) -> DataFrame:
38
+ """Run an aggregate function on the dataframe."""
39
+ if use_duckdb():
40
+ relation = df.relation.aggregate(f"{agg_func}({column}) as {alias}")
41
+ return DataFrame(relation.set_alias(make_temp_view_name()), df.session)
42
+ return df.agg(getattr(F, agg_func)(column).alias(alias))
43
+
44
+
45
+ def aggregate_single_value(df: DataFrame, agg_func: str, column: str) -> Any:
46
+ """Run an aggregate function on the dataframe that will produce a single value, such as max.
47
+ Return that single value.
48
+ """
49
+ alias = "__tmp__"
50
+ if use_duckdb():
51
+ return df.relation.aggregate(f"{agg_func}({column}) as {alias}").df().values[0][0]
52
+ return df.agg(getattr(F, agg_func)(column).alias(alias)).collect()[0][alias]
53
+
54
+
55
+ def cache(df: DataFrame) -> DataFrame:
56
+ """Cache the dataframe. This is a no-op for DuckDB."""
57
+ if use_duckdb():
58
+ return df
59
+ return df.cache()
60
+
61
+
62
+ def unpersist(df: DataFrame) -> None:
63
+ """Unpersist a dataframe that was previously cached. This is a no-op for DuckDB."""
64
+ if not use_duckdb():
65
+ df.unpersist()
66
+
67
+
68
+ def coalesce(df: DataFrame, num_partitions: int) -> DataFrame:
69
+ """Coalesce the dataframe into num_partitions partitions. This is a no-op for DuckDB."""
70
+ if use_duckdb():
71
+ return df
72
+ return df.coalesce(num_partitions)
73
+
74
+
75
+ def collect_list(df: DataFrame, column: str) -> list:
76
+ """Collect the dataframe column into a list."""
77
+ if use_duckdb():
78
+ return [x[column] for x in df.collect()]
79
+
80
+ return next(iter(df.select(F.collect_list(column)).first()))
81
+
82
+
83
+ def count_distinct_on_group_by(
84
+ df: DataFrame, group_by_columns: list[str], agg_column: str, alias: str
85
+ ) -> DataFrame:
86
+ """Perform a count distinct on one column after grouping."""
87
+ if use_duckdb():
88
+ view = create_temp_view(df)
89
+ cols = ",".join([f'"{x}"' for x in group_by_columns])
90
+ query = f"""
91
+ SELECT {cols}, COUNT(DISTINCT "{agg_column}") AS "{alias}"
92
+ FROM {view}
93
+ GROUP BY {cols}
94
+ """
95
+ return get_spark_session().sql(query)
96
+
97
+ return df.groupBy(*group_by_columns).agg(F.count_distinct(agg_column).alias(alias))
98
+
99
+
100
+ def create_temp_view(df: DataFrame) -> str:
101
+ """Create a temporary view with a random name and return the name."""
102
+ view1 = make_temp_view_name()
103
+ df.createOrReplaceTempView(view1)
104
+ return view1
105
+
106
+
107
+ def make_temp_view_name() -> str:
108
+ """Make a random name to be used as a view."""
109
+ return f"{TEMP_TABLE_PREFIX}_{uuid4().hex}"
110
+
111
+
112
+ def drop_temp_tables_and_views() -> None:
113
+ """Drop all temporary views and tables."""
114
+ drop_temp_views()
115
+ drop_temp_tables()
116
+
117
+
118
+ def drop_temp_tables() -> None:
119
+ """Drop all temporary tables."""
120
+ spark = get_spark_session()
121
+ if use_duckdb():
122
+ query = f"SELECT * FROM pg_tables WHERE tablename LIKE '%{TEMP_TABLE_PREFIX}%'"
123
+ for row in spark.sql(query).collect():
124
+ spark.sql(f"DROP TABLE {row.tablename}")
125
+ else:
126
+ for row in spark.sql(f"SHOW TABLES LIKE '*{TEMP_TABLE_PREFIX}*'").collect():
127
+ spark.sql(f"DROP TABLE {row.tableName}")
128
+
129
+
130
+ def drop_temp_views() -> None:
131
+ """Drop all temporary views."""
132
+ spark = get_spark_session()
133
+ if use_duckdb():
134
+ query = f"""
135
+ SELECT view_name FROM duckdb_views()
136
+ WHERE NOT internal AND view_name LIKE '%{TEMP_TABLE_PREFIX}%'
137
+ """
138
+ for row in spark.sql(query).collect():
139
+ spark.sql(f"DROP VIEW {row.view_name}")
140
+ else:
141
+ for row in spark.sql(f"SHOW VIEWS LIKE '*{TEMP_TABLE_PREFIX}*'").collect():
142
+ spark.sql(f"DROP VIEW {row.viewName}")
143
+
144
+
145
+ def cross_join(df1: DataFrame, df2: DataFrame) -> DataFrame:
146
+ """Return a cross join of the two dataframes."""
147
+ if use_duckdb():
148
+ view1 = create_temp_view(df1)
149
+ view2 = create_temp_view(df2)
150
+ spark = get_spark_session()
151
+ return spark.sql(f"SELECT * from {view1} CROSS JOIN {view2}")
152
+
153
+ return df1.crossJoin(df2)
154
+
155
+
156
+ def except_all(df1: DataFrame, df2: DataFrame) -> DataFrame:
157
+ """Return a dataframe with all rows in df1 that are not in df2."""
158
+ method = _except_all_duckdb if use_duckdb() else _except_all_spark
159
+ return method(df1, df2)
160
+
161
+
162
+ def _except_all_duckdb(df1: DataFrame, df2: DataFrame) -> DataFrame:
163
+ view1 = create_temp_view(df1)
164
+ view2 = create_temp_view(df2)
165
+ query = f"""
166
+ SELECT * FROM {view1}
167
+ EXCEPT ALL
168
+ SELECT * FROM {view2}
169
+ """
170
+ spark = get_spark_session()
171
+ return spark.sql(query)
172
+
173
+
174
+ def _except_all_spark(df1: DataFrame, df2: DataFrame) -> DataFrame:
175
+ return df1.exceptAll(df2)
176
+
177
+
178
+ def handle_column_spaces(column: str) -> str:
179
+ """Return a column string suitable for the backend."""
180
+ if use_duckdb():
181
+ return f'"{column}"'
182
+ return f"`{column}`"
183
+
184
+
185
+ def intersect(df1: DataFrame, df2: DataFrame) -> DataFrame:
186
+ """Return an intersection of rows. Duplicates are not returned"""
187
+ # Could add intersect all if duplicated are needed.
188
+ method = _intersect_duckdb if use_duckdb() else _intersect_spark
189
+ return method(df1, df2)
190
+
191
+
192
+ def _intersect_duckdb(df1: DataFrame, df2: DataFrame) -> DataFrame:
193
+ view1 = create_temp_view(df1)
194
+ view2 = create_temp_view(df2)
195
+ query = f"""
196
+ SELECT * FROM {view1}
197
+ INTERSECT
198
+ SELECT * FROM {view2}
199
+ """
200
+ spark = get_spark_session()
201
+ return spark.sql(query)
202
+
203
+
204
+ def _intersect_spark(df1: DataFrame, df2: DataFrame) -> DataFrame:
205
+ return df1.intersect(df2)
206
+
207
+
208
+ def get_duckdb_spark_session() -> SparkSession | None:
209
+ """Return the active DuckDB Spark Session if it is set."""
210
+ return g_duckdb_spark
211
+
212
+
213
+ def get_spark_session() -> SparkSession:
214
+ """Return the active SparkSession or create a new one is none is active."""
215
+ spark = get_duckdb_spark_session()
216
+ if spark is not None:
217
+ return spark
218
+
219
+ spark = SparkSession.getActiveSession()
220
+ if spark is None:
221
+ logger.warning("Could not find a SparkSession. Create a new one.")
222
+ spark = SparkSession.builder.getOrCreate()
223
+ log_spark_conf(spark)
224
+ return spark
225
+
226
+
227
+ def get_spark_warehouse_dir() -> Path:
228
+ """Return the Spark warehouse directory. Not valid with DuckDB."""
229
+ assert not use_duckdb()
230
+ val = get_spark_session().conf.get("spark.sql.warehouse.dir")
231
+ assert isinstance(val, str)
232
+ if not val:
233
+ msg = "Bug: spark.sql.warehouse.dir is not set"
234
+ raise Exception(msg)
235
+ if not val.startswith("file:"):
236
+ msg = f"get_spark_warehouse_dir only supports local file paths currently: {val}"
237
+ raise NotImplementedError(msg)
238
+ return Path(val.split("file:")[1])
239
+
240
+
241
+ def get_current_time_zone() -> str:
242
+ """Return the current time zone."""
243
+ spark = get_spark_session()
244
+ if use_duckdb():
245
+ res = spark.sql("SELECT * FROM duckdb_settings() WHERE name = 'TimeZone'").collect()
246
+ assert len(res) == 1
247
+ return res[0].value
248
+
249
+ tz = spark.conf.get("spark.sql.session.timeZone")
250
+ assert tz is not None
251
+ return tz
252
+
253
+
254
+ def set_current_time_zone(time_zone: str) -> None:
255
+ """Set the current time zone."""
256
+ spark = get_spark_session()
257
+ if use_duckdb():
258
+ spark.sql(f"SET TimeZone='{time_zone}'")
259
+ return
260
+
261
+ spark.conf.set("spark.sql.session.timeZone", time_zone)
262
+
263
+
264
+ def init_spark(name="dsgrid", check_env=True, spark_conf=None) -> SparkSession:
265
+ """Initialize a SparkSession.
266
+
267
+ Parameters
268
+ ----------
269
+ name : str
270
+ check_env : bool
271
+ If True, which is default, check for the SPARK_CLUSTER environment variable and attach to
272
+ it. Otherwise, create a local-mode cluster or attach to the SparkSession that was created
273
+ by pyspark/spark-submit prior to starting the current process.
274
+ spark_conf : dict | None, defaults to None
275
+ If set, Spark configuration parameters
276
+
277
+ """
278
+ if use_duckdb():
279
+ logger.info("Using DuckDB as the backend engine.")
280
+ return g_duckdb_spark
281
+
282
+ logger.info("Using Spark as the backend engine.")
283
+ cluster = os.environ.get("SPARK_CLUSTER")
284
+ conf = SparkConf().setAppName(name)
285
+ if spark_conf is not None:
286
+ for key, val in spark_conf.items():
287
+ conf.set(key, val)
288
+
289
+ out_ts_type = conf.get("spark.sql.parquet.outputTimestampType")
290
+ if out_ts_type is None:
291
+ conf.set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS")
292
+ elif out_ts_type != "TIMESTAMP_MICROS":
293
+ logger.warning(
294
+ "spark.sql.parquet.outputTimestampType is set to %s. Writing parquet files may "
295
+ "produced undesired results.",
296
+ out_ts_type,
297
+ )
298
+
299
+ if check_env and cluster is not None:
300
+ logger.info("Create SparkSession %s on existing cluster %s", name, cluster)
301
+ conf.setMaster(cluster)
302
+
303
+ config = SparkSession.builder.config(conf=conf)
304
+ if dsgrid.runtime_config.use_hive_metastore:
305
+ config = config.enableHiveSupport()
306
+ spark = config.getOrCreate()
307
+
308
+ with disable_console_logging():
309
+ log_spark_conf(spark)
310
+ logger.info("Custom configuration settings: %s", spark_conf)
311
+
312
+ return spark
313
+
314
+
315
+ def is_dataframe_empty(df: DataFrame) -> bool:
316
+ """Return True if the DataFrame is empty."""
317
+ if use_duckdb():
318
+ view = create_temp_view(df)
319
+ spark = get_spark_session()
320
+ col = df.columns[0]
321
+ return spark.sql(f'SELECT "{col}" FROM {view} LIMIT 1').count() == 0
322
+ return df.rdd.isEmpty()
323
+
324
+
325
+ def perform_interval_op(
326
+ df: DataFrame, time_column, op: str, val: Any, unit: str, alias: str
327
+ ) -> DataFrame:
328
+ """Perform an interval operation ('-' or '+') on a time column."""
329
+ if use_duckdb():
330
+ view = create_temp_view(df)
331
+ cols = df.columns[:]
332
+ if alias == time_column:
333
+ cols.remove(time_column)
334
+ cols_str = ",".join([f'"{x}"' for x in cols])
335
+ query = (
336
+ f'SELECT "{time_column}" {op} INTERVAL {val} {unit} AS {alias}, {cols_str} from {view}'
337
+ )
338
+ return get_spark_session().sql(query)
339
+
340
+ interval_expr = F.expr(f"INTERVAL {val} SECONDS")
341
+ match op:
342
+ case "-":
343
+ expr = F.col(time_column) - interval_expr
344
+ case "+":
345
+ expr = F.col(time_column) + interval_expr
346
+ case _:
347
+ msg = f"{op=} is not supported"
348
+ raise NotImplementedError(msg)
349
+ return df.withColumn(alias, expr)
350
+
351
+
352
+ def join(df1: DataFrame, df2: DataFrame, column1: str, column2: str, how="inner") -> DataFrame:
353
+ """Join two dataframes on one column. Use this method whenever the result may be joined
354
+ with another dataframe in order to workaround a DuckDB issue.
355
+ """
356
+ df = df1.join(df2, on=df1[column1] == df2[column2], how=how)
357
+ if use_duckdb():
358
+ # DuckDB sets the relation alias to "relation", which causes problems with future
359
+ # joins. They declined to address this in https://github.com/duckdb/duckdb/issues/12959
360
+ df.relation = df.relation.set_alias(f"relation_{uuid4()}")
361
+
362
+ return df
363
+
364
+
365
+ def join_multiple_columns(
366
+ df1: DataFrame, df2: DataFrame, columns: list[str], how="inner"
367
+ ) -> DataFrame:
368
+ """Join two dataframes on multiple columns."""
369
+ if use_duckdb():
370
+ view1 = create_temp_view(df1)
371
+ view2 = create_temp_view(df2)
372
+ view2_columns = ",".join((f'{view2}."{x}"' for x in df2.columns if x not in columns))
373
+ on_str = " AND ".join((f'{view1}."{x}" = {view2}."{x}"' for x in columns))
374
+ query = f"""
375
+ SELECT {view1}.*, {view2_columns}
376
+ FROM {view1}
377
+ {how} JOIN {view2}
378
+ ON {on_str}
379
+ """
380
+ # This does not have the alias="relation" issue discussed above.
381
+ return get_spark_session().sql(query)
382
+
383
+ return df1.join(df2, columns, how=how)
384
+
385
+
386
+ def log_spark_conf(spark: SparkSession):
387
+ """Log the Spark configuration details."""
388
+ if not use_duckdb():
389
+ conf = spark.sparkContext.getConf().getAll()
390
+ conf.sort(key=lambda x: x[0])
391
+ logger.info("Spark conf: %s", "\n".join([f"{x} = {y}" for x, y in conf]))
392
+
393
+
394
+ def prepare_timestamps_for_dataframe(timestamps: Iterable[datetime]) -> Iterable[datetime]:
395
+ """Apply necessary conversions of the timestamps for dataframe creation."""
396
+ if use_duckdb():
397
+ return [x.astimezone(ZoneInfo("UTC")) for x in timestamps]
398
+ return timestamps
399
+
400
+
401
+ def read_csv(path: Path | str, cast_timestamp: bool = True) -> DataFrame:
402
+ """Return a DataFrame from a CSV file, handling special cases with duckdb."""
403
+ spark = get_spark_session()
404
+ if use_duckdb():
405
+ path_ = path if isinstance(path, Path) else Path(path)
406
+ if path_.is_dir():
407
+ path_str = str(path_) + "**/*.csv"
408
+ else:
409
+ path_str = str(path_)
410
+ df = spark.createDataFrame(pd.read_csv(path_str))
411
+ if cast_timestamp and "timestamp" in df.columns:
412
+ if "PYTEST_VERSION" not in os.environ:
413
+ msg = f"cast_timestamp in read_csv can only be set in a test environment: {path=}"
414
+ raise Exception(msg)
415
+ # TODO: need a better way of guessing and setting the correct type.
416
+ df = df.withColumn("timestamp", F.col("timestamp").cast(TimestampType()))
417
+ dup_cols = [x for x in df.columns if x.endswith(".1")]
418
+ if dup_cols:
419
+ msg = f"Detected a duplicate column in the dataset: {dup_cols}"
420
+ raise DSGInvalidDimension(msg)
421
+ return df
422
+ return spark.read.csv(str(path), header=True, inferSchema=True)
423
+
424
+
425
+ def read_json(path: Path | str) -> DataFrame:
426
+ """Return a DataFrame from a JSON file, handling special cases with duckdb.
427
+
428
+ Warning: Use of this function with DuckDB is not efficient because it requires that we
429
+ convert line-delimited JSON to standard JSON.
430
+ """
431
+ spark = get_spark_session()
432
+ filename = str(path)
433
+ if use_duckdb():
434
+ with NamedTemporaryFile(suffix=".json") as f:
435
+ f.close()
436
+ # TODO duckdb: look for something more efficient. Not a big deal right now.
437
+ data = load_line_delimited_json(path)
438
+ dump_data(data, f.name)
439
+ return spark.read.json(f.name)
440
+ return spark.read.json(filename, mode="FAILFAST")
441
+
442
+
443
+ def read_parquet(path: Path | str) -> DataFrame:
444
+ path = Path(path) if isinstance(path, str) else path
445
+ spark = get_spark_session()
446
+ if path.is_file() or not use_duckdb():
447
+ df = spark.read.parquet(str(path))
448
+ else:
449
+ df = spark.read.parquet(f"{path}/**/*.parquet")
450
+ return df
451
+
452
+
453
+ def select_expr(df: DataFrame, exprs: list[str]) -> DataFrame:
454
+ """Execute the SQL SELECT expression. It is the caller's responsibility to handle column
455
+ names with spaces or special characters.
456
+ """
457
+ if use_duckdb():
458
+ view = create_temp_view(df)
459
+ spark = get_spark_session()
460
+ cols = ",".join(exprs)
461
+ return spark.sql(f"SELECT {cols} FROM {view}")
462
+ return df.selectExpr(*exprs)
463
+
464
+
465
+ def sql_from_df(df: DataFrame, query: str) -> DataFrame:
466
+ """Run a SQL query on a dataframe with Spark."""
467
+ logger.debug("Run SQL query [%s]", query)
468
+ spark = get_spark_session()
469
+ if use_duckdb():
470
+ view = create_temp_view(df)
471
+ query += f" FROM {view}"
472
+ return spark.sql(query)
473
+
474
+ query += " FROM {df}"
475
+ return spark.sql(query, df=df)
476
+
477
+
478
+ def pivot(df: DataFrame, name_column: str, value_column: str) -> DataFrame:
479
+ """Unpivot the dataframe."""
480
+ method = _pivot_duckdb if use_duckdb() else _pivot_spark
481
+ return method(df, name_column, value_column)
482
+
483
+
484
+ def _pivot_duckdb(df: DataFrame, name_column: str, value_column: str) -> DataFrame:
485
+ view = create_temp_view(df)
486
+ query = f"""
487
+ PIVOT {view}
488
+ ON "{name_column}"
489
+ USING SUM({value_column})
490
+ """
491
+ return get_spark_session().sql(query)
492
+
493
+
494
+ def _pivot_spark(df: DataFrame, name_column: str, value_column: str) -> DataFrame:
495
+ ids = [x for x in df.columns if x not in {name_column, value_column}]
496
+ return df.groupBy(*ids).pivot(name_column).sum(value_column)
497
+
498
+
499
+ def unpivot(df: DataFrame, pivoted_columns, name_column: str, value_column: str) -> DataFrame:
500
+ """Unpivot the dataframe."""
501
+ method = _unpivot_duckdb if use_duckdb() else _unpivot_spark
502
+ return method(df, pivoted_columns, name_column, value_column)
503
+
504
+
505
+ def _unpivot_duckdb(
506
+ df: DataFrame, pivoted_columns, name_column: str, value_column: str
507
+ ) -> DataFrame:
508
+ view = create_temp_view(df)
509
+ cols = ",".join([f'"{x}"' for x in pivoted_columns])
510
+ query = f"""
511
+ SELECT * FROM {view}
512
+ UNPIVOT INCLUDE NULLS (
513
+ "{value_column}"
514
+ FOR "{name_column}" in ({cols})
515
+ )
516
+ """
517
+ spark = get_spark_session()
518
+ df = spark.sql(query)
519
+ return df
520
+
521
+
522
+ def _unpivot_spark(
523
+ df: DataFrame, pivoted_columns, name_column: str, value_column: str
524
+ ) -> DataFrame:
525
+ ids = list(set(df.columns) - {value_column, *pivoted_columns})
526
+ return df.unpivot(
527
+ ids,
528
+ pivoted_columns,
529
+ name_column,
530
+ value_column,
531
+ )
532
+
533
+
534
+ def write_csv(
535
+ df: DataFrame, path: Path | str, header: bool = True, overwrite: bool = False
536
+ ) -> None:
537
+ """Write a DataFrame to a CSV file, handling special cases with duckdb."""
538
+ path_str = path if isinstance(path, str) else str(path)
539
+ if use_duckdb():
540
+ df.relation.write_csv(path_str, header=header, overwrite=overwrite)
541
+ else:
542
+ if overwrite:
543
+ df.write.options(header=True).mode("overwrite").csv(path_str)
544
+ else:
545
+ df.write.options(header=True).csv(path_str)
dsgrid/spark/types.py ADDED
@@ -0,0 +1,50 @@
1
+ # flake8: noqa
2
+
3
+ import dsgrid
4
+ from dsgrid.common import BackendEngine
5
+
6
+
7
+ def use_duckdb() -> bool:
8
+ """Return True if the environment is set to use DuckDB instead of Spark."""
9
+ return dsgrid.runtime_config.backend_engine == BackendEngine.DUCKDB
10
+
11
+
12
+ if use_duckdb():
13
+ import duckdb.experimental.spark.sql.functions as F
14
+ from duckdb.experimental.spark.conf import SparkConf
15
+ from duckdb.experimental.spark.sql import DataFrame, SparkSession
16
+ from duckdb.experimental.spark.sql.types import (
17
+ ByteType,
18
+ StructField,
19
+ StructType,
20
+ StringType,
21
+ BooleanType,
22
+ IntegerType,
23
+ ShortType,
24
+ LongType,
25
+ DoubleType,
26
+ FloatType,
27
+ TimestampType,
28
+ TimestampNTZType,
29
+ Row,
30
+ )
31
+ from duckdb.experimental.spark.errors import AnalysisException
32
+ else:
33
+ import pyspark.sql.functions as F
34
+ from pyspark.sql import DataFrame, Row, SparkSession
35
+ from pyspark.sql.types import (
36
+ ByteType,
37
+ FloatType,
38
+ StructType,
39
+ StructField,
40
+ StringType,
41
+ DoubleType,
42
+ IntegerType,
43
+ LongType,
44
+ ShortType,
45
+ BooleanType,
46
+ TimestampType,
47
+ TimestampNTZType,
48
+ )
49
+ from pyspark.errors import AnalysisException
50
+ from pyspark import SparkConf
File without changes