dsgrid-toolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsgrid-toolkit might be problematic. Click here for more details.

Files changed (152) hide show
  1. dsgrid/__init__.py +22 -0
  2. dsgrid/api/__init__.py +0 -0
  3. dsgrid/api/api_manager.py +179 -0
  4. dsgrid/api/app.py +420 -0
  5. dsgrid/api/models.py +60 -0
  6. dsgrid/api/response_models.py +116 -0
  7. dsgrid/apps/__init__.py +0 -0
  8. dsgrid/apps/project_viewer/app.py +216 -0
  9. dsgrid/apps/registration_gui.py +444 -0
  10. dsgrid/chronify.py +22 -0
  11. dsgrid/cli/__init__.py +0 -0
  12. dsgrid/cli/common.py +120 -0
  13. dsgrid/cli/config.py +177 -0
  14. dsgrid/cli/download.py +13 -0
  15. dsgrid/cli/dsgrid.py +142 -0
  16. dsgrid/cli/dsgrid_admin.py +349 -0
  17. dsgrid/cli/install_notebooks.py +62 -0
  18. dsgrid/cli/query.py +711 -0
  19. dsgrid/cli/registry.py +1773 -0
  20. dsgrid/cloud/__init__.py +0 -0
  21. dsgrid/cloud/cloud_storage_interface.py +140 -0
  22. dsgrid/cloud/factory.py +31 -0
  23. dsgrid/cloud/fake_storage_interface.py +37 -0
  24. dsgrid/cloud/s3_storage_interface.py +156 -0
  25. dsgrid/common.py +35 -0
  26. dsgrid/config/__init__.py +0 -0
  27. dsgrid/config/annual_time_dimension_config.py +187 -0
  28. dsgrid/config/common.py +131 -0
  29. dsgrid/config/config_base.py +148 -0
  30. dsgrid/config/dataset_config.py +684 -0
  31. dsgrid/config/dataset_schema_handler_factory.py +41 -0
  32. dsgrid/config/date_time_dimension_config.py +108 -0
  33. dsgrid/config/dimension_config.py +54 -0
  34. dsgrid/config/dimension_config_factory.py +65 -0
  35. dsgrid/config/dimension_mapping_base.py +349 -0
  36. dsgrid/config/dimension_mappings_config.py +48 -0
  37. dsgrid/config/dimensions.py +775 -0
  38. dsgrid/config/dimensions_config.py +71 -0
  39. dsgrid/config/index_time_dimension_config.py +76 -0
  40. dsgrid/config/input_dataset_requirements.py +31 -0
  41. dsgrid/config/mapping_tables.py +209 -0
  42. dsgrid/config/noop_time_dimension_config.py +42 -0
  43. dsgrid/config/project_config.py +1457 -0
  44. dsgrid/config/registration_models.py +199 -0
  45. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  46. dsgrid/config/simple_models.py +49 -0
  47. dsgrid/config/supplemental_dimension.py +29 -0
  48. dsgrid/config/time_dimension_base_config.py +200 -0
  49. dsgrid/data_models.py +155 -0
  50. dsgrid/dataset/__init__.py +0 -0
  51. dsgrid/dataset/dataset.py +123 -0
  52. dsgrid/dataset/dataset_expression_handler.py +86 -0
  53. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  54. dsgrid/dataset/dataset_schema_handler_base.py +899 -0
  55. dsgrid/dataset/dataset_schema_handler_one_table.py +196 -0
  56. dsgrid/dataset/dataset_schema_handler_standard.py +303 -0
  57. dsgrid/dataset/growth_rates.py +162 -0
  58. dsgrid/dataset/models.py +44 -0
  59. dsgrid/dataset/table_format_handler_base.py +257 -0
  60. dsgrid/dataset/table_format_handler_factory.py +17 -0
  61. dsgrid/dataset/unpivoted_table.py +121 -0
  62. dsgrid/dimension/__init__.py +0 -0
  63. dsgrid/dimension/base_models.py +218 -0
  64. dsgrid/dimension/dimension_filters.py +308 -0
  65. dsgrid/dimension/standard.py +213 -0
  66. dsgrid/dimension/time.py +531 -0
  67. dsgrid/dimension/time_utils.py +88 -0
  68. dsgrid/dsgrid_rc.py +88 -0
  69. dsgrid/exceptions.py +105 -0
  70. dsgrid/filesystem/__init__.py +0 -0
  71. dsgrid/filesystem/cloud_filesystem.py +32 -0
  72. dsgrid/filesystem/factory.py +32 -0
  73. dsgrid/filesystem/filesystem_interface.py +136 -0
  74. dsgrid/filesystem/local_filesystem.py +74 -0
  75. dsgrid/filesystem/s3_filesystem.py +118 -0
  76. dsgrid/loggers.py +132 -0
  77. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +950 -0
  78. dsgrid/notebooks/registration.ipynb +48 -0
  79. dsgrid/notebooks/start_notebook.sh +11 -0
  80. dsgrid/project.py +451 -0
  81. dsgrid/query/__init__.py +0 -0
  82. dsgrid/query/dataset_mapping_plan.py +142 -0
  83. dsgrid/query/derived_dataset.py +384 -0
  84. dsgrid/query/models.py +726 -0
  85. dsgrid/query/query_context.py +287 -0
  86. dsgrid/query/query_submitter.py +847 -0
  87. dsgrid/query/report_factory.py +19 -0
  88. dsgrid/query/report_peak_load.py +70 -0
  89. dsgrid/query/reports_base.py +20 -0
  90. dsgrid/registry/__init__.py +0 -0
  91. dsgrid/registry/bulk_register.py +161 -0
  92. dsgrid/registry/common.py +287 -0
  93. dsgrid/registry/config_update_checker_base.py +63 -0
  94. dsgrid/registry/data_store_factory.py +34 -0
  95. dsgrid/registry/data_store_interface.py +69 -0
  96. dsgrid/registry/dataset_config_generator.py +156 -0
  97. dsgrid/registry/dataset_registry_manager.py +734 -0
  98. dsgrid/registry/dataset_update_checker.py +16 -0
  99. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  100. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  101. dsgrid/registry/dimension_registry_manager.py +413 -0
  102. dsgrid/registry/dimension_update_checker.py +16 -0
  103. dsgrid/registry/duckdb_data_store.py +185 -0
  104. dsgrid/registry/filesystem_data_store.py +141 -0
  105. dsgrid/registry/filter_registry_manager.py +123 -0
  106. dsgrid/registry/project_config_generator.py +57 -0
  107. dsgrid/registry/project_registry_manager.py +1616 -0
  108. dsgrid/registry/project_update_checker.py +48 -0
  109. dsgrid/registry/registration_context.py +223 -0
  110. dsgrid/registry/registry_auto_updater.py +316 -0
  111. dsgrid/registry/registry_database.py +662 -0
  112. dsgrid/registry/registry_interface.py +446 -0
  113. dsgrid/registry/registry_manager.py +544 -0
  114. dsgrid/registry/registry_manager_base.py +367 -0
  115. dsgrid/registry/versioning.py +92 -0
  116. dsgrid/spark/__init__.py +0 -0
  117. dsgrid/spark/functions.py +545 -0
  118. dsgrid/spark/types.py +50 -0
  119. dsgrid/tests/__init__.py +0 -0
  120. dsgrid/tests/common.py +139 -0
  121. dsgrid/tests/make_us_data_registry.py +204 -0
  122. dsgrid/tests/register_derived_datasets.py +103 -0
  123. dsgrid/tests/utils.py +25 -0
  124. dsgrid/time/__init__.py +0 -0
  125. dsgrid/time/time_conversions.py +80 -0
  126. dsgrid/time/types.py +67 -0
  127. dsgrid/units/__init__.py +0 -0
  128. dsgrid/units/constants.py +113 -0
  129. dsgrid/units/convert.py +71 -0
  130. dsgrid/units/energy.py +145 -0
  131. dsgrid/units/power.py +87 -0
  132. dsgrid/utils/__init__.py +0 -0
  133. dsgrid/utils/dataset.py +612 -0
  134. dsgrid/utils/files.py +179 -0
  135. dsgrid/utils/filters.py +125 -0
  136. dsgrid/utils/id_remappings.py +100 -0
  137. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  138. dsgrid/utils/py_expression_eval/README.md +8 -0
  139. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  140. dsgrid/utils/py_expression_eval/tests.py +283 -0
  141. dsgrid/utils/run_command.py +70 -0
  142. dsgrid/utils/scratch_dir_context.py +64 -0
  143. dsgrid/utils/spark.py +918 -0
  144. dsgrid/utils/spark_partition.py +98 -0
  145. dsgrid/utils/timing.py +239 -0
  146. dsgrid/utils/utilities.py +184 -0
  147. dsgrid/utils/versioning.py +36 -0
  148. dsgrid_toolkit-0.2.0.dist-info/METADATA +216 -0
  149. dsgrid_toolkit-0.2.0.dist-info/RECORD +152 -0
  150. dsgrid_toolkit-0.2.0.dist-info/WHEEL +4 -0
  151. dsgrid_toolkit-0.2.0.dist-info/entry_points.txt +4 -0
  152. dsgrid_toolkit-0.2.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,612 @@
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Iterable
5
+
6
+ import chronify
7
+ from chronify.models import TableSchema
8
+
9
+ import dsgrid
10
+ from dsgrid.common import SCALING_FACTOR_COLUMN, VALUE_COLUMN
11
+ from dsgrid.config.dimension_mapping_base import DimensionMappingType
12
+ from dsgrid.config.time_dimension_base_config import TimeDimensionBaseConfig
13
+ from dsgrid.dataset.dataset_mapping_manager import DatasetMappingManager
14
+ from dsgrid.dimension.base_models import DimensionType
15
+ from dsgrid.dimension.time import (
16
+ DaylightSavingFallBackType,
17
+ DaylightSavingSpringForwardType,
18
+ TimeBasedDataAdjustmentModel,
19
+ TimeZone,
20
+ )
21
+ from dsgrid.exceptions import (
22
+ DSGInvalidField,
23
+ DSGInvalidDimensionMapping,
24
+ DSGInvalidDataset,
25
+ )
26
+ from dsgrid.spark.functions import (
27
+ count_distinct_on_group_by,
28
+ create_temp_view,
29
+ handle_column_spaces,
30
+ make_temp_view_name,
31
+ read_parquet,
32
+ is_dataframe_empty,
33
+ join,
34
+ join_multiple_columns,
35
+ unpivot,
36
+ write_csv,
37
+ )
38
+ from dsgrid.spark.functions import except_all, get_spark_session
39
+ from dsgrid.spark.types import (
40
+ DataFrame,
41
+ F,
42
+ IntegerType,
43
+ LongType,
44
+ ShortType,
45
+ StringType,
46
+ use_duckdb,
47
+ )
48
+ from dsgrid.utils.scratch_dir_context import ScratchDirContext
49
+ from dsgrid.utils.spark import (
50
+ check_for_nulls,
51
+ )
52
+ from dsgrid.utils.timing import timer_stats_collector, track_timing
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ def map_stacked_dimension(
58
+ df: DataFrame,
59
+ records: DataFrame,
60
+ column: str,
61
+ drop_column: bool = True,
62
+ to_column: str | None = None,
63
+ ) -> DataFrame:
64
+ to_column_ = to_column or column
65
+ if "fraction" not in df.columns:
66
+ df = df.withColumn("fraction", F.lit(1.0))
67
+ # map and consolidate from_fraction only
68
+ records = records.filter("to_id IS NOT NULL")
69
+ df = join(df, records, column, "from_id", how="inner").drop("from_id")
70
+ if drop_column:
71
+ df = df.drop(column)
72
+ df = df.withColumnRenamed("to_id", to_column_)
73
+ nonfraction_cols = [x for x in df.columns if x not in {"fraction", "from_fraction"}]
74
+ df = df.select(
75
+ *nonfraction_cols,
76
+ (F.col("fraction") * F.col("from_fraction")).alias("fraction"),
77
+ )
78
+ return df
79
+
80
+
81
+ def add_time_zone(load_data_df, geography_dim):
82
+ """Add a time_zone column to a load_data dataframe from a geography dimension.
83
+
84
+ Parameters
85
+ ----------
86
+ load_data_df : pyspark.sql.DataFrame
87
+ geography_dim: DimensionConfig
88
+
89
+ Returns
90
+ -------
91
+ pyspark.sql.DataFrame
92
+
93
+ """
94
+ spark = get_spark_session()
95
+ dsg_geo_records = geography_dim.get_records_dataframe()
96
+ tz_map_table = spark.createDataFrame(
97
+ [(x.value, x.tz_name) for x in TimeZone], ("dsgrid_name", "tz_name")
98
+ )
99
+ geo_records = (
100
+ join(dsg_geo_records, tz_map_table, "time_zone", "dsgrid_name")
101
+ .drop("time_zone", "dsgrid_name")
102
+ .withColumnRenamed("tz_name", "time_zone")
103
+ )
104
+ assert dsg_geo_records.count() == geo_records.count()
105
+ geo_name = geography_dim.model.dimension_type.value
106
+ return add_column_from_records(load_data_df, geo_records, geo_name, "time_zone")
107
+
108
+
109
+ def add_column_from_records(df, dimension_records, dimension_name, column_to_add):
110
+ df = join(
111
+ df,
112
+ dimension_records.select(F.col("id").alias("record_id"), column_to_add),
113
+ dimension_name,
114
+ "record_id",
115
+ how="inner",
116
+ ).drop("record_id")
117
+ return df
118
+
119
+
120
+ def add_null_rows_from_load_data_lookup(df: DataFrame, lookup: DataFrame) -> DataFrame:
121
+ """Add null rows from the nulled load data lookup table to data table.
122
+
123
+ Parameters
124
+ ----------
125
+ df
126
+ load data table
127
+ lookup
128
+ load data lookup table that has been filtered for nulls.
129
+ """
130
+ if not is_dataframe_empty(lookup):
131
+ intersect_cols = set(lookup.columns).intersection(df.columns)
132
+ null_rows_to_add = except_all(lookup.select(*intersect_cols), df.select(*intersect_cols))
133
+ for col in set(df.columns).difference(null_rows_to_add.columns):
134
+ null_rows_to_add = null_rows_to_add.withColumn(col, F.lit(None))
135
+ df = df.union(null_rows_to_add.select(*df.columns))
136
+
137
+ return df
138
+
139
+
140
+ def apply_scaling_factor(
141
+ df: DataFrame,
142
+ value_column: str,
143
+ mapping_manager: DatasetMappingManager,
144
+ scaling_factor_column: str = SCALING_FACTOR_COLUMN,
145
+ ) -> DataFrame:
146
+ """Apply the scaling factor to all value columns and then drop the scaling factor column."""
147
+ op = mapping_manager.plan.apply_scaling_factor_op
148
+ if mapping_manager.has_completed_operation(op):
149
+ return df
150
+
151
+ func = _apply_scaling_factor_duckdb if use_duckdb() else _apply_scaling_factor_spark
152
+ df = func(df, value_column, scaling_factor_column)
153
+ if mapping_manager.plan.apply_scaling_factor_op.persist:
154
+ df = mapping_manager.persist_intermediate_table(df, op)
155
+ return df
156
+
157
+
158
+ def _apply_scaling_factor_duckdb(
159
+ df: DataFrame,
160
+ value_column: str,
161
+ scaling_factor_column: str,
162
+ ):
163
+ # Workaround for the fact that duckdb doesn't support
164
+ # F.col(scaling_factor_column).isNotNull()
165
+ cols = (x for x in df.columns if x not in (value_column, scaling_factor_column))
166
+ cols_str = ",".join(cols)
167
+ view = create_temp_view(df)
168
+ query = f"""
169
+ SELECT
170
+ {cols_str},
171
+ (
172
+ CASE WHEN {scaling_factor_column} IS NULL THEN {value_column}
173
+ ELSE {value_column} * {scaling_factor_column} END
174
+ ) AS {value_column}
175
+ FROM {view}
176
+ """
177
+ spark = get_spark_session()
178
+ return spark.sql(query)
179
+
180
+
181
+ def _apply_scaling_factor_spark(
182
+ df: DataFrame,
183
+ value_column: str,
184
+ scaling_factor_column: str,
185
+ ):
186
+ return df.withColumn(
187
+ value_column,
188
+ F.when(
189
+ F.col(scaling_factor_column).isNotNull(),
190
+ F.col(value_column) * F.col(scaling_factor_column),
191
+ ).otherwise(F.col(value_column)),
192
+ ).drop(scaling_factor_column)
193
+
194
+
195
+ def check_historical_annual_time_model_year_consistency(
196
+ df: DataFrame, time_column: str, model_year_column: str
197
+ ) -> None:
198
+ """Check that the model year values match the time dimension years for a historical
199
+ dataset with an annual time dimension.
200
+ """
201
+ invalid = (
202
+ df.select(time_column, model_year_column)
203
+ .filter(f"{time_column} IS NOT NULL")
204
+ .distinct()
205
+ .filter(f"{time_column} != {model_year_column}")
206
+ .collect()
207
+ )
208
+ if invalid:
209
+ msg = (
210
+ "A historical dataset with annual time must have rows where the time years match the model years. "
211
+ f"{invalid}"
212
+ )
213
+ raise DSGInvalidDataset(msg)
214
+
215
+
216
+ @track_timing(timer_stats_collector)
217
+ def check_null_value_in_dimension_rows(dim_table, exclude_columns=None):
218
+ if os.environ.get("__DSGRID_SKIP_CHECK_NULL_DIMENSION__"):
219
+ # This has intermittently caused GC-related timeouts for TEMPO.
220
+ # Leave a backdoor to skip these checks, which may eventually be removed.
221
+ logger.warning("Skip check_null_value_in_dimension_rows")
222
+ return
223
+
224
+ try:
225
+ exclude = {"id"}
226
+ if exclude_columns is not None:
227
+ exclude.update(exclude_columns)
228
+ check_for_nulls(dim_table, exclude_columns=exclude)
229
+ except DSGInvalidField as exc:
230
+ msg = (
231
+ "Invalid dimension mapping application. "
232
+ "Combination of remapped dataset dimensions contain NULL value(s) for "
233
+ f"dimension(s): \n{str(exc)}"
234
+ )
235
+ raise DSGInvalidDimensionMapping(msg)
236
+
237
+
238
+ def handle_dimension_association_errors(
239
+ diff: DataFrame,
240
+ dataset_table: DataFrame,
241
+ dataset_id: str,
242
+ ) -> None:
243
+ """Record missing dimension record combinations in a CSV file and log an error."""
244
+ out_file = f"{dataset_id}__missing_dimension_record_combinations.csv"
245
+ df = diff
246
+ changed = False
247
+ for column in diff.columns:
248
+ if diff.select(column).distinct().count() == 1:
249
+ df = df.drop(column)
250
+ changed = True
251
+ if changed:
252
+ df = df.distinct()
253
+ write_csv(df, out_file, header=True, overwrite=True)
254
+ logger.error(
255
+ "Dataset %s is missing required dimension records. Recorded missing records in %s",
256
+ dataset_id,
257
+ out_file,
258
+ )
259
+ _look_for_error_contributors(df, dataset_table)
260
+ msg = (
261
+ f"Dataset {dataset_id} is missing required dimension records. "
262
+ "Please look in the log file for more information."
263
+ )
264
+ raise DSGInvalidDataset(msg)
265
+
266
+
267
+ def _look_for_error_contributors(diff: DataFrame, dataset_table: DataFrame) -> None:
268
+ diff_counts = {x: diff.select(x).distinct().count() for x in diff.columns}
269
+ for col in diff.columns:
270
+ dataset_count = dataset_table.select(col).distinct().count()
271
+ if dataset_count != diff_counts[col]:
272
+ logger.error(
273
+ "Error contributor: column=%s dataset_distinct_count=%s missing_distinct_count=%s",
274
+ col,
275
+ dataset_count,
276
+ diff_counts[col],
277
+ )
278
+
279
+
280
+ def is_noop_mapping(records: DataFrame) -> bool:
281
+ """Return True if the mapping is a no-op."""
282
+ return is_dataframe_empty(
283
+ records.filter(
284
+ "(to_id IS NULL and from_id IS NOT NULL) or "
285
+ "(to_id IS NOT NULL and from_id IS NULL) or "
286
+ "(from_id != to_id) or (from_fraction != 1.0)"
287
+ )
288
+ )
289
+
290
+
291
+ def map_time_dimension_with_chronify_duckdb(
292
+ df: DataFrame,
293
+ value_column: str,
294
+ from_time_dim: TimeDimensionBaseConfig,
295
+ to_time_dim: TimeDimensionBaseConfig,
296
+ scratch_dir_context: ScratchDirContext,
297
+ wrap_time_allowed: bool = False,
298
+ time_based_data_adjustment: TimeBasedDataAdjustmentModel | None = None,
299
+ ) -> DataFrame:
300
+ """Create a time-mapped table with chronify and DuckDB.
301
+ All operations are performed in memory.
302
+ """
303
+ # This will only work if the source and destination tables will fit in memory.
304
+ # We could potentially use a file-based DuckDB database for larger-than memory datasets.
305
+ # However, time checks and unpivot operations have failed with out-of-memory errors,
306
+ # and so we have never reached this point.
307
+ # If we solve those problems, this code could be modified.
308
+ src_schema, dst_schema = _get_mapping_schemas(df, value_column, from_time_dim, to_time_dim)
309
+ store = chronify.Store.create_in_memory_db()
310
+ store.ingest_table(df.relation, src_schema, skip_time_checks=True)
311
+ store.map_table_time_config(
312
+ src_schema.name,
313
+ dst_schema,
314
+ wrap_time_allowed=wrap_time_allowed,
315
+ data_adjustment=_to_chronify_time_based_data_adjustment(time_based_data_adjustment),
316
+ scratch_dir=scratch_dir_context.scratch_dir,
317
+ )
318
+ pandas_df = store.read_table(dst_schema.name)
319
+ store.drop_table(dst_schema.name)
320
+ return df.session.createDataFrame(pandas_df)
321
+
322
+
323
+ def map_time_dimension_with_chronify_spark_hive(
324
+ df: DataFrame,
325
+ table_name: str,
326
+ value_column: str,
327
+ from_time_dim: TimeDimensionBaseConfig,
328
+ to_time_dim: TimeDimensionBaseConfig,
329
+ scratch_dir_context: ScratchDirContext,
330
+ time_based_data_adjustment: TimeBasedDataAdjustmentModel | None = None,
331
+ wrap_time_allowed: bool = False,
332
+ ) -> DataFrame:
333
+ """Create a time-mapped table with chronify and Spark and a Hive Metastore.
334
+ The source data must already be stored in the metastore.
335
+ Chronify will store the mapped table in the metastore.
336
+ """
337
+ src_schema, dst_schema = _get_mapping_schemas(
338
+ df, value_column, from_time_dim, to_time_dim, src_name=table_name
339
+ )
340
+ store = chronify.Store.create_new_hive_store(dsgrid.runtime_config.thrift_server_url)
341
+ with store.engine.begin() as conn:
342
+ # This bypasses checks because the table should already be valid.
343
+ store.schema_manager.add_schema(conn, src_schema)
344
+ try:
345
+ # TODO: https://github.com/NREL/chronify/issues/37
346
+ store.map_table_time_config(
347
+ src_schema.name,
348
+ dst_schema,
349
+ check_mapped_timestamps=False,
350
+ scratch_dir=scratch_dir_context.scratch_dir,
351
+ wrap_time_allowed=wrap_time_allowed,
352
+ data_adjustment=_to_chronify_time_based_data_adjustment(time_based_data_adjustment),
353
+ )
354
+ finally:
355
+ with store.engine.begin() as conn:
356
+ store.schema_manager.remove_schema(conn, src_schema.name)
357
+
358
+ return df.sparkSession.sql(f"SELECT * FROM {dst_schema.name}")
359
+
360
+
361
+ def map_time_dimension_with_chronify_spark_path(
362
+ df: DataFrame,
363
+ filename: Path,
364
+ value_column: str,
365
+ from_time_dim: TimeDimensionBaseConfig,
366
+ to_time_dim: TimeDimensionBaseConfig,
367
+ scratch_dir_context: ScratchDirContext,
368
+ wrap_time_allowed: bool = False,
369
+ time_based_data_adjustment: TimeBasedDataAdjustmentModel | None = None,
370
+ ) -> DataFrame:
371
+ """Create a time-mapped table with chronify and Spark using the local filesystem.
372
+ Chronify will store the mapped table in a Parquet file within scratch_dir_context.
373
+ """
374
+ src_schema, dst_schema = _get_mapping_schemas(df, value_column, from_time_dim, to_time_dim)
375
+ store = chronify.Store.create_new_hive_store(dsgrid.runtime_config.thrift_server_url)
376
+ store.create_view_from_parquet(filename, src_schema, bypass_checks=True)
377
+ output_file = scratch_dir_context.get_temp_filename(suffix=".parquet")
378
+ store.map_table_time_config(
379
+ src_schema.name,
380
+ dst_schema,
381
+ check_mapped_timestamps=False,
382
+ scratch_dir=scratch_dir_context.scratch_dir,
383
+ output_file=output_file,
384
+ wrap_time_allowed=wrap_time_allowed,
385
+ data_adjustment=_to_chronify_time_based_data_adjustment(time_based_data_adjustment),
386
+ )
387
+ return df.sparkSession.read.load(str(output_file))
388
+
389
+
390
+ def _to_chronify_time_based_data_adjustment(
391
+ adj: TimeBasedDataAdjustmentModel | None,
392
+ ) -> chronify.TimeBasedDataAdjustment | None:
393
+ if adj is None:
394
+ return None
395
+ if (
396
+ adj.daylight_saving_adjustment.spring_forward_hour == DaylightSavingSpringForwardType.NONE
397
+ and adj.daylight_saving_adjustment.fall_back_hour == DaylightSavingFallBackType.NONE
398
+ ):
399
+ chronify_dst_adjustment = chronify.time.DaylightSavingAdjustmentType.NONE
400
+ elif (
401
+ adj.daylight_saving_adjustment.spring_forward_hour == DaylightSavingSpringForwardType.DROP
402
+ and adj.daylight_saving_adjustment.fall_back_hour == DaylightSavingFallBackType.DUPLICATE
403
+ ):
404
+ chronify_dst_adjustment = (
405
+ chronify.time.DaylightSavingAdjustmentType.DROP_SPRING_FORWARD_DUPLICATE_FALLBACK
406
+ )
407
+ elif (
408
+ adj.daylight_saving_adjustment.spring_forward_hour == DaylightSavingSpringForwardType.DROP
409
+ and adj.daylight_saving_adjustment.fall_back_hour == DaylightSavingFallBackType.INTERPOLATE
410
+ ):
411
+ chronify_dst_adjustment = (
412
+ chronify.time.DaylightSavingAdjustmentType.DROP_SPRING_FORWARD_INTERPOLATE_FALLBACK
413
+ )
414
+ else:
415
+ msg = f"dsgrid time_based_data_adjustment = {adj}"
416
+ raise NotImplementedError(msg)
417
+
418
+ return chronify.TimeBasedDataAdjustment(
419
+ leap_day_adjustment=adj.leap_day_adjustment.value,
420
+ daylight_saving_adjustment=chronify_dst_adjustment,
421
+ )
422
+
423
+
424
+ def _get_mapping_schemas(
425
+ df: DataFrame,
426
+ value_column: str,
427
+ from_time_dim: TimeDimensionBaseConfig,
428
+ to_time_dim: TimeDimensionBaseConfig,
429
+ src_name: str | None = None,
430
+ ) -> tuple[TableSchema, TableSchema]:
431
+ src = src_name or "src_" + make_temp_view_name()
432
+ time_array_id_columns = [
433
+ x
434
+ for x in df.columns
435
+ if x
436
+ in set(df.columns).difference(from_time_dim.get_load_data_time_columns()) - {value_column}
437
+ ]
438
+ src_schema = chronify.TableSchema(
439
+ name=src,
440
+ time_config=from_time_dim.to_chronify(),
441
+ time_array_id_columns=time_array_id_columns,
442
+ value_column=value_column,
443
+ )
444
+ dst_schema = chronify.TableSchema(
445
+ name="dst_" + make_temp_view_name(),
446
+ time_config=to_time_dim.to_chronify(),
447
+ time_array_id_columns=time_array_id_columns,
448
+ value_column=value_column,
449
+ )
450
+ return src_schema, dst_schema
451
+
452
+
453
+ def ordered_subset_columns(df, subset: set[str]) -> list[str]:
454
+ """Return a list of columns in the dataframe that are present in subset."""
455
+ return [x for x in df.columns if x in subset]
456
+
457
+
458
+ def remove_invalid_null_timestamps(df, time_columns, stacked_columns):
459
+ """Remove rows from the dataframe where the time column is NULL and other rows with the
460
+ same dimensions contain valid data.
461
+ """
462
+ assert len(time_columns) == 1, time_columns
463
+ time_column = next(iter(time_columns))
464
+ orig_columns = df.columns
465
+ stacked = list(stacked_columns)
466
+ return (
467
+ join_multiple_columns(
468
+ df,
469
+ count_distinct_on_group_by(df, stacked, time_column, "count_time"),
470
+ stacked,
471
+ )
472
+ .filter(f"{handle_column_spaces(time_column)} IS NOT NULL OR count_time = 0")
473
+ .select(orig_columns)
474
+ )
475
+
476
+
477
+ @track_timing(timer_stats_collector)
478
+ def repartition_if_needed_by_mapping(
479
+ df: DataFrame,
480
+ mapping_type: DimensionMappingType,
481
+ scratch_dir_context: ScratchDirContext,
482
+ repartition: bool | None = None,
483
+ ) -> tuple[DataFrame, Path | None]:
484
+ """Repartition the dataframe if the mapping might cause data skew.
485
+
486
+ Parameters
487
+ ----------
488
+ df : DataFrame
489
+ The dataframe to repartition.
490
+ mapping_type : DimensionMappingType
491
+ scratch_dir_context : ScratchDirContext
492
+ The scratch directory context to use for temporary files.
493
+ repartition : bool
494
+ If None, repartition based on the mapping type.
495
+ Otherwise, always repartition if True, or never if False.
496
+ """
497
+ if use_duckdb():
498
+ return df, None
499
+
500
+ # We experienced an issue with the IEF buildings dataset where the disaggregation of
501
+ # region to county caused a major issue where one Spark executor thread got stuck,
502
+ # seemingly indefinitely. A message like this was repeated continually.
503
+ # UnsafeExternalSorter: Thread 152 spilling sort data of 4.0 GiB to disk (0 time so far)
504
+ # It appears to be caused by data skew, though the imbalance didn't seem too severe.
505
+ # Using a variation of what online sources call a "salting technique" solves the issue.
506
+ # Apply the technique to mappings that will cause an explosion of rows.
507
+ # Note that this probably isn't needed in all cases and we may need to adjust in the
508
+ # future.
509
+
510
+ # Note: log messages below are checked in the tests.
511
+ if repartition or (
512
+ repartition is None
513
+ and mapping_type
514
+ in {
515
+ DimensionMappingType.ONE_TO_MANY_DISAGGREGATION,
516
+ # These cases might be problematic in the future.
517
+ # DimensionMappingType.ONE_TO_MANY_ASSIGNMENT,
518
+ # DimensionMappingType.ONE_TO_MANY_EXPLICIT_MULTIPLIERS,
519
+ # DimensionMappingType.MANY_TO_MANY_DISAGGREGATION,
520
+ # This is usually happening with scenario and hasn't caused a problem.
521
+ # DimensionMappingType.DUPLICATION,
522
+ }
523
+ ):
524
+ filename = scratch_dir_context.get_temp_filename(suffix=".parquet")
525
+ # Salting techniques online talk about adding or modifying a column with random values.
526
+ # We might be able to use one of our value columns. However, there are cases where there
527
+ # could be many instances of zero or null. So, add a new column with random values.
528
+ logger.info("Repartition after mapping %s", mapping_type)
529
+ salted_column = "salted_key"
530
+ spark = get_spark_session()
531
+ num_partitions = int(spark.conf.get("spark.sql.shuffle.partitions"))
532
+ df.withColumn(
533
+ salted_column, (F.rand() * num_partitions).cast(IntegerType()) + 1
534
+ ).repartition(salted_column).write.parquet(str(filename))
535
+ df = read_parquet(filename).drop(salted_column)
536
+ logger.info("Completed repartition.")
537
+ return df, filename
538
+
539
+ logger.debug("Repartition is not needed for mapping_type %s", mapping_type)
540
+ return df, None
541
+
542
+
543
+ def unpivot_dataframe(
544
+ df: DataFrame,
545
+ value_columns: Iterable[str],
546
+ variable_column: str,
547
+ time_columns: list[str],
548
+ ) -> DataFrame:
549
+ """Unpivot the dataframe, accounting for time columns."""
550
+ values = value_columns if isinstance(value_columns, set) else set(value_columns)
551
+ ids = [x for x in df.columns if x != VALUE_COLUMN and x not in values]
552
+ df = unpivot(df, value_columns, variable_column, VALUE_COLUMN)
553
+ cols = set(df.columns).difference(time_columns)
554
+ new_rows = df.filter(f"{VALUE_COLUMN} IS NULL").select(*cols).distinct()
555
+ for col in time_columns:
556
+ new_rows = new_rows.withColumn(col, F.lit(None))
557
+
558
+ return (
559
+ df.filter(f"{VALUE_COLUMN} IS NOT NULL")
560
+ .union(new_rows.select(*df.columns))
561
+ .select(*ids, variable_column, VALUE_COLUMN)
562
+ )
563
+
564
+
565
+ def convert_types_if_necessary(df: DataFrame) -> DataFrame:
566
+ """Convert the types of the dataframe if necessary."""
567
+ allowed_int_columns = (
568
+ DimensionType.MODEL_YEAR.value,
569
+ DimensionType.WEATHER_YEAR.value,
570
+ )
571
+ int_types = {IntegerType(), LongType(), ShortType()}
572
+ existing_columns = set(df.columns)
573
+ for column in allowed_int_columns:
574
+ if column in existing_columns and df.schema[column].dataType in int_types:
575
+ df = df.withColumn(column, F.col(column).cast(StringType()))
576
+ return df
577
+
578
+
579
+ def filter_out_expected_missing_associations(
580
+ main_df: DataFrame, missing_df: DataFrame
581
+ ) -> DataFrame:
582
+ """Filter out rows that are expected to be missing from the main dataframe."""
583
+ missing_columns = [DimensionType.from_column(x).value for x in missing_df.columns]
584
+ spark = get_spark_session()
585
+ main_view = make_temp_view_name()
586
+ assoc_view = make_temp_view_name()
587
+ main_columns = ",".join((f"{main_view}.{x}" for x in main_df.columns))
588
+
589
+ main_df.createOrReplaceTempView(main_view)
590
+ missing_df.createOrReplaceTempView(assoc_view)
591
+ join_str = " AND ".join((f"{main_view}.{x} = {assoc_view}.{x}" for x in missing_columns))
592
+ query = f"""
593
+ SELECT {main_columns}
594
+ FROM {main_view}
595
+ ANTI JOIN {assoc_view}
596
+ ON {join_str}
597
+ """
598
+ res = spark.sql(query)
599
+ return res
600
+
601
+
602
+ def split_expected_missing_rows(
603
+ df: DataFrame, time_columns: list[str]
604
+ ) -> tuple[DataFrame, DataFrame | None]:
605
+ """Split a DataFrame into two if it contains expected missing data."""
606
+ null_df = df.filter(f"{VALUE_COLUMN} IS NULL")
607
+ if is_dataframe_empty(null_df):
608
+ return df, None
609
+
610
+ drop_columns = time_columns + [VALUE_COLUMN]
611
+ missing_associations = null_df.drop(*drop_columns)
612
+ return df.filter(f"{VALUE_COLUMN} IS NOT NULL"), missing_associations