dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. build_backend.py +93 -0
  2. dsgrid/__init__.py +22 -0
  3. dsgrid/api/__init__.py +0 -0
  4. dsgrid/api/api_manager.py +179 -0
  5. dsgrid/api/app.py +419 -0
  6. dsgrid/api/models.py +60 -0
  7. dsgrid/api/response_models.py +116 -0
  8. dsgrid/apps/__init__.py +0 -0
  9. dsgrid/apps/project_viewer/app.py +216 -0
  10. dsgrid/apps/registration_gui.py +444 -0
  11. dsgrid/chronify.py +32 -0
  12. dsgrid/cli/__init__.py +0 -0
  13. dsgrid/cli/common.py +120 -0
  14. dsgrid/cli/config.py +176 -0
  15. dsgrid/cli/download.py +13 -0
  16. dsgrid/cli/dsgrid.py +157 -0
  17. dsgrid/cli/dsgrid_admin.py +92 -0
  18. dsgrid/cli/install_notebooks.py +62 -0
  19. dsgrid/cli/query.py +729 -0
  20. dsgrid/cli/registry.py +1862 -0
  21. dsgrid/cloud/__init__.py +0 -0
  22. dsgrid/cloud/cloud_storage_interface.py +140 -0
  23. dsgrid/cloud/factory.py +31 -0
  24. dsgrid/cloud/fake_storage_interface.py +37 -0
  25. dsgrid/cloud/s3_storage_interface.py +156 -0
  26. dsgrid/common.py +36 -0
  27. dsgrid/config/__init__.py +0 -0
  28. dsgrid/config/annual_time_dimension_config.py +194 -0
  29. dsgrid/config/common.py +142 -0
  30. dsgrid/config/config_base.py +148 -0
  31. dsgrid/config/dataset_config.py +907 -0
  32. dsgrid/config/dataset_schema_handler_factory.py +46 -0
  33. dsgrid/config/date_time_dimension_config.py +136 -0
  34. dsgrid/config/dimension_config.py +54 -0
  35. dsgrid/config/dimension_config_factory.py +65 -0
  36. dsgrid/config/dimension_mapping_base.py +350 -0
  37. dsgrid/config/dimension_mappings_config.py +48 -0
  38. dsgrid/config/dimensions.py +1025 -0
  39. dsgrid/config/dimensions_config.py +71 -0
  40. dsgrid/config/file_schema.py +190 -0
  41. dsgrid/config/index_time_dimension_config.py +80 -0
  42. dsgrid/config/input_dataset_requirements.py +31 -0
  43. dsgrid/config/mapping_tables.py +209 -0
  44. dsgrid/config/noop_time_dimension_config.py +42 -0
  45. dsgrid/config/project_config.py +1462 -0
  46. dsgrid/config/registration_models.py +188 -0
  47. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  48. dsgrid/config/simple_models.py +49 -0
  49. dsgrid/config/supplemental_dimension.py +29 -0
  50. dsgrid/config/time_dimension_base_config.py +192 -0
  51. dsgrid/data_models.py +155 -0
  52. dsgrid/dataset/__init__.py +0 -0
  53. dsgrid/dataset/dataset.py +123 -0
  54. dsgrid/dataset/dataset_expression_handler.py +86 -0
  55. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  56. dsgrid/dataset/dataset_schema_handler_base.py +945 -0
  57. dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
  58. dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
  59. dsgrid/dataset/growth_rates.py +162 -0
  60. dsgrid/dataset/models.py +51 -0
  61. dsgrid/dataset/table_format_handler_base.py +257 -0
  62. dsgrid/dataset/table_format_handler_factory.py +17 -0
  63. dsgrid/dataset/unpivoted_table.py +121 -0
  64. dsgrid/dimension/__init__.py +0 -0
  65. dsgrid/dimension/base_models.py +230 -0
  66. dsgrid/dimension/dimension_filters.py +308 -0
  67. dsgrid/dimension/standard.py +252 -0
  68. dsgrid/dimension/time.py +352 -0
  69. dsgrid/dimension/time_utils.py +103 -0
  70. dsgrid/dsgrid_rc.py +88 -0
  71. dsgrid/exceptions.py +105 -0
  72. dsgrid/filesystem/__init__.py +0 -0
  73. dsgrid/filesystem/cloud_filesystem.py +32 -0
  74. dsgrid/filesystem/factory.py +32 -0
  75. dsgrid/filesystem/filesystem_interface.py +136 -0
  76. dsgrid/filesystem/local_filesystem.py +74 -0
  77. dsgrid/filesystem/s3_filesystem.py +118 -0
  78. dsgrid/loggers.py +132 -0
  79. dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
  80. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
  81. dsgrid/notebooks/registration.ipynb +48 -0
  82. dsgrid/notebooks/start_notebook.sh +11 -0
  83. dsgrid/project.py +451 -0
  84. dsgrid/query/__init__.py +0 -0
  85. dsgrid/query/dataset_mapping_plan.py +142 -0
  86. dsgrid/query/derived_dataset.py +388 -0
  87. dsgrid/query/models.py +728 -0
  88. dsgrid/query/query_context.py +287 -0
  89. dsgrid/query/query_submitter.py +994 -0
  90. dsgrid/query/report_factory.py +19 -0
  91. dsgrid/query/report_peak_load.py +70 -0
  92. dsgrid/query/reports_base.py +20 -0
  93. dsgrid/registry/__init__.py +0 -0
  94. dsgrid/registry/bulk_register.py +165 -0
  95. dsgrid/registry/common.py +287 -0
  96. dsgrid/registry/config_update_checker_base.py +63 -0
  97. dsgrid/registry/data_store_factory.py +34 -0
  98. dsgrid/registry/data_store_interface.py +74 -0
  99. dsgrid/registry/dataset_config_generator.py +158 -0
  100. dsgrid/registry/dataset_registry_manager.py +950 -0
  101. dsgrid/registry/dataset_update_checker.py +16 -0
  102. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  103. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  104. dsgrid/registry/dimension_registry_manager.py +413 -0
  105. dsgrid/registry/dimension_update_checker.py +16 -0
  106. dsgrid/registry/duckdb_data_store.py +207 -0
  107. dsgrid/registry/filesystem_data_store.py +150 -0
  108. dsgrid/registry/filter_registry_manager.py +123 -0
  109. dsgrid/registry/project_config_generator.py +57 -0
  110. dsgrid/registry/project_registry_manager.py +1623 -0
  111. dsgrid/registry/project_update_checker.py +48 -0
  112. dsgrid/registry/registration_context.py +223 -0
  113. dsgrid/registry/registry_auto_updater.py +316 -0
  114. dsgrid/registry/registry_database.py +667 -0
  115. dsgrid/registry/registry_interface.py +446 -0
  116. dsgrid/registry/registry_manager.py +558 -0
  117. dsgrid/registry/registry_manager_base.py +367 -0
  118. dsgrid/registry/versioning.py +92 -0
  119. dsgrid/rust_ext/__init__.py +14 -0
  120. dsgrid/rust_ext/find_minimal_patterns.py +129 -0
  121. dsgrid/spark/__init__.py +0 -0
  122. dsgrid/spark/functions.py +589 -0
  123. dsgrid/spark/types.py +110 -0
  124. dsgrid/tests/__init__.py +0 -0
  125. dsgrid/tests/common.py +140 -0
  126. dsgrid/tests/make_us_data_registry.py +265 -0
  127. dsgrid/tests/register_derived_datasets.py +103 -0
  128. dsgrid/tests/utils.py +25 -0
  129. dsgrid/time/__init__.py +0 -0
  130. dsgrid/time/time_conversions.py +80 -0
  131. dsgrid/time/types.py +67 -0
  132. dsgrid/units/__init__.py +0 -0
  133. dsgrid/units/constants.py +113 -0
  134. dsgrid/units/convert.py +71 -0
  135. dsgrid/units/energy.py +145 -0
  136. dsgrid/units/power.py +87 -0
  137. dsgrid/utils/__init__.py +0 -0
  138. dsgrid/utils/dataset.py +830 -0
  139. dsgrid/utils/files.py +179 -0
  140. dsgrid/utils/filters.py +125 -0
  141. dsgrid/utils/id_remappings.py +100 -0
  142. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  143. dsgrid/utils/py_expression_eval/README.md +8 -0
  144. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  145. dsgrid/utils/py_expression_eval/tests.py +283 -0
  146. dsgrid/utils/run_command.py +70 -0
  147. dsgrid/utils/scratch_dir_context.py +65 -0
  148. dsgrid/utils/spark.py +918 -0
  149. dsgrid/utils/spark_partition.py +98 -0
  150. dsgrid/utils/timing.py +239 -0
  151. dsgrid/utils/utilities.py +221 -0
  152. dsgrid/utils/versioning.py +36 -0
  153. dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
  154. dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
  155. dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
  156. dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
  157. dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
dsgrid/utils/spark.py ADDED
@@ -0,0 +1,918 @@
1
+ """Spark helper functions"""
2
+
3
+ import enum
4
+ import itertools
5
+ import logging
6
+ import math
7
+ import os
8
+ import shutil
9
+ import time
10
+ from contextlib import contextmanager
11
+ from pathlib import Path
12
+ from types import UnionType
13
+ from typing import Any, Generator, Iterable, Sequence, Type, Union, get_origin, get_args
14
+
15
+ import duckdb
16
+ import pandas as pd
17
+
18
+ from dsgrid.data_models import DSGBaseModel
19
+ from dsgrid.exceptions import (
20
+ DSGInvalidField,
21
+ DSGInvalidFile,
22
+ DSGInvalidOperation,
23
+ DSGInvalidParameter,
24
+ )
25
+ from dsgrid.utils.files import delete_if_exists, load_data
26
+ from dsgrid.utils.scratch_dir_context import ScratchDirContext
27
+ from dsgrid.spark.functions import (
28
+ cross_join,
29
+ get_spark_session,
30
+ get_duckdb_spark_session,
31
+ get_current_time_zone,
32
+ set_current_time_zone,
33
+ init_spark,
34
+ is_dataframe_empty,
35
+ read_csv,
36
+ read_json,
37
+ read_parquet,
38
+ )
39
+ from dsgrid.spark.types import (
40
+ AnalysisException,
41
+ BooleanType,
42
+ DataFrame,
43
+ DoubleType,
44
+ IntegerType,
45
+ SparkSession,
46
+ StringType,
47
+ StructField,
48
+ StructType,
49
+ use_duckdb,
50
+ )
51
+ from dsgrid.utils.timing import Timer, track_timing, timer_stats_collector
52
+
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+ # Consider using our own database. Would need to manage creation with
57
+ # spark.sql(f"CREATE DATABASE IF NOT EXISTS {database}")
58
+ # Doing so has caused conflicts in tests with the Derby db.
59
+ DSGRID_DB_NAME = "default"
60
+
61
+ MAX_PARTITION_SIZE_MB = 128
62
+
63
+ PYTHON_TO_SPARK_TYPES = {
64
+ int: IntegerType,
65
+ float: DoubleType,
66
+ str: StringType,
67
+ bool: BooleanType,
68
+ }
69
+
70
+
71
+ def get_active_session(*args) -> SparkSession:
72
+ """Return the active Spark Session."""
73
+ return get_duckdb_spark_session() or init_spark(*args)
74
+
75
+
76
+ def restart_spark(*args, force=False, **kwargs) -> SparkSession:
77
+ """Restart a SparkSession with new config parameters. Refer to init_spark for parameters.
78
+
79
+ Parameters
80
+ ----------
81
+ force : bool
82
+ If True, restart the session even if the config parameters haven't changed.
83
+ You might want to do this in order to clear cached tables or start Spark fresh.
84
+
85
+ Returns
86
+ -------
87
+ pyspark.sql.SparkSession
88
+
89
+ """
90
+ spark = get_duckdb_spark_session()
91
+ if spark is not None:
92
+ return spark
93
+
94
+ spark = SparkSession.getActiveSession()
95
+ needs_restart = force
96
+ orig_time_zone = spark.conf.get("spark.sql.session.timeZone")
97
+ conf = kwargs.get("spark_conf", {})
98
+ new_time_zone = conf.get("spark.sql.session.timeZone", orig_time_zone)
99
+
100
+ if not force:
101
+ for key, val in conf.items():
102
+ current = spark.conf.get(key, None)
103
+ if isinstance(current, str):
104
+ match current.lower():
105
+ case "true":
106
+ current = True
107
+ case "false":
108
+ current = False
109
+ if current is not None and current != val:
110
+ logger.info("SparkSession needs restart because of %s = %s", key, val)
111
+ needs_restart = True
112
+ break
113
+
114
+ if needs_restart:
115
+ spark.stop()
116
+ logger.info("Stopped the SparkSession so that it can be restarted with a new config.")
117
+ spark = init_spark(*args, **kwargs)
118
+ if spark.conf.get("spark.sql.session.timeZone") != new_time_zone:
119
+ # We set this value in query_submitter.py and that change will get lost
120
+ # when the session is restarted.
121
+ spark.conf.set("spark.sql.session.timeZone", new_time_zone)
122
+ else:
123
+ logger.info("No restart of Spark is needed.")
124
+
125
+ return spark
126
+
127
+
128
+ @track_timing(timer_stats_collector)
129
+ def create_dataframe(records, table_name=None, require_unique=None) -> DataFrame:
130
+ """Create a spark DataFrame from a list of records.
131
+
132
+ Parameters
133
+ ----------
134
+ records : list
135
+ list of spark.sql.Row
136
+ table_name : str | None
137
+ If set, cache the DataFrame in memory with this name. Must be unique.
138
+ require_unique : list
139
+ list of column names (str) to check for uniqueness
140
+ """
141
+ df = get_spark_session().createDataFrame(records)
142
+ _post_process_dataframe(df, table_name=table_name, require_unique=require_unique)
143
+ return df
144
+
145
+
146
+ @track_timing(timer_stats_collector)
147
+ def create_dataframe_from_ids(ids: Iterable[str], column: str) -> DataFrame:
148
+ """Create a spark DataFrame from a list of dimension IDs."""
149
+ schema = StructType([StructField(column, StringType())])
150
+ return get_spark_session().createDataFrame([[x] for x in ids], schema)
151
+
152
+
153
+ def create_dataframe_from_pandas(df):
154
+ """Create a spark DataFrame from a pandas DataFrame."""
155
+ return get_spark_session().createDataFrame(df)
156
+
157
+
158
+ def create_dataframe_from_dicts(records: list[dict[str, Any]]) -> DataFrame:
159
+ """Create a spark DataFrame from a list of dictionaries.
160
+
161
+ The only purpose is to avoid pyright complaints about the type of the input to
162
+ spark.createDataFrame. This can be removed if pyspark fixes the type annotations.
163
+ """
164
+ if not records:
165
+ msg = "records cannot be empty in create_dataframe_from_dicts"
166
+ raise DSGInvalidParameter(msg)
167
+
168
+ data = [tuple(row.values()) for row in records]
169
+ columns = list(records[0].keys())
170
+ return get_spark_session().createDataFrame(data, columns)
171
+
172
+
173
+ def try_read_dataframe(filename: Path, delete_if_invalid=True, **kwargs):
174
+ """Try to read the dataframe.
175
+
176
+ Parameters
177
+ ----------
178
+ filename : Path
179
+ delete_if_invalid : bool
180
+ Delete the file if it cannot be read, defaults to true.
181
+ kwargs
182
+ Forwarded to read_dataframe.
183
+
184
+ Returns
185
+ -------
186
+ pyspark.sql.DataFrame | None
187
+ Returns None if the file does not exist or is invalid.
188
+
189
+ """
190
+ if not filename.exists():
191
+ return None
192
+
193
+ try:
194
+ return read_dataframe(filename, **kwargs)
195
+ except DSGInvalidFile:
196
+ if delete_if_invalid:
197
+ if filename.is_dir():
198
+ shutil.rmtree(filename)
199
+ else:
200
+ filename.unlink()
201
+ return None
202
+
203
+
204
+ @track_timing(timer_stats_collector)
205
+ def read_dataframe(
206
+ filename: str | Path,
207
+ table_name: str | None = None,
208
+ require_unique: None | bool = None,
209
+ read_with_spark: bool = True,
210
+ ) -> DataFrame:
211
+ """Create a spark DataFrame from a file.
212
+
213
+ Supported formats when read_with_spark=True: .csv, .json, .parquet
214
+ Supported formats when read_with_spark=False: .csv, .json
215
+
216
+ When reading CSV files on AWS read_with_spark should be set to False because the
217
+ files would need to be present on local storage for all workers. The master node
218
+ will sync the config files from S3, read them with standard filesystem system calls,
219
+ and then convert the data to Spark dataframes. This could change if we ever decide
220
+ to read CSV files with Spark directly from S3.
221
+
222
+ Parameters
223
+ ----------
224
+ filename : str | Path
225
+ path to file
226
+ table_name : str | None
227
+ If set, cache the DataFrame in memory. Must be unique.
228
+ require_unique : list
229
+ list of column names (str) to check for uniqueness
230
+ read_with_spark : bool
231
+ If True, read the file with pyspark.read. Otherwise, read the file into
232
+ a list of dicts, convert to pyspark Rows, and then to a DataFrame.
233
+
234
+ Returns
235
+ -------
236
+ spark.sql.DataFrame
237
+
238
+ Raises
239
+ ------
240
+ ValueError
241
+ Raised if a require_unique column has duplicate values.
242
+ DSGInvalidFile
243
+ Raised if the file cannot be read. This can happen if a Parquet write operation fails.
244
+
245
+ """
246
+ func = _read_with_spark if read_with_spark else _read_natively
247
+ df = func(str(filename))
248
+ _post_process_dataframe(df, table_name=table_name, require_unique=require_unique)
249
+ return df
250
+
251
+
252
+ def _read_with_spark(filename):
253
+ if not os.path.exists(filename):
254
+ msg = f"{filename} does not exist"
255
+ raise FileNotFoundError(msg)
256
+ suffix = Path(filename).suffix
257
+ if suffix == ".csv":
258
+ df = read_csv(filename)
259
+ elif suffix == ".parquet":
260
+ try:
261
+ df = read_parquet(filename)
262
+ except AnalysisException as exc:
263
+ if "Unable to infer schema for Parquet. It must be specified manually." in str(exc):
264
+ logger.exception("Failed to read Parquet file=%s. File may be invalid", filename)
265
+ msg = f"Cannot read {filename=}"
266
+ raise DSGInvalidFile(msg)
267
+ raise
268
+ except duckdb.duckdb.IOException:
269
+ logger.exception("Failed to read Parquet file=%s. File may be invalid", filename)
270
+ msg = f"Cannot read {filename=}"
271
+ raise DSGInvalidFile(msg)
272
+
273
+ elif suffix == ".json":
274
+ df = read_json(filename)
275
+ else:
276
+ assert False, f"Unsupported file extension: {filename}"
277
+ return df
278
+
279
+
280
+ def _read_natively(filename):
281
+ suffix = Path(filename).suffix
282
+ if suffix == ".csv":
283
+ # Reading the file is faster with pandas. Converting a list of Row to spark df
284
+ # is a tiny bit faster. Pandas is likely scales better with bigger files.
285
+ # Keep the code in case we ever want to revert.
286
+ # with open(filename, encoding="utf-8-sig") as f_in:
287
+ # rows = [Row(**x) for x in csv.DictReader(f_in)]
288
+ obj = pd.read_csv(filename)
289
+ elif suffix == ".json":
290
+ obj = load_data(filename)
291
+ else:
292
+ msg = f"Unsupported file extension: {filename}"
293
+ raise NotImplementedError(msg)
294
+ return get_spark_session().createDataFrame(obj)
295
+
296
+
297
+ def _post_process_dataframe(df, table_name=None, require_unique=None):
298
+ if not use_duckdb() and table_name is not None:
299
+ df.createOrReplaceTempView(table_name)
300
+ df.cache()
301
+
302
+ if require_unique is not None:
303
+ with Timer(timer_stats_collector, "check_unique"):
304
+ for column in require_unique:
305
+ unique = df.select(column).distinct()
306
+ if unique.count() != df.count():
307
+ msg = f"DataFrame has duplicate entries for {column}"
308
+ raise DSGInvalidField(msg)
309
+
310
+
311
+ def cross_join_dfs(dfs: list[DataFrame]) -> DataFrame:
312
+ """Perform a cross join of all dataframes in dfs."""
313
+ if len(dfs) == 1:
314
+ return dfs[0]
315
+
316
+ df = dfs[0]
317
+ for other in dfs[1:]:
318
+ df = cross_join(df, other)
319
+ return df
320
+
321
+
322
+ def get_unique_values(df: DataFrame, columns: Sequence[str]) -> set[str]:
323
+ """Return the unique values of a dataframe in one column or a list of columns."""
324
+ dfc = df.select(columns).distinct().collect()
325
+ if isinstance(columns, list):
326
+ values = {tuple(getattr(row, col) for col in columns) for row in dfc}
327
+ else:
328
+ values = {getattr(x, columns) for x in dfc}
329
+
330
+ return values
331
+
332
+
333
+ @track_timing(timer_stats_collector)
334
+ def models_to_dataframe(models: list[DSGBaseModel], table_name: str | None = None) -> DataFrame:
335
+ """Converts a list of Pydantic models to a Spark DataFrame.
336
+
337
+ Parameters
338
+ ----------
339
+ models : list
340
+ table_name : str | None
341
+ If set, a unique ID to use as the cached table name. Return from cache if already stored.
342
+ """
343
+ spark = get_spark_session()
344
+ if not use_duckdb():
345
+ if (
346
+ table_name is not None
347
+ and spark.catalog.tableExists(table_name)
348
+ and spark.catalog.isCached(table_name)
349
+ ):
350
+ return spark.table(table_name)
351
+
352
+ assert models
353
+ cls = type(models[0])
354
+ rows = []
355
+ struct_fields = []
356
+ for i, model in enumerate(models):
357
+ dct = {}
358
+ for f in cls.model_fields:
359
+ val = getattr(model, f)
360
+ if isinstance(val, enum.Enum):
361
+ val = val.value
362
+ if i == 0:
363
+ if val is None:
364
+ python_type = cls.model_fields[f].annotation
365
+ origin = get_origin(python_type)
366
+ if origin is Union or origin is UnionType:
367
+ python_type = get_type_from_union(python_type)
368
+ # else: will likely fail below
369
+ # Need to add more logic to detect the actual type or add to
370
+ # PYTHON_TO_SPARK_TYPES.
371
+ else:
372
+ python_type = type(val)
373
+ spark_type = PYTHON_TO_SPARK_TYPES[python_type]()
374
+ struct_fields.append(StructField(f, spark_type, nullable=True))
375
+ dct[f] = val
376
+ rows.append(tuple(dct.values()))
377
+
378
+ schema = StructType(struct_fields)
379
+ df = spark.createDataFrame(rows, schema=schema)
380
+
381
+ if not use_duckdb() and table_name is not None:
382
+ df.createOrReplaceTempView(table_name)
383
+ df.cache()
384
+
385
+ return df
386
+
387
+
388
+ def get_type_from_union(python_type) -> Type:
389
+ """Return the Python type from a Union.
390
+
391
+ Only works if it is Union of NoneType and something.
392
+
393
+ Raises
394
+ ------
395
+ NotImplementedError
396
+ Raised if the code does know how to determine the type.
397
+ """
398
+ args = get_args(python_type)
399
+ if issubclass(args[0], enum.Enum):
400
+ python_type = type(next(iter(args[0])).value)
401
+ else:
402
+ types = [x for x in args if not issubclass(x, type(None))]
403
+ if not types:
404
+ msg = f"Unhandled Union type: {python_type=} {args=}"
405
+ raise NotImplementedError(msg)
406
+ elif len(types) > 1:
407
+ msg = f"Unhandled Union type: {types=}"
408
+ raise NotImplementedError(msg)
409
+ else:
410
+ python_type = types[0]
411
+
412
+ return python_type
413
+
414
+
415
+ @track_timing(timer_stats_collector)
416
+ def create_dataframe_from_dimension_ids(records, *dimension_types, cache=True) -> DataFrame:
417
+ """Return a DataFrame created from the IDs of dimension_types.
418
+
419
+ Parameters
420
+ ----------
421
+ records : sequence
422
+ Iterable of lists of record IDs
423
+ dimension_types : tuple
424
+ cache : If True, cache the DataFrame.
425
+ """
426
+ schema = StructType()
427
+ for dimension_type in dimension_types:
428
+ schema.add(dimension_type.value, StringType(), nullable=False)
429
+ df = get_spark_session().createDataFrame(records, schema=schema)
430
+ if not use_duckdb() and cache:
431
+ df.cache()
432
+ return df
433
+
434
+
435
+ @track_timing(timer_stats_collector)
436
+ def check_for_nulls(df, exclude_columns=None):
437
+ """Check if a DataFrame has null values.
438
+
439
+ Parameters
440
+ ----------
441
+ df : spark.sql.DataFrame
442
+ exclude_columns : None or Set
443
+
444
+ Raises
445
+ ------
446
+ DSGInvalidField
447
+ Raised if null exists in any column.
448
+
449
+ """
450
+ if exclude_columns is None:
451
+ exclude_columns = set()
452
+ cols_to_check = set(df.columns).difference(exclude_columns)
453
+ cols_str = ", ".join(cols_to_check)
454
+ filter_str = " OR ".join((f"{x} IS NULL" for x in cols_to_check))
455
+ df.createOrReplaceTempView("tmp_view")
456
+
457
+ try:
458
+ # Avoid iterating with many checks unless we know there is at least one failure.
459
+ nulls = sql(f"SELECT {cols_str} FROM tmp_view WHERE {filter_str}")
460
+ if not is_dataframe_empty(nulls):
461
+ cols_with_null = set()
462
+ for col in cols_to_check:
463
+ if not is_dataframe_empty(nulls.select(col).filter(f"{col} is NULL")):
464
+ cols_with_null.add(col)
465
+ assert cols_with_null, "Did not find any columns with NULL values"
466
+
467
+ msg = f"DataFrame contains NULL value(s) for column(s): {cols_with_null}"
468
+ raise DSGInvalidField(msg)
469
+ finally:
470
+ sql("DROP VIEW tmp_view")
471
+
472
+
473
+ @track_timing(timer_stats_collector)
474
+ def overwrite_dataframe_file(filename: Path | str, df: DataFrame) -> DataFrame:
475
+ """Perform an in-place overwrite of a Spark DataFrame, accounting for different file types
476
+ and symlinks.
477
+
478
+ Do not attempt to access the original dataframe unless it was fully cached.
479
+ """
480
+ spark = get_spark_session()
481
+ suffix = Path(filename).suffix
482
+ tmp = str(filename) + ".tmp"
483
+ if suffix == ".parquet":
484
+ df.write.parquet(tmp)
485
+ read_method = read_parquet
486
+ kwargs = {}
487
+ elif suffix == ".csv":
488
+ df.write.csv(str(tmp), header=True)
489
+ read_method = spark.read.csv
490
+ kwargs = {"header": True, "schema": df.schema}
491
+ elif suffix == ".json":
492
+ df.write.json(str(tmp))
493
+ read_method = spark.read.json
494
+ kwargs = {}
495
+ delete_if_exists(filename)
496
+ os.rename(tmp, str(filename))
497
+ return read_method(str(filename), **kwargs)
498
+
499
+
500
+ @track_timing(timer_stats_collector)
501
+ def persist_intermediate_query(
502
+ df: DataFrame, scratch_dir_context: ScratchDirContext, auto_partition=False
503
+ ) -> DataFrame:
504
+ """Persist the current query to files and then read it back and return it.
505
+
506
+ This is advised when the query has become too complex or when the query might be evaluated
507
+ twice.
508
+
509
+ Parameters
510
+ ----------
511
+ df : DataFrame
512
+ scratch_dir_context : ScratchDirContext
513
+ auto_partition : bool
514
+ If True, call write_dataframe_and_auto_partition.
515
+
516
+ Returns
517
+ -------
518
+ DataFrame
519
+ """
520
+ spark = get_spark_session()
521
+ tmp_file = scratch_dir_context.get_temp_filename(suffix=".parquet")
522
+ if auto_partition:
523
+ return write_dataframe_and_auto_partition(df, tmp_file)
524
+ df.write.parquet(str(tmp_file))
525
+ return spark.read.parquet(str(tmp_file))
526
+
527
+
528
+ @track_timing(timer_stats_collector)
529
+ def write_dataframe_and_auto_partition(
530
+ df: DataFrame,
531
+ filename: Path,
532
+ partition_size_mb: int = MAX_PARTITION_SIZE_MB,
533
+ columns: list[str] | None = None,
534
+ rtol_pct: float = 50,
535
+ min_num_partitions: int = 36,
536
+ ) -> DataFrame:
537
+ """Write a dataframe to a Parquet file and then automatically coalesce or repartition it if
538
+ needed. If the file already exists, it will be overwritten.
539
+
540
+ Parameters
541
+ ----------
542
+ df : pyspark.sql.DataFrame
543
+ filename : Path
544
+ partition_size_mb : int
545
+ Target size in MB for each partition
546
+ columns : None, list
547
+ If not None and repartitioning is needed, partition on these columns.
548
+ rtol_pct : int
549
+ Don't repartition or coalesce if the relative difference between desired and actual
550
+ partitions is within this tolerance as a percentage.
551
+ min_num_partitions : int
552
+ Minimum number of partitions to create. If the number of partitions is less than this,
553
+ Do not coalesce/repartition because it will reduce parallelism.
554
+
555
+ Raises
556
+ ------
557
+ DSGInvalidParameter
558
+ Raised if a non-Parquet file is passed
559
+ """
560
+ suffix = Path(filename).suffix
561
+ if suffix != ".parquet":
562
+ msg = "write_dataframe_and_auto_partition only supports Parquet files: {filename=}"
563
+ raise DSGInvalidParameter(msg)
564
+
565
+ start_initial_write = time.time()
566
+ if filename.exists():
567
+ df = overwrite_dataframe_file(filename, df)
568
+ else:
569
+ df.write.parquet(str(filename))
570
+ df = read_parquet(filename)
571
+
572
+ end_initial_write = time.time()
573
+ duration_first_write = end_initial_write - start_initial_write
574
+
575
+ if use_duckdb():
576
+ logger.debug("write_dataframe_and_auto_partition is not optimized for DuckDB")
577
+ return df
578
+
579
+ num_partitions = len(list(filename.parent.iterdir()))
580
+ if num_partitions < min_num_partitions:
581
+ logger.info(
582
+ "Not coalescing %s because it has only %s partitions, "
583
+ "which is less than the minimum of %s.",
584
+ filename,
585
+ num_partitions,
586
+ min_num_partitions,
587
+ )
588
+ # TODO: consider repartitioning to increase the number of partitions.
589
+ return df
590
+
591
+ partition_size_bytes = partition_size_mb * 1024 * 1024
592
+ total_size = sum((x.stat().st_size for x in filename.glob("*.parquet")))
593
+ desired = math.ceil(total_size / partition_size_bytes)
594
+ actual = len(list(filename.glob("*.parquet")))
595
+ if abs(actual - desired) / desired * 100 < rtol_pct:
596
+ logger.info("No change in number of partitions is needed for %s.", filename)
597
+ elif actual > desired:
598
+ df = df.coalesce(desired)
599
+ df = overwrite_dataframe_file(filename, df)
600
+ duration_second_write = time.time() - end_initial_write
601
+ logger.info(
602
+ "Coalesced %s from partition count %s to %s. "
603
+ "duration_first_write=%s duration_second_write=%s",
604
+ filename,
605
+ actual,
606
+ desired,
607
+ duration_first_write,
608
+ duration_second_write,
609
+ )
610
+ else:
611
+ if columns is None:
612
+ df = df.repartition(desired)
613
+ else:
614
+ df = df.repartition(desired, *columns)
615
+ df = overwrite_dataframe_file(filename, df)
616
+ duration_second_write = time.time() - end_initial_write
617
+ logger.info(
618
+ "Repartitioned %s from partition count %s to %s. "
619
+ "duration_first_write=%s duration_second_write=%s",
620
+ filename,
621
+ actual,
622
+ desired,
623
+ duration_first_write,
624
+ duration_second_write,
625
+ )
626
+
627
+ logger.info("Wrote dataframe to %s", filename)
628
+ return df
629
+
630
+
631
+ @track_timing(timer_stats_collector)
632
+ def write_dataframe(df: DataFrame, filename: str | Path, overwrite: bool = False) -> None:
633
+ """Write a Spark DataFrame, accounting for different file types.
634
+
635
+ Parameters
636
+ ----------
637
+ filename : str
638
+ df : pyspark.sql.DataFrame
639
+ """
640
+ path = Path(filename)
641
+ if overwrite:
642
+ delete_if_exists(path)
643
+
644
+ suffix = path.suffix
645
+ name = str(filename)
646
+ if suffix == ".parquet":
647
+ df.write.parquet(name)
648
+ elif suffix == ".csv":
649
+ df.write.csv(name, header=True)
650
+ elif suffix == ".json":
651
+ if use_duckdb():
652
+ new_name = name.replace(".json", ".parquet")
653
+ df.write.parquet(new_name)
654
+ else:
655
+ df.write.json(name)
656
+
657
+
658
+ @track_timing(timer_stats_collector)
659
+ def persist_table(df: DataFrame, context: ScratchDirContext, tag=None) -> Path:
660
+ """Persist a table to the scratch directory. This can be helpful to avoid multiple
661
+ evaluations of the same query.
662
+ """
663
+ # Note: This does not use the Spark warehouse because we are not properly configuring or
664
+ # managing it across sessions. And, we are already using the scratch dir for our own files.
665
+ path = context.get_temp_filename(suffix=".parquet")
666
+ logger.info("Start persist_table %s %s", path, tag or "")
667
+ write_dataframe(df, path)
668
+ logger.info("Completed persist_table %s %s", path, tag or "")
669
+ return path
670
+
671
+
672
+ @track_timing(timer_stats_collector)
673
+ def save_to_warehouse(df: DataFrame, table_name: str) -> DataFrame:
674
+ """Save a table to the Spark warehouse. Not supported when using DuckDB."""
675
+ if use_duckdb():
676
+ msg = "save_to_warehouse is not supported when using DuckDB"
677
+ raise DSGInvalidOperation(msg)
678
+
679
+ logger.info("Start saveAsTable to warehouse %s", table_name)
680
+ df.write.saveAsTable(table_name)
681
+ logger.info("Completed saveAsTable %s", table_name)
682
+ return df.sparkSession.sql(f"select * from {table_name}")
683
+
684
+
685
+ def sql(query: str) -> DataFrame:
686
+ """Run a SQL query with Spark."""
687
+ logger.debug("Run SQL query [%s]", query)
688
+ return get_spark_session().sql(query)
689
+
690
+
691
+ def load_stored_table(table_name: str) -> DataFrame:
692
+ """Return a table stored in the Spark warehouse."""
693
+ spark = get_spark_session()
694
+ return spark.table(table_name)
695
+
696
+
697
+ def try_load_stored_table(
698
+ table_name: str, database: str | None = DSGRID_DB_NAME
699
+ ) -> DataFrame | None:
700
+ """Return a table if it is stored in the Spark warehouse."""
701
+ spark = get_spark_session()
702
+ full_name = f"{database}.{table_name}"
703
+ if spark.catalog.tableExists(full_name):
704
+ return spark.table(table_name)
705
+ return None
706
+
707
+
708
+ def is_table_stored(table_name, database=DSGRID_DB_NAME):
709
+ spark = get_spark_session()
710
+ full_name = f"{database}.{table_name}"
711
+ return spark.catalog.tableExists(full_name)
712
+
713
+
714
+ def save_table(table, table_name, overwrite=True, database=DSGRID_DB_NAME):
715
+ full_name = f"{database}.{table_name}"
716
+ if overwrite:
717
+ table.write.mode("overwrite").saveAsTable(full_name)
718
+ else:
719
+ table.write.saveAsTable(full_name)
720
+
721
+
722
+ def list_tables(database=DSGRID_DB_NAME):
723
+ spark = get_spark_session()
724
+ return [x.name for x in spark.catalog.listTables(dbName=database)]
725
+
726
+
727
+ def drop_table(table_name, database=DSGRID_DB_NAME):
728
+ spark = get_spark_session()
729
+ if is_table_stored(table_name, database=database):
730
+ spark.sql(f"DROP TABLE {table_name}")
731
+ logger.info("Dropped table %s", table_name)
732
+
733
+
734
+ @track_timing(timer_stats_collector)
735
+ def create_dataframe_from_product(
736
+ data: dict[str, list[str]],
737
+ context: ScratchDirContext,
738
+ max_partition_size_mb=MAX_PARTITION_SIZE_MB,
739
+ ) -> DataFrame:
740
+ """Create a dataframe by taking a product of values/columns in a dict.
741
+
742
+ Parameters
743
+ ----------
744
+ data : dict
745
+ Columns on which to perform a cross product.
746
+ {"sector": [com], "subsector": ["SmallOffice", "LargeOffice"]}
747
+ context : ScratchDirContext
748
+ Manages temporary files.
749
+ """
750
+ # dthom: 1/29/2024
751
+ # This implementation creates a product of all columns in Python, writes them to temporary
752
+ # CSV files, and then loads that back into Spark.
753
+ # This is the fastest way I've found to pass a large dataframe from the Spark driver (Python
754
+ # app) to the Spark workers on compute nodes.
755
+ # The total size of a table can be large depending on the numbers of dimensions. For example,
756
+ # comstock_conus_2022_projected is 3108 counties * 41 model years * 21 end uses * 14 subsectors * 3 scenarios
757
+ # 112_391_496 rows. The CSV files are ~7.7 GB.
758
+ # (Note that, due to compression, the same table in Parquet is 7 MB.)
759
+ # This is not ideal because it writes temporary files to the filesystem.
760
+ # Other solutions tried:
761
+ # 1. spark.createDataFrame(spark.sparkContext.parallelize(itertools.product(*(data.values()))), list(data.keys))
762
+ # Reasonably fast until the data is larger than Spark's max RPC message size. Then it fails.
763
+ # 2. Create an RDD and then call rdd.flatMap with the output of itertools.product. Very slow.
764
+ # 3. Create one Spark DataFrame per column and then cross-join all of them. Extremely slow.
765
+ # 4. Create one pyarrow Table, write to temp Parquet, read back in Spark. ~2x slower
766
+ # than CSV implementaion.
767
+ # 5. Create the joined table via SQLite and then read the contents into Spark with a JDBC
768
+ # driver. Much slower.
769
+
770
+ # Note: This location must be accessible on all compute nodes.
771
+ csv_dir = context.get_temp_filename(suffix=".csv")
772
+ columns = list(data.keys())
773
+ schema = StructType([StructField(x, StringType()) for x in columns])
774
+
775
+ with CsvPartitionWriter(csv_dir, max_partition_size_mb=max_partition_size_mb) as writer:
776
+ for row in itertools.product(*(data.values())):
777
+ writer.add_row(row)
778
+
779
+ spark = get_spark_session()
780
+ if use_duckdb():
781
+ df = spark.read.csv(f"{csv_dir}/*.csv", header=False, schema=schema)
782
+ else:
783
+ df = spark.read.csv(str(csv_dir), header=False, schema=schema)
784
+ return df
785
+
786
+
787
+ class CsvPartitionWriter:
788
+ """Writes dataframe rows to partitioned CSV files."""
789
+
790
+ def __init__(self, directory: Path, max_partition_size_mb: int = MAX_PARTITION_SIZE_MB):
791
+ self._directory = directory
792
+ self._directory.mkdir(exist_ok=True)
793
+ self._max_size = max_partition_size_mb * 1024 * 1024
794
+ self._size = 0
795
+ self._index = 1
796
+ self._fp = None
797
+
798
+ def __enter__(self):
799
+ return self
800
+
801
+ def __exit__(self, *args, **kwargs):
802
+ if self._fp is not None:
803
+ self._fp.close()
804
+
805
+ def add_row(self, row: tuple) -> None:
806
+ """Add a row to the CSV files."""
807
+ line = ",".join(row)
808
+ if self._fp is None:
809
+ filename = self._directory / f"part{self._index}.csv"
810
+ self._fp = open(filename, "w", encoding="utf-8")
811
+ self._size += self._fp.write(line)
812
+ self._size += self._fp.write("\n")
813
+ if self._size >= self._max_size:
814
+ self._fp.close()
815
+ self._fp = None
816
+ self._size = 0
817
+ self._index += 1
818
+
819
+
820
+ @contextmanager
821
+ def custom_spark_conf(conf):
822
+ """Apply a custom Spark configuration for the duration of a code block.
823
+
824
+ Parameters
825
+ ----------
826
+ conf : dict
827
+ Key-value pairs to set on the spark configuration.
828
+
829
+ """
830
+ spark = get_duckdb_spark_session()
831
+ if spark is not None:
832
+ yield
833
+ return
834
+
835
+ spark = get_spark_session()
836
+ orig_settings = {}
837
+
838
+ try:
839
+ for key, val in conf.items():
840
+ orig_settings[key] = spark.conf.get(key)
841
+ spark.conf.set(key, val)
842
+ logger.info("Set %s=%s temporarily", key, val)
843
+ yield
844
+ finally:
845
+ # Note that the user code could have restarted the session.
846
+ # Get the current one.
847
+ spark = get_spark_session()
848
+ for key, val in orig_settings.items():
849
+ spark.conf.set(key, val)
850
+
851
+
852
+ @contextmanager
853
+ def custom_time_zone(time_zone: str):
854
+ """Apply a custom Spark time zone for the duration of a code block."""
855
+ orig_time_zone = get_current_time_zone()
856
+ try:
857
+ set_current_time_zone(time_zone)
858
+ yield
859
+ finally:
860
+ # Note that the user code could have restarted the session.
861
+ # This will function will get the current one.
862
+ set_current_time_zone(orig_time_zone)
863
+
864
+
865
+ @contextmanager
866
+ def restart_spark_with_custom_conf(conf: dict, force=False):
867
+ """Restart the SparkSession with a custom configuration for the duration of a code block.
868
+
869
+ Parameters
870
+ ----------
871
+ conf : dict
872
+ Key-value pairs to set on the spark configuration.
873
+ force : bool
874
+ If True, restart the session even if the config parameters haven't changed.
875
+ You might want to do this in order to clear cached tables or start Spark fresh.
876
+ """
877
+ spark = get_duckdb_spark_session()
878
+ if spark is not None:
879
+ yield spark
880
+ return
881
+
882
+ spark = get_spark_session()
883
+ app_name = spark.conf.get("spark.app.name")
884
+ orig_settings = {}
885
+
886
+ try:
887
+ for name in conf:
888
+ current = spark.conf.get(name, None)
889
+ if current is not None:
890
+ orig_settings[name] = current
891
+ new_spark = restart_spark(name=app_name, spark_conf=conf, force=force)
892
+ yield new_spark
893
+ finally:
894
+ restart_spark(name=app_name, spark_conf=orig_settings, force=force)
895
+
896
+
897
+ @contextmanager
898
+ def set_session_time_zone(time_zone: str) -> Generator[None, None, None]:
899
+ """Set the session time zone for execution of a code block."""
900
+ orig = get_current_time_zone()
901
+
902
+ try:
903
+ set_current_time_zone(time_zone)
904
+ yield
905
+ finally:
906
+ set_current_time_zone(orig)
907
+
908
+
909
+ def union(dfs: list[DataFrame]) -> DataFrame:
910
+ """Return a union of the dataframes, ensuring that the columns match."""
911
+ df = dfs[0]
912
+ if len(dfs) > 1:
913
+ for dft in dfs[1:]:
914
+ if df.columns != dft.columns:
915
+ msg = f"columns don't match: {df.columns=} {dft.columns=}"
916
+ raise Exception(msg)
917
+ df = df.union(dft)
918
+ return df