tencent-wedata-feature-engineering-dev 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tencent-wedata-feature-engineering-dev might be problematic. Click here for more details.
- tencent_wedata_feature_engineering_dev-0.1.0.dist-info/METADATA +19 -0
- tencent_wedata_feature_engineering_dev-0.1.0.dist-info/RECORD +64 -0
- tencent_wedata_feature_engineering_dev-0.1.0.dist-info/WHEEL +5 -0
- tencent_wedata_feature_engineering_dev-0.1.0.dist-info/top_level.txt +1 -0
- wedata/__init__.py +9 -0
- wedata/feature_store/__init__.py +0 -0
- wedata/feature_store/client.py +462 -0
- wedata/feature_store/cloud_sdk_client/__init__.py +0 -0
- wedata/feature_store/cloud_sdk_client/client.py +86 -0
- wedata/feature_store/cloud_sdk_client/models.py +686 -0
- wedata/feature_store/cloud_sdk_client/utils.py +32 -0
- wedata/feature_store/common/__init__.py +0 -0
- wedata/feature_store/common/protos/__init__.py +0 -0
- wedata/feature_store/common/protos/feature_store_pb2.py +49 -0
- wedata/feature_store/common/store_config/__init__.py +0 -0
- wedata/feature_store/common/store_config/redis.py +48 -0
- wedata/feature_store/constants/__init__.py +0 -0
- wedata/feature_store/constants/constants.py +59 -0
- wedata/feature_store/constants/engine_types.py +34 -0
- wedata/feature_store/entities/__init__.py +0 -0
- wedata/feature_store/entities/column_info.py +138 -0
- wedata/feature_store/entities/environment_variables.py +55 -0
- wedata/feature_store/entities/feature.py +53 -0
- wedata/feature_store/entities/feature_column_info.py +72 -0
- wedata/feature_store/entities/feature_function.py +55 -0
- wedata/feature_store/entities/feature_lookup.py +200 -0
- wedata/feature_store/entities/feature_spec.py +489 -0
- wedata/feature_store/entities/feature_spec_constants.py +25 -0
- wedata/feature_store/entities/feature_table.py +111 -0
- wedata/feature_store/entities/feature_table_info.py +49 -0
- wedata/feature_store/entities/function_info.py +90 -0
- wedata/feature_store/entities/on_demand_column_info.py +57 -0
- wedata/feature_store/entities/source_data_column_info.py +24 -0
- wedata/feature_store/entities/training_set.py +135 -0
- wedata/feature_store/feast_client/__init__.py +0 -0
- wedata/feature_store/feast_client/feast_client.py +482 -0
- wedata/feature_store/feature_table_client/__init__.py +0 -0
- wedata/feature_store/feature_table_client/feature_table_client.py +969 -0
- wedata/feature_store/mlflow_model.py +17 -0
- wedata/feature_store/spark_client/__init__.py +0 -0
- wedata/feature_store/spark_client/spark_client.py +289 -0
- wedata/feature_store/training_set_client/__init__.py +0 -0
- wedata/feature_store/training_set_client/training_set_client.py +572 -0
- wedata/feature_store/utils/__init__.py +0 -0
- wedata/feature_store/utils/common_utils.py +352 -0
- wedata/feature_store/utils/env_utils.py +86 -0
- wedata/feature_store/utils/feature_lookup_utils.py +564 -0
- wedata/feature_store/utils/feature_spec_utils.py +286 -0
- wedata/feature_store/utils/feature_utils.py +73 -0
- wedata/feature_store/utils/on_demand_utils.py +107 -0
- wedata/feature_store/utils/schema_utils.py +117 -0
- wedata/feature_store/utils/signature_utils.py +202 -0
- wedata/feature_store/utils/topological_sort.py +158 -0
- wedata/feature_store/utils/training_set_utils.py +579 -0
- wedata/feature_store/utils/uc_utils.py +296 -0
- wedata/feature_store/utils/validation_utils.py +79 -0
- wedata/tempo/__init__.py +0 -0
- wedata/tempo/interpol.py +448 -0
- wedata/tempo/intervals.py +1331 -0
- wedata/tempo/io.py +61 -0
- wedata/tempo/ml.py +129 -0
- wedata/tempo/resample.py +318 -0
- wedata/tempo/tsdf.py +1720 -0
- wedata/tempo/utils.py +254 -0
wedata/tempo/tsdf.py
ADDED
|
@@ -0,0 +1,1720 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import operator
|
|
5
|
+
from abc import ABCMeta, abstractmethod
|
|
6
|
+
from typing import Any, Callable, List, Optional, Sequence, TypeVar, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import pyspark.sql.functions as sfn
|
|
11
|
+
from IPython.core.display import HTML # type: ignore
|
|
12
|
+
from IPython.display import display as ipydisplay # type: ignore
|
|
13
|
+
from pyspark.sql import SparkSession
|
|
14
|
+
from pyspark.sql.column import Column
|
|
15
|
+
from pyspark.sql.dataframe import DataFrame
|
|
16
|
+
from pyspark.sql.types import StringType, TimestampType
|
|
17
|
+
from pyspark.sql.window import Window, WindowSpec
|
|
18
|
+
from scipy.fft import fft, fftfreq
|
|
19
|
+
|
|
20
|
+
import wedata.tempo.interpol as t_interpolation
|
|
21
|
+
import wedata.tempo.io as t_io
|
|
22
|
+
import wedata.tempo.resample as t_resample
|
|
23
|
+
import wedata.tempo.utils as t_utils
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TSDF:
|
|
29
|
+
"""
|
|
30
|
+
This object is the main wrapper over a Spark data frame which allows a user to parallelize time series computations on a Spark data frame by various dimensions. The two dimensions required are partition_cols (list of columns by which to summarize) and ts_col (timestamp column, which can be epoch or TimestampType).
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
summarizable_types = ["int", "bigint", "float", "double"]
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
df: DataFrame,
|
|
38
|
+
ts_col: str = "event_ts",
|
|
39
|
+
partition_cols: Optional[list[str]] = None,
|
|
40
|
+
sequence_col: Optional[str] = None,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Constructor
|
|
44
|
+
:param df:
|
|
45
|
+
:param ts_col:
|
|
46
|
+
:param partition_cols:
|
|
47
|
+
:sequence_col every tsdf allows for a tie-breaker secondary sort key
|
|
48
|
+
"""
|
|
49
|
+
self.ts_col = self.__validated_column(df, ts_col)
|
|
50
|
+
self.partitionCols = (
|
|
51
|
+
[]
|
|
52
|
+
if partition_cols is None
|
|
53
|
+
else self.__validated_columns(df, partition_cols.copy())
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
self.df = df
|
|
57
|
+
self.sequence_col = "" if sequence_col is None else sequence_col
|
|
58
|
+
|
|
59
|
+
# Add customized check for string type for the timestamp.
|
|
60
|
+
# If we see a string, we will proactively created a double
|
|
61
|
+
# version of the string timestamp for sorting purposes and
|
|
62
|
+
# rename to ts_col
|
|
63
|
+
|
|
64
|
+
# TODO : we validate the string is of a specific format. Spark will
|
|
65
|
+
# convert a valid formatted timestamp string to timestamp type so
|
|
66
|
+
# this if clause seems unneeded. Perhaps we should check for non-valid
|
|
67
|
+
# Timestamp string matching then do some pattern matching to extract
|
|
68
|
+
# the time stamp.
|
|
69
|
+
if isinstance(df.schema[ts_col].dataType, StringType): # pragma: no cover
|
|
70
|
+
sample_ts = df.select(ts_col).limit(1).head(1)[0][0]
|
|
71
|
+
self.__validate_ts_string(sample_ts)
|
|
72
|
+
self.df = (
|
|
73
|
+
self.__add_double_ts()
|
|
74
|
+
.drop(self.ts_col)
|
|
75
|
+
.withColumnRenamed("double_ts", self.ts_col)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
"""
|
|
79
|
+
Make sure DF is ordered by its respective ts_col and partition columns.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
#
|
|
83
|
+
# Helper functions
|
|
84
|
+
#
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def parse_nanos_timestamp(
|
|
88
|
+
df: DataFrame,
|
|
89
|
+
str_ts_col: str,
|
|
90
|
+
ts_fmt: str = "yyyy-MM-dd HH:mm:ss",
|
|
91
|
+
double_ts_col: Optional[str] = None,
|
|
92
|
+
parsed_ts_col: Optional[str] = None,
|
|
93
|
+
) -> DataFrame:
|
|
94
|
+
"""
|
|
95
|
+
Parse a string timestamp column with nanosecond precision into a double timestamp column.
|
|
96
|
+
|
|
97
|
+
:param df: DataFrame containing the string timestamp column
|
|
98
|
+
:param str_ts_col: Name of the string timestamp column
|
|
99
|
+
:param ts_fmt: Format of the string timestamp column (default: "yyyy-MM-dd HH:mm:ss")
|
|
100
|
+
:param double_ts_col: Name of the double timestamp column to create, if None
|
|
101
|
+
the source string column will be overwritten
|
|
102
|
+
:param parsed_ts_col: Name of the parsed timestamp column to create, if None
|
|
103
|
+
no parsed timestamp column will be kept
|
|
104
|
+
|
|
105
|
+
:return: DataFrame with the double timestamp column
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
# add a parsed timestamp column if requested
|
|
109
|
+
src_df = (
|
|
110
|
+
df.withColumn(parsed_ts_col, sfn.to_timestamp(sfn.col(str_ts_col), ts_fmt))
|
|
111
|
+
if parsed_ts_col
|
|
112
|
+
else df
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
return (
|
|
116
|
+
src_df.withColumn(
|
|
117
|
+
"nanos",
|
|
118
|
+
sfn.when(
|
|
119
|
+
sfn.col(str_ts_col).contains("."),
|
|
120
|
+
sfn.concat(sfn.lit("0."), sfn.split(sfn.col(str_ts_col), r"\.")[1]),
|
|
121
|
+
)
|
|
122
|
+
.otherwise(0)
|
|
123
|
+
.cast("double"),
|
|
124
|
+
)
|
|
125
|
+
.withColumn("long_ts", sfn.unix_timestamp(str_ts_col, ts_fmt))
|
|
126
|
+
.withColumn(
|
|
127
|
+
(double_ts_col or str_ts_col), sfn.col("long_ts") + sfn.col("nanos")
|
|
128
|
+
)
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def __add_double_ts(self) -> DataFrame:
|
|
132
|
+
"""Add a double (epoch) version of the string timestamp out to nanos"""
|
|
133
|
+
return (
|
|
134
|
+
self.df.withColumn(
|
|
135
|
+
"nanos",
|
|
136
|
+
(
|
|
137
|
+
sfn.when(
|
|
138
|
+
sfn.col(self.ts_col).contains("."),
|
|
139
|
+
sfn.concat(
|
|
140
|
+
sfn.lit("0."),
|
|
141
|
+
sfn.split(sfn.col(self.ts_col), r"\.")[1],
|
|
142
|
+
),
|
|
143
|
+
).otherwise(0)
|
|
144
|
+
).cast("double"),
|
|
145
|
+
)
|
|
146
|
+
.withColumn("long_ts", sfn.col(self.ts_col).cast("timestamp").cast("long"))
|
|
147
|
+
.withColumn("double_ts", sfn.col("long_ts") + sfn.col("nanos"))
|
|
148
|
+
.drop("nanos")
|
|
149
|
+
.drop("long_ts")
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
@staticmethod
|
|
153
|
+
def __validate_ts_string(ts_text: str) -> None:
|
|
154
|
+
"""Validate the format for the string using Regex matching for ts_string"""
|
|
155
|
+
import re
|
|
156
|
+
|
|
157
|
+
ts_pattern = r"^(\d{4}-\d{2}-\d{2}[T| ]\d{2}:\d{2}:\d{2})(\.\d+)?$"
|
|
158
|
+
if re.match(ts_pattern, ts_text) is None:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
"Incorrect data format, should be YYYY-MM-DD HH:MM:SS[.nnnnnnnn]"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def __validated_column(df: DataFrame, colname: str) -> str:
|
|
165
|
+
if not isinstance(colname, str):
|
|
166
|
+
raise TypeError(
|
|
167
|
+
f"Column names must be of type str; found {type(colname)} instead!"
|
|
168
|
+
)
|
|
169
|
+
if colname.lower() not in [col.lower() for col in df.columns]:
|
|
170
|
+
raise ValueError(f"Column {colname} not found in Dataframe")
|
|
171
|
+
return colname
|
|
172
|
+
|
|
173
|
+
def __validated_columns(
|
|
174
|
+
self, df: DataFrame, colnames: Optional[Union[str, List[str]]]
|
|
175
|
+
) -> List[str]:
|
|
176
|
+
# if provided a string, treat it as a single column
|
|
177
|
+
if isinstance(colnames, str):
|
|
178
|
+
colnames = [colnames]
|
|
179
|
+
# otherwise we really should have a list or None
|
|
180
|
+
elif colnames is None:
|
|
181
|
+
colnames = []
|
|
182
|
+
elif not isinstance(colnames, list):
|
|
183
|
+
raise TypeError(
|
|
184
|
+
f"Columns must be of type list, str, or None; found {type(colnames)} instead!"
|
|
185
|
+
)
|
|
186
|
+
# validate each column
|
|
187
|
+
for col in colnames:
|
|
188
|
+
self.__validated_column(df, col)
|
|
189
|
+
return colnames
|
|
190
|
+
|
|
191
|
+
def __checkPartitionCols(self, tsdf_right: "TSDF") -> None:
|
|
192
|
+
for left_col, right_col in zip(self.partitionCols, tsdf_right.partitionCols):
|
|
193
|
+
if left_col != right_col:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
"left and right dataframe partition columns should have same name in same order"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def __validateTsColMatch(self, right_tsdf: "TSDF") -> None:
|
|
199
|
+
left_ts_datatype = self.df.select(self.ts_col).dtypes[0][1]
|
|
200
|
+
right_ts_datatype = right_tsdf.df.select(self.ts_col).dtypes[0][1]
|
|
201
|
+
if left_ts_datatype != right_ts_datatype:
|
|
202
|
+
raise ValueError(
|
|
203
|
+
"left and right dataframe timestamp index columns should have same type"
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
def __addPrefixToColumns(self, col_list: list[str], prefix: str) -> "TSDF":
|
|
207
|
+
"""
|
|
208
|
+
Add prefix to all specified columns.
|
|
209
|
+
"""
|
|
210
|
+
# no-op if no prefix
|
|
211
|
+
if not prefix:
|
|
212
|
+
return self
|
|
213
|
+
|
|
214
|
+
# build a column rename map
|
|
215
|
+
col_map = {col: "_".join([prefix, col]) for col in col_list}
|
|
216
|
+
# TODO - In the future (when Spark 3.4+ is standard) we should implement batch rename using:
|
|
217
|
+
# df = self.df.withColumnsRenamed(col_map)
|
|
218
|
+
|
|
219
|
+
# build a list of column expressions to rename columns in a select
|
|
220
|
+
select_exprs = [
|
|
221
|
+
sfn.col(col).alias(col_map[col]) if col in col_map else sfn.col(col)
|
|
222
|
+
for col in self.df.columns
|
|
223
|
+
]
|
|
224
|
+
# select the renamed columns
|
|
225
|
+
renamed_df = self.df.select(*select_exprs)
|
|
226
|
+
|
|
227
|
+
# find the structural columns
|
|
228
|
+
ts_col = col_map.get(self.ts_col, self.ts_col)
|
|
229
|
+
partition_cols = [col_map.get(c, c) for c in self.partitionCols]
|
|
230
|
+
sequence_col = col_map.get(self.sequence_col, self.sequence_col)
|
|
231
|
+
return TSDF(renamed_df, ts_col, partition_cols, sequence_col=sequence_col)
|
|
232
|
+
|
|
233
|
+
def __addColumnsFromOtherDF(self, other_cols: Sequence[str]) -> "TSDF":
|
|
234
|
+
"""
|
|
235
|
+
Add columns from some other DF as lit(None), as pre-step before union.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
# build a list of column expressions to rename columns in a select
|
|
239
|
+
current_cols = [sfn.col(col) for col in self.df.columns]
|
|
240
|
+
new_cols = [sfn.lit(None).alias(col) for col in other_cols]
|
|
241
|
+
new_df = self.df.select(current_cols + new_cols)
|
|
242
|
+
|
|
243
|
+
return TSDF(new_df, self.ts_col, self.partitionCols)
|
|
244
|
+
|
|
245
|
+
def __combineTSDF(self, ts_df_right: "TSDF", combined_ts_col: str) -> "TSDF":
|
|
246
|
+
combined_df = self.df.unionByName(ts_df_right.df).withColumn(
|
|
247
|
+
combined_ts_col, sfn.coalesce(self.ts_col, ts_df_right.ts_col)
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
return TSDF(combined_df, combined_ts_col, self.partitionCols)
|
|
251
|
+
|
|
252
|
+
def __getLastRightRow(
|
|
253
|
+
self,
|
|
254
|
+
left_ts_col: str,
|
|
255
|
+
right_cols: list[str],
|
|
256
|
+
sequence_col: str,
|
|
257
|
+
tsPartitionVal: Optional[int],
|
|
258
|
+
ignoreNulls: bool,
|
|
259
|
+
suppress_null_warning: bool,
|
|
260
|
+
) -> "TSDF":
|
|
261
|
+
"""Get last right value of each right column (inc. right timestamp) for each self.ts_col value
|
|
262
|
+
|
|
263
|
+
self.ts_col, which is the combined time-stamp column of both left and right dataframe, is dropped at the end
|
|
264
|
+
since it is no longer used in subsequent methods.
|
|
265
|
+
"""
|
|
266
|
+
ptntl_sort_keys = [self.ts_col, "rec_ind"]
|
|
267
|
+
if sequence_col:
|
|
268
|
+
ptntl_sort_keys.append(sequence_col)
|
|
269
|
+
|
|
270
|
+
sort_keys = [
|
|
271
|
+
sfn.col(col_name) for col_name in ptntl_sort_keys if col_name != ""
|
|
272
|
+
]
|
|
273
|
+
|
|
274
|
+
window_spec = (
|
|
275
|
+
Window.partitionBy(self.partitionCols)
|
|
276
|
+
.orderBy(sort_keys)
|
|
277
|
+
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# generate expressions to find the last value of each right-hand column
|
|
281
|
+
if ignoreNulls is False:
|
|
282
|
+
if tsPartitionVal is not None:
|
|
283
|
+
raise ValueError(
|
|
284
|
+
"Disabling null skipping with a partition value is not supported yet."
|
|
285
|
+
)
|
|
286
|
+
mod_right_cols = [
|
|
287
|
+
sfn.last(
|
|
288
|
+
sfn.when(sfn.col("rec_ind") == -1, sfn.struct(col)).otherwise(None),
|
|
289
|
+
True,
|
|
290
|
+
)
|
|
291
|
+
.over(window_spec)[col]
|
|
292
|
+
.alias(col)
|
|
293
|
+
for col in right_cols
|
|
294
|
+
]
|
|
295
|
+
elif tsPartitionVal is None:
|
|
296
|
+
mod_right_cols = [
|
|
297
|
+
sfn.last(col, ignoreNulls).over(window_spec).alias(col)
|
|
298
|
+
for col in right_cols
|
|
299
|
+
]
|
|
300
|
+
else:
|
|
301
|
+
mod_right_cols = [
|
|
302
|
+
sfn.last(col, ignoreNulls).over(window_spec).alias(col)
|
|
303
|
+
for col in right_cols
|
|
304
|
+
]
|
|
305
|
+
# non-null count columns, these will be dropped below
|
|
306
|
+
mod_right_cols += [
|
|
307
|
+
sfn.count(col).over(window_spec).alias("non_null_ct" + col)
|
|
308
|
+
for col in right_cols
|
|
309
|
+
]
|
|
310
|
+
|
|
311
|
+
# select the left-hand side columns, and the modified right-hand side columns
|
|
312
|
+
non_right_cols = list(set(self.df.columns) - set(right_cols))
|
|
313
|
+
df = self.df.select(non_right_cols + mod_right_cols)
|
|
314
|
+
# drop the null left-hand side rows
|
|
315
|
+
df = (df.filter(sfn.col(left_ts_col).isNotNull()).drop(self.ts_col)).drop(
|
|
316
|
+
"rec_ind"
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# remove the null_ct stats used to record missing values in partitioned as of join
|
|
320
|
+
if tsPartitionVal is not None:
|
|
321
|
+
for column in df.columns:
|
|
322
|
+
if column.startswith("non_null"):
|
|
323
|
+
# Avoid collect() calls when explicitly ignoring the warnings about null values due to lookback
|
|
324
|
+
# window. if setting suppress_null_warning to True and warning logger is enabled for other part
|
|
325
|
+
# of the code, it would make sense to not log warning in this function while allowing other part
|
|
326
|
+
# of the code to continue to log warning. So it makes more sense for and than or on this line
|
|
327
|
+
if not suppress_null_warning and logger.isEnabledFor(
|
|
328
|
+
logging.WARNING
|
|
329
|
+
):
|
|
330
|
+
any_blank_vals = df.agg({column: "min"}).head(1)[0][0] == 0
|
|
331
|
+
newCol = column.replace("non_null_ct", "")
|
|
332
|
+
if any_blank_vals:
|
|
333
|
+
logger.warning(
|
|
334
|
+
"Column "
|
|
335
|
+
+ newCol
|
|
336
|
+
+ " had no values within the lookback window. Consider using a larger window to avoid missing values. If this is the first record in the data frame, this warning can be ignored."
|
|
337
|
+
)
|
|
338
|
+
df = df.drop(column)
|
|
339
|
+
|
|
340
|
+
return TSDF(df, left_ts_col, self.partitionCols)
|
|
341
|
+
|
|
342
|
+
def __getTimePartitions(self, tsPartitionVal: int, fraction: float = 0.1) -> "TSDF":
|
|
343
|
+
"""
|
|
344
|
+
Create time-partitions for our data-set. We put our time-stamps into brackets of <tsPartitionVal>. Timestamps
|
|
345
|
+
are rounded down to the nearest <tsPartitionVal> seconds.
|
|
346
|
+
|
|
347
|
+
We cast our timestamp column to double instead of using f.unix_timestamp, since it provides more precision.
|
|
348
|
+
|
|
349
|
+
Additionally, we make these partitions overlapping by adding a remainder df. This way when calculating the
|
|
350
|
+
last right timestamp we will not end up with nulls for the first left timestamp in each partition.
|
|
351
|
+
|
|
352
|
+
TODO: change ts_partition to accommodate for higher precision than seconds.
|
|
353
|
+
"""
|
|
354
|
+
partition_df = (
|
|
355
|
+
self.df.withColumn(
|
|
356
|
+
"ts_col_double", sfn.col(self.ts_col).cast("double")
|
|
357
|
+
) # double is preferred over unix_timestamp
|
|
358
|
+
.withColumn(
|
|
359
|
+
"ts_partition",
|
|
360
|
+
sfn.lit(tsPartitionVal)
|
|
361
|
+
* (sfn.col("ts_col_double") / sfn.lit(tsPartitionVal)).cast("integer"),
|
|
362
|
+
)
|
|
363
|
+
.withColumn(
|
|
364
|
+
"partition_remainder",
|
|
365
|
+
(sfn.col("ts_col_double") - sfn.col("ts_partition"))
|
|
366
|
+
/ sfn.lit(tsPartitionVal),
|
|
367
|
+
)
|
|
368
|
+
.withColumn("is_original", sfn.lit(1))
|
|
369
|
+
).cache() # cache it because it's used twice.
|
|
370
|
+
|
|
371
|
+
# add [1 - fraction] of previous time partition to the next partition.
|
|
372
|
+
remainder_df = (
|
|
373
|
+
partition_df.filter(sfn.col("partition_remainder") >= sfn.lit(1 - fraction))
|
|
374
|
+
.withColumn(
|
|
375
|
+
"ts_partition", sfn.col("ts_partition") + sfn.lit(tsPartitionVal)
|
|
376
|
+
)
|
|
377
|
+
.withColumn("is_original", sfn.lit(0))
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
df = partition_df.union(remainder_df).drop(
|
|
381
|
+
"partition_remainder", "ts_col_double"
|
|
382
|
+
)
|
|
383
|
+
return TSDF(df, self.ts_col, self.partitionCols + ["ts_partition"])
|
|
384
|
+
|
|
385
|
+
#
|
|
386
|
+
# Slicing & Selection
|
|
387
|
+
#
|
|
388
|
+
|
|
389
|
+
def select(self, *cols: Union[str, List[str]]) -> "TSDF":
|
|
390
|
+
"""
|
|
391
|
+
pyspark.sql.DataFrame.select() method's equivalent for TSDF objects
|
|
392
|
+
|
|
393
|
+
:param cols: str or list of strs column names (string). If one of the column names is '*', that
|
|
394
|
+
column is expanded to include all columns in the current :class:`TSDF`.
|
|
395
|
+
|
|
396
|
+
## Examples
|
|
397
|
+
.. code-block:: python
|
|
398
|
+
tsdf.select('*').collect()
|
|
399
|
+
[Row(age=2, name='Alice'), Row(age=5, name='Bob')]
|
|
400
|
+
tsdf.select('name', 'age').collect()
|
|
401
|
+
[Row(name='Alice', age=2), Row(name='Bob', age=5)]
|
|
402
|
+
|
|
403
|
+
"""
|
|
404
|
+
|
|
405
|
+
# The columns which will be a mandatory requirement while selecting from TSDFs
|
|
406
|
+
seq_col_stub = [] if bool(self.sequence_col) is False else [self.sequence_col]
|
|
407
|
+
mandatory_cols = [self.ts_col] + self.partitionCols + seq_col_stub
|
|
408
|
+
if set(mandatory_cols).issubset(set(cols)):
|
|
409
|
+
return TSDF(
|
|
410
|
+
self.df.select(*cols),
|
|
411
|
+
self.ts_col,
|
|
412
|
+
self.partitionCols,
|
|
413
|
+
self.sequence_col,
|
|
414
|
+
)
|
|
415
|
+
else:
|
|
416
|
+
raise Exception(
|
|
417
|
+
"In TSDF's select statement original ts_col, partitionCols and seq_col_stub(optional) must be present"
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
def __slice(self, op: str, target_ts: Union[str, int]) -> "TSDF":
|
|
421
|
+
"""
|
|
422
|
+
Private method to slice TSDF by time
|
|
423
|
+
|
|
424
|
+
:param op: string symbol of the operation to perform
|
|
425
|
+
:type op: str
|
|
426
|
+
:param target_ts: timestamp on which to filter
|
|
427
|
+
|
|
428
|
+
:return: a TSDF object containing only those records within the time slice specified
|
|
429
|
+
"""
|
|
430
|
+
# quote our timestamp if its a string
|
|
431
|
+
target_expr = f"'{target_ts}'" if isinstance(target_ts, str) else target_ts
|
|
432
|
+
slice_expr = sfn.expr(f"{self.ts_col} {op} {target_expr}")
|
|
433
|
+
sliced_df = self.df.where(slice_expr)
|
|
434
|
+
return TSDF(
|
|
435
|
+
sliced_df,
|
|
436
|
+
ts_col=self.ts_col,
|
|
437
|
+
partition_cols=self.partitionCols,
|
|
438
|
+
sequence_col=self.sequence_col,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
def at(self, ts: Union[str, int]) -> "TSDF":
|
|
442
|
+
"""
|
|
443
|
+
Select only records at a given time
|
|
444
|
+
|
|
445
|
+
:param ts: timestamp of the records to select
|
|
446
|
+
|
|
447
|
+
:return: a :class:`~tsdf.TSDF` object containing just the records at the given time
|
|
448
|
+
"""
|
|
449
|
+
return self.__slice("==", ts)
|
|
450
|
+
|
|
451
|
+
def before(self, ts: Union[str, int]) -> "TSDF":
|
|
452
|
+
"""
|
|
453
|
+
Select only records before a given time
|
|
454
|
+
|
|
455
|
+
:param ts: timestamp on which to filter records
|
|
456
|
+
|
|
457
|
+
:return: a :class:`~tsdf.TSDF` object containing just the records before the given time
|
|
458
|
+
"""
|
|
459
|
+
return self.__slice("<", ts)
|
|
460
|
+
|
|
461
|
+
def atOrBefore(self, ts: Union[str, int]) -> "TSDF":
|
|
462
|
+
"""
|
|
463
|
+
Select only records at or before a given time
|
|
464
|
+
|
|
465
|
+
:param ts: timestamp on which to filter records
|
|
466
|
+
|
|
467
|
+
:return: a :class:`~tsdf.TSDF` object containing just the records at or before the given time
|
|
468
|
+
"""
|
|
469
|
+
return self.__slice("<=", ts)
|
|
470
|
+
|
|
471
|
+
def after(self, ts: Union[str, int]) -> "TSDF":
|
|
472
|
+
"""
|
|
473
|
+
Select only records after a given time
|
|
474
|
+
|
|
475
|
+
:param ts: timestamp on which to filter records
|
|
476
|
+
|
|
477
|
+
:return: a :class:`~tsdf.TSDF` object containing just the records after the given time
|
|
478
|
+
"""
|
|
479
|
+
return self.__slice(">", ts)
|
|
480
|
+
|
|
481
|
+
def atOrAfter(self, ts: Union[str, int]) -> "TSDF":
|
|
482
|
+
"""
|
|
483
|
+
Select only records at or after a given time
|
|
484
|
+
|
|
485
|
+
:param ts: timestamp on which to filter records
|
|
486
|
+
|
|
487
|
+
:return: a :class:`~tsdf.TSDF` object containing just the records at or after the given time
|
|
488
|
+
"""
|
|
489
|
+
return self.__slice(">=", ts)
|
|
490
|
+
|
|
491
|
+
def between(
|
|
492
|
+
self, start_ts: Union[str, int], end_ts: Union[str, int], inclusive: bool = True
|
|
493
|
+
) -> "TSDF":
|
|
494
|
+
"""
|
|
495
|
+
Select only records in a given range
|
|
496
|
+
|
|
497
|
+
:param start_ts: starting time of the range to select
|
|
498
|
+
:param end_ts: ending time of the range to select
|
|
499
|
+
:param inclusive: whether the range is inclusive of the endpoints or not, defaults to True
|
|
500
|
+
:type inclusive: bool
|
|
501
|
+
|
|
502
|
+
:return: a :class:`~tsdf.TSDF` object containing just the records within the range specified
|
|
503
|
+
"""
|
|
504
|
+
if inclusive:
|
|
505
|
+
return self.atOrAfter(start_ts).atOrBefore(end_ts)
|
|
506
|
+
return self.after(start_ts).before(end_ts)
|
|
507
|
+
|
|
508
|
+
def __top_rows_per_series(self, win: WindowSpec, n: int) -> "TSDF":
|
|
509
|
+
"""
|
|
510
|
+
Private method to select just the top n rows per series (as defined by a window ordering)
|
|
511
|
+
|
|
512
|
+
:param win: the window on which we order the rows in each series
|
|
513
|
+
:param n: the number of rows to return
|
|
514
|
+
|
|
515
|
+
:return: a :class:`~tsdf.TSDF` object containing just the top n rows in each series
|
|
516
|
+
"""
|
|
517
|
+
row_num_col = "__row_num"
|
|
518
|
+
prev_records_df = (
|
|
519
|
+
self.df.withColumn(row_num_col, sfn.row_number().over(win))
|
|
520
|
+
.where(sfn.col(row_num_col) <= sfn.lit(n))
|
|
521
|
+
.drop(row_num_col)
|
|
522
|
+
)
|
|
523
|
+
return TSDF(
|
|
524
|
+
prev_records_df,
|
|
525
|
+
ts_col=self.ts_col,
|
|
526
|
+
partition_cols=self.partitionCols,
|
|
527
|
+
sequence_col=self.sequence_col,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
def earliest(self, n: int = 1) -> "TSDF":
|
|
531
|
+
"""
|
|
532
|
+
Select the earliest n records for each series
|
|
533
|
+
|
|
534
|
+
:param n: number of records to select (default is 1)
|
|
535
|
+
|
|
536
|
+
:return: a :class:`~tsdf.TSDF` object containing the earliest n records for each series
|
|
537
|
+
"""
|
|
538
|
+
prev_window = self.__baseWindow(reverse=False)
|
|
539
|
+
return self.__top_rows_per_series(prev_window, n)
|
|
540
|
+
|
|
541
|
+
def latest(self, n: int = 1) -> "TSDF":
|
|
542
|
+
"""
|
|
543
|
+
Select the latest n records for each series
|
|
544
|
+
|
|
545
|
+
:param n: number of records to select (default is 1)
|
|
546
|
+
|
|
547
|
+
:return: a :class:`~tsdf.TSDF` object containing the latest n records for each series
|
|
548
|
+
"""
|
|
549
|
+
next_window = self.__baseWindow(reverse=True)
|
|
550
|
+
return self.__top_rows_per_series(next_window, n)
|
|
551
|
+
|
|
552
|
+
def priorTo(self, ts: Union[str, int], n: int = 1) -> "TSDF":
|
|
553
|
+
"""
|
|
554
|
+
Select the n most recent records prior to a given time
|
|
555
|
+
You can think of this like an 'asOf' select - it selects the records as of a particular time
|
|
556
|
+
|
|
557
|
+
:param ts: timestamp on which to filter records
|
|
558
|
+
:param n: number of records to select (default is 1)
|
|
559
|
+
|
|
560
|
+
:return: a :class:`~tsdf.TSDF` object containing the n records prior to the given time
|
|
561
|
+
"""
|
|
562
|
+
return self.atOrBefore(ts).latest(n)
|
|
563
|
+
|
|
564
|
+
def subsequentTo(self, ts: Union[str, int], n: int = 1) -> "TSDF":
|
|
565
|
+
"""
|
|
566
|
+
Select the n records subsequent to a give time
|
|
567
|
+
|
|
568
|
+
:param ts: timestamp on which to filter records
|
|
569
|
+
:param n: number of records to select (default is 1)
|
|
570
|
+
|
|
571
|
+
:return: a :class:`~tsdf.TSDF` object containing the n records subsequent to the given time
|
|
572
|
+
"""
|
|
573
|
+
return self.atOrAfter(ts).earliest(n)
|
|
574
|
+
|
|
575
|
+
#
|
|
576
|
+
# Display functions
|
|
577
|
+
#
|
|
578
|
+
|
|
579
|
+
def show(
|
|
580
|
+
self, n: int = 20, k: int = 5, truncate: bool = True, vertical: bool = False
|
|
581
|
+
) -> None:
|
|
582
|
+
"""
|
|
583
|
+
pyspark.sql.DataFrame.show() method's equivalent for TSDF objects
|
|
584
|
+
|
|
585
|
+
:param n: Number of rows to show. (default: 20)
|
|
586
|
+
:param truncate: If set to True, truncate strings longer than 20 chars by default.
|
|
587
|
+
If set to a number greater than one, truncates long strings to length truncate
|
|
588
|
+
and align cells right.
|
|
589
|
+
:param vertical: If set to True, print output rows vertically (one line per column value).
|
|
590
|
+
|
|
591
|
+
## Example to show usage:
|
|
592
|
+
.. code-block:: python
|
|
593
|
+
from pyspark.sql.functions import *
|
|
594
|
+
|
|
595
|
+
phone_accel_df = spark.read.format("csv").option("header", "true").load("dbfs:/home/tempo/Phones_accelerometer") \n
|
|
596
|
+
.withColumn("event_ts", (col("Arrival_Time").cast("double")/1000).cast("timestamp")) \n
|
|
597
|
+
.withColumn("x", col("x").cast("double")) \n
|
|
598
|
+
.withColumn("y", col("y").cast("double")) \n
|
|
599
|
+
.withColumn("z", col("z").cast("double")) \n
|
|
600
|
+
.withColumn("event_ts_dbl", col("event_ts").cast("double"))
|
|
601
|
+
|
|
602
|
+
from tempo import *
|
|
603
|
+
|
|
604
|
+
phone_accel_tsdf = TSDF(phone_accel_df, ts_col="event_ts", partition_cols = ["User"])
|
|
605
|
+
|
|
606
|
+
# Call show method here
|
|
607
|
+
phone_accel_tsdf.show()
|
|
608
|
+
"""
|
|
609
|
+
|
|
610
|
+
# validate k <= n
|
|
611
|
+
if k > n:
|
|
612
|
+
raise ValueError(f"Parameter k {k} cannot be greater than parameter n {n}")
|
|
613
|
+
|
|
614
|
+
if not t_utils.IS_WEDATA and t_utils.ENV_CAN_RENDER_HTML:
|
|
615
|
+
# In Jupyter notebooks, for wide dataframes the below line will enable
|
|
616
|
+
# rendering the output in a scrollable format.
|
|
617
|
+
ipydisplay(
|
|
618
|
+
HTML("<style>pre { white-space: pre !important; }</style>")
|
|
619
|
+
) # pragma: no cover
|
|
620
|
+
t_utils.get_display_df(self, k).show(n, truncate, vertical)
|
|
621
|
+
|
|
622
|
+
def describe(self) -> DataFrame:
|
|
623
|
+
"""
|
|
624
|
+
Describe a TSDF object using a global summary across all time series (anywhere from 10 to millions) as well as the standard Spark data frame stats. Missing vals
|
|
625
|
+
Summary
|
|
626
|
+
global - unique time series based on partition columns, min/max times, granularity - lowest precision in the time series timestamp column
|
|
627
|
+
count / mean / stddev / min / max - standard Spark data frame describe() output
|
|
628
|
+
missing_vals_pct - percentage (from 0 to 100) of missing values.
|
|
629
|
+
"""
|
|
630
|
+
# extract the double version of the timestamp column to summarize
|
|
631
|
+
double_ts_col = self.ts_col + "_dbl"
|
|
632
|
+
|
|
633
|
+
this_df = self.df.withColumn(double_ts_col, sfn.col(self.ts_col).cast("double"))
|
|
634
|
+
|
|
635
|
+
# summary missing value percentages
|
|
636
|
+
missing_vals = this_df.select(
|
|
637
|
+
[
|
|
638
|
+
(
|
|
639
|
+
100
|
|
640
|
+
* sfn.count(sfn.when(sfn.col(c[0]).isNull(), c[0]))
|
|
641
|
+
/ sfn.count(sfn.lit(1))
|
|
642
|
+
).alias(c[0])
|
|
643
|
+
for c in this_df.dtypes
|
|
644
|
+
if c[1] != "timestamp"
|
|
645
|
+
]
|
|
646
|
+
).select(sfn.lit("missing_vals_pct").alias("summary"), "*")
|
|
647
|
+
|
|
648
|
+
# describe stats
|
|
649
|
+
desc_stats = this_df.describe().union(missing_vals)
|
|
650
|
+
unique_ts = this_df.select(*self.partitionCols).distinct().count()
|
|
651
|
+
|
|
652
|
+
max_ts = this_df.select(sfn.max(sfn.col(self.ts_col)).alias("max_ts")).head(1)[
|
|
653
|
+
0
|
|
654
|
+
][0]
|
|
655
|
+
min_ts = this_df.select(sfn.min(sfn.col(self.ts_col)).alias("max_ts")).head(1)[
|
|
656
|
+
0
|
|
657
|
+
][0]
|
|
658
|
+
gran = this_df.selectExpr(
|
|
659
|
+
"""min(case when {0} - cast({0} as integer) > 0 then '1-millis'
|
|
660
|
+
when {0} % 60 != 0 then '2-seconds'
|
|
661
|
+
when {0} % 3600 != 0 then '3-minutes'
|
|
662
|
+
when {0} % 86400 != 0 then '4-hours'
|
|
663
|
+
else '5-days' end) granularity""".format(
|
|
664
|
+
double_ts_col
|
|
665
|
+
)
|
|
666
|
+
).head(1)[0][0][2:]
|
|
667
|
+
|
|
668
|
+
non_summary_cols = [c for c in desc_stats.columns if c != "summary"]
|
|
669
|
+
|
|
670
|
+
desc_stats = desc_stats.select(
|
|
671
|
+
sfn.col("summary"),
|
|
672
|
+
sfn.lit(" ").alias("unique_ts_count"),
|
|
673
|
+
sfn.lit(" ").alias("min_ts"),
|
|
674
|
+
sfn.lit(" ").alias("max_ts"),
|
|
675
|
+
sfn.lit(" ").alias("granularity"),
|
|
676
|
+
*non_summary_cols,
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
# add in single record with global summary attributes and the previously computed missing value and Spark data frame describe stats
|
|
680
|
+
global_smry_rec = desc_stats.limit(1).select(
|
|
681
|
+
sfn.lit("global").alias("summary"),
|
|
682
|
+
sfn.lit(unique_ts).alias("unique_ts_count"),
|
|
683
|
+
sfn.lit(min_ts).alias("min_ts"),
|
|
684
|
+
sfn.lit(max_ts).alias("max_ts"),
|
|
685
|
+
sfn.lit(gran).alias("granularity"),
|
|
686
|
+
*[sfn.lit(" ").alias(c) for c in non_summary_cols],
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
full_smry = global_smry_rec.union(desc_stats)
|
|
690
|
+
full_smry = full_smry.withColumnRenamed(
|
|
691
|
+
"unique_ts_count", "unique_time_series_count"
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
try: # pragma: no cover
|
|
695
|
+
dbutils.fs.ls("/") # type: ignore
|
|
696
|
+
return full_smry
|
|
697
|
+
# TODO: Can we raise something other than generic Exception?
|
|
698
|
+
# perhaps refactor to check for IS_WEDATA
|
|
699
|
+
except Exception:
|
|
700
|
+
return full_smry
|
|
701
|
+
|
|
702
|
+
def __getSparkPlan(self, df: DataFrame, spark: SparkSession) -> str:
|
|
703
|
+
"""
|
|
704
|
+
Internal helper function to obtain the Spark plan for the input data frame
|
|
705
|
+
|
|
706
|
+
Parameters
|
|
707
|
+
:param df - input Spark data frame - the AS OF join has 2 data frames; this will be called for each
|
|
708
|
+
:param spark - Spark session which is used to query the view obtained from the Spark data frame
|
|
709
|
+
"""
|
|
710
|
+
|
|
711
|
+
df.createOrReplaceTempView("view")
|
|
712
|
+
plan = spark.sql("explain cost select * from view").head(1)[0][0]
|
|
713
|
+
|
|
714
|
+
return plan
|
|
715
|
+
|
|
716
|
+
def __getBytesFromPlan(self, df: DataFrame, spark: SparkSession) -> float:
|
|
717
|
+
"""
|
|
718
|
+
Internal helper function to obtain how many bytes in memory the Spark data
|
|
719
|
+
frame is likely to take up. This is an upper bound and is obtained from the
|
|
720
|
+
plan details in Spark
|
|
721
|
+
|
|
722
|
+
Parameters
|
|
723
|
+
:param df - input Spark data frame - the AS OF join has 2 data frames; this will be called for each
|
|
724
|
+
:param spark - Spark session which is used to query the view obtained from the Spark data frame
|
|
725
|
+
"""
|
|
726
|
+
|
|
727
|
+
plan = self.__getSparkPlan(df, spark)
|
|
728
|
+
|
|
729
|
+
import re
|
|
730
|
+
|
|
731
|
+
search_result = re.search(r"sizeInBytes=.*(['\)])", plan, re.MULTILINE)
|
|
732
|
+
if search_result is not None:
|
|
733
|
+
result = search_result.group(0).replace(")", "")
|
|
734
|
+
else:
|
|
735
|
+
raise ValueError("Unable to obtain sizeInBytes from Spark plan")
|
|
736
|
+
|
|
737
|
+
size = result.split("=")[1].split(" ")[0]
|
|
738
|
+
units = result.split("=")[1].split(" ")[1]
|
|
739
|
+
|
|
740
|
+
# perform to MB for threshold check
|
|
741
|
+
if units == "GiB":
|
|
742
|
+
plan_bytes = float(size) * 1024 * 1024 * 1024
|
|
743
|
+
elif units == "MiB":
|
|
744
|
+
plan_bytes = float(size) * 1024 * 1024
|
|
745
|
+
elif units == "KiB":
|
|
746
|
+
plan_bytes = float(size) * 1024
|
|
747
|
+
else:
|
|
748
|
+
plan_bytes = float(size)
|
|
749
|
+
|
|
750
|
+
return plan_bytes
|
|
751
|
+
|
|
752
|
+
def asofJoin(
|
|
753
|
+
self,
|
|
754
|
+
right_tsdf: "TSDF",
|
|
755
|
+
left_prefix: Optional[str] = None,
|
|
756
|
+
right_prefix: str = "right",
|
|
757
|
+
tsPartitionVal: Optional[int] = None,
|
|
758
|
+
fraction: float = 0.5,
|
|
759
|
+
skipNulls: bool = True,
|
|
760
|
+
sql_join_opt: bool = False,
|
|
761
|
+
suppress_null_warning: bool = False,
|
|
762
|
+
tolerance: Optional[int] = None,
|
|
763
|
+
) -> "TSDF":
|
|
764
|
+
"""
|
|
765
|
+
Performs an as-of join between two time-series. If a tsPartitionVal is
|
|
766
|
+
specified, it will do this partitioned by time brackets, which can help alleviate skew.
|
|
767
|
+
|
|
768
|
+
NOTE: partition cols have to be the same for both Dataframes. We are
|
|
769
|
+
collecting stats when the WARNING level is enabled also.
|
|
770
|
+
|
|
771
|
+
Parameters
|
|
772
|
+
:param right_tsdf - right-hand data frame containing columns to merge in
|
|
773
|
+
:param left_prefix - optional prefix for base data frame
|
|
774
|
+
:param right_prefix - optional prefix for right-hand data frame
|
|
775
|
+
:param tsPartitionVal - value to break up each partition into time brackets
|
|
776
|
+
:param fraction - overlap fraction
|
|
777
|
+
:param skipNulls - whether to skip nulls when joining in values
|
|
778
|
+
:param sql_join_opt - if set to True, will use standard Spark SQL join if it is estimated to be efficient
|
|
779
|
+
:param suppress_null_warning - when tsPartitionVal is specified, will collect min of each column and raise warnings about null values, set to True to avoid
|
|
780
|
+
:param tolerance - only join values within this tolerance range (inclusive), expressed in number of seconds as a double
|
|
781
|
+
"""
|
|
782
|
+
|
|
783
|
+
# first block of logic checks whether a standard range join will suffice
|
|
784
|
+
left_df = self.df
|
|
785
|
+
right_df = right_tsdf.df
|
|
786
|
+
|
|
787
|
+
# test if the broadcast join will be efficient
|
|
788
|
+
if sql_join_opt:
|
|
789
|
+
spark = SparkSession.builder.getOrCreate()
|
|
790
|
+
left_bytes = self.__getBytesFromPlan(left_df, spark)
|
|
791
|
+
right_bytes = self.__getBytesFromPlan(right_df, spark)
|
|
792
|
+
|
|
793
|
+
# choose 30MB as the cutoff for the broadcast
|
|
794
|
+
bytes_threshold = 30 * 1024 * 1024
|
|
795
|
+
if (left_bytes < bytes_threshold) or (right_bytes < bytes_threshold):
|
|
796
|
+
spark.conf.set("spark.wedata.optimizer.rangeJoin.binSize", 60)
|
|
797
|
+
partition_cols = right_tsdf.partitionCols
|
|
798
|
+
left_cols = list(set(left_df.columns) - set(self.partitionCols))
|
|
799
|
+
right_cols = list(set(right_df.columns) - set(right_tsdf.partitionCols))
|
|
800
|
+
|
|
801
|
+
left_prefix = left_prefix + "_" if left_prefix else ""
|
|
802
|
+
right_prefix = right_prefix + "_" if right_prefix else ""
|
|
803
|
+
|
|
804
|
+
w = Window.partitionBy(*partition_cols).orderBy(
|
|
805
|
+
right_prefix + right_tsdf.ts_col
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
new_left_ts_col = left_prefix + self.ts_col
|
|
809
|
+
new_left_cols = [
|
|
810
|
+
sfn.col(c).alias(left_prefix + c) for c in left_cols
|
|
811
|
+
] + partition_cols
|
|
812
|
+
new_right_cols = [
|
|
813
|
+
sfn.col(c).alias(right_prefix + c) for c in right_cols
|
|
814
|
+
] + partition_cols
|
|
815
|
+
quotes_df_w_lag = right_df.select(*new_right_cols).withColumn(
|
|
816
|
+
"lead_" + right_tsdf.ts_col,
|
|
817
|
+
sfn.lead(right_prefix + right_tsdf.ts_col).over(w),
|
|
818
|
+
)
|
|
819
|
+
left_df = left_df.select(*new_left_cols)
|
|
820
|
+
res = (
|
|
821
|
+
left_df.join(quotes_df_w_lag, partition_cols)
|
|
822
|
+
.where(
|
|
823
|
+
left_df[new_left_ts_col].between(
|
|
824
|
+
sfn.col(right_prefix + right_tsdf.ts_col),
|
|
825
|
+
sfn.coalesce(
|
|
826
|
+
sfn.col("lead_" + right_tsdf.ts_col),
|
|
827
|
+
sfn.lit("2099-01-01").cast("timestamp"),
|
|
828
|
+
),
|
|
829
|
+
)
|
|
830
|
+
)
|
|
831
|
+
.drop("lead_" + right_tsdf.ts_col)
|
|
832
|
+
)
|
|
833
|
+
return TSDF(
|
|
834
|
+
res, partition_cols=self.partitionCols, ts_col=new_left_ts_col
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
# end of block checking to see if standard Spark SQL join will work
|
|
838
|
+
|
|
839
|
+
if tsPartitionVal is not None:
|
|
840
|
+
logger.warning(
|
|
841
|
+
"You are using the skew version of the AS OF join. This may result in null values if there are any "
|
|
842
|
+
"values outside of the maximum lookback. For maximum efficiency, choose smaller values of maximum "
|
|
843
|
+
"lookback, trading off performance and potential blank AS OF values for sparse keys"
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
# Check whether partition columns have same name in both dataframes
|
|
847
|
+
self.__checkPartitionCols(right_tsdf)
|
|
848
|
+
|
|
849
|
+
# prefix non-partition columns, to avoid duplicated columns.
|
|
850
|
+
left_df = self.df
|
|
851
|
+
right_df = right_tsdf.df
|
|
852
|
+
|
|
853
|
+
# validate timestamp datatypes match
|
|
854
|
+
self.__validateTsColMatch(right_tsdf)
|
|
855
|
+
|
|
856
|
+
orig_left_col_diff = list(
|
|
857
|
+
set(left_df.columns).difference(set(self.partitionCols))
|
|
858
|
+
)
|
|
859
|
+
orig_right_col_diff = list(
|
|
860
|
+
set(right_df.columns).difference(set(self.partitionCols))
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
left_tsdf = (
|
|
864
|
+
(self.__addPrefixToColumns([self.ts_col] + orig_left_col_diff, left_prefix))
|
|
865
|
+
if left_prefix is not None
|
|
866
|
+
else self
|
|
867
|
+
)
|
|
868
|
+
right_tsdf = right_tsdf.__addPrefixToColumns(
|
|
869
|
+
[right_tsdf.ts_col] + orig_right_col_diff, right_prefix
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
left_columns = list(
|
|
873
|
+
set(left_tsdf.df.columns).difference(set(self.partitionCols))
|
|
874
|
+
)
|
|
875
|
+
right_columns = list(
|
|
876
|
+
set(right_tsdf.df.columns).difference(set(self.partitionCols))
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
# Union both dataframes, and create a combined TS column
|
|
880
|
+
combined_ts_col = "combined_ts"
|
|
881
|
+
combined_df = left_tsdf.__addColumnsFromOtherDF(right_columns).__combineTSDF(
|
|
882
|
+
right_tsdf.__addColumnsFromOtherDF(left_columns), combined_ts_col
|
|
883
|
+
)
|
|
884
|
+
combined_df.df = combined_df.df.withColumn(
|
|
885
|
+
"rec_ind",
|
|
886
|
+
sfn.when(sfn.col(left_tsdf.ts_col).isNotNull(), 1).otherwise(-1),
|
|
887
|
+
)
|
|
888
|
+
|
|
889
|
+
# perform asof join.
|
|
890
|
+
if tsPartitionVal is None:
|
|
891
|
+
asofDF = combined_df.__getLastRightRow(
|
|
892
|
+
left_tsdf.ts_col,
|
|
893
|
+
right_columns,
|
|
894
|
+
right_tsdf.sequence_col,
|
|
895
|
+
tsPartitionVal,
|
|
896
|
+
skipNulls,
|
|
897
|
+
suppress_null_warning,
|
|
898
|
+
)
|
|
899
|
+
else:
|
|
900
|
+
tsPartitionDF = combined_df.__getTimePartitions(
|
|
901
|
+
tsPartitionVal, fraction=fraction
|
|
902
|
+
)
|
|
903
|
+
asofDF = tsPartitionDF.__getLastRightRow(
|
|
904
|
+
left_tsdf.ts_col,
|
|
905
|
+
right_columns,
|
|
906
|
+
right_tsdf.sequence_col,
|
|
907
|
+
tsPartitionVal,
|
|
908
|
+
skipNulls,
|
|
909
|
+
suppress_null_warning,
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
# Get rid of overlapped data and the extra columns generated from timePartitions
|
|
913
|
+
df = asofDF.df.filter(sfn.col("is_original") == 1).drop(
|
|
914
|
+
"ts_partition", "is_original"
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
asofDF = TSDF(df, asofDF.ts_col, combined_df.partitionCols)
|
|
918
|
+
|
|
919
|
+
if tolerance is not None:
|
|
920
|
+
df = asofDF.df
|
|
921
|
+
left_ts_col = left_tsdf.ts_col
|
|
922
|
+
right_ts_col = right_tsdf.ts_col
|
|
923
|
+
tolerance_condition = (
|
|
924
|
+
df[left_ts_col].cast("double") - df[right_ts_col].cast("double")
|
|
925
|
+
> tolerance
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
for right_col in right_columns:
|
|
929
|
+
# First set right non-timestamp columns to null for rows outside of tolerance band
|
|
930
|
+
if right_col != right_ts_col:
|
|
931
|
+
df = df.withColumn(
|
|
932
|
+
right_col,
|
|
933
|
+
sfn.when(tolerance_condition, sfn.lit(None)).otherwise(
|
|
934
|
+
df[right_col]
|
|
935
|
+
),
|
|
936
|
+
)
|
|
937
|
+
|
|
938
|
+
# Finally, set right timestamp column to null for rows outside of tolerance band
|
|
939
|
+
df = df.withColumn(
|
|
940
|
+
right_ts_col,
|
|
941
|
+
sfn.when(tolerance_condition, sfn.lit(None)).otherwise(
|
|
942
|
+
df[right_ts_col]
|
|
943
|
+
),
|
|
944
|
+
)
|
|
945
|
+
asofDF.df = df
|
|
946
|
+
|
|
947
|
+
return asofDF
|
|
948
|
+
|
|
949
|
+
def __baseWindow(
|
|
950
|
+
self, sort_col: Optional[str] = None, reverse: bool = False
|
|
951
|
+
) -> WindowSpec:
|
|
952
|
+
# figure out our sorting columns
|
|
953
|
+
primary_sort_col = self.ts_col if not sort_col else sort_col
|
|
954
|
+
sort_cols = (
|
|
955
|
+
[primary_sort_col, self.sequence_col]
|
|
956
|
+
if self.sequence_col
|
|
957
|
+
else [primary_sort_col]
|
|
958
|
+
)
|
|
959
|
+
|
|
960
|
+
# are we ordering forwards (default) or reveresed?
|
|
961
|
+
col_fn = sfn.col
|
|
962
|
+
if reverse:
|
|
963
|
+
col_fn = lambda colname: sfn.col(colname).desc() # noqa E731
|
|
964
|
+
|
|
965
|
+
# our window will be sorted on our sort_cols in the appropriate direction
|
|
966
|
+
w = Window().orderBy([col_fn(col) for col in sort_cols])
|
|
967
|
+
# and partitioned by any series IDs
|
|
968
|
+
if self.partitionCols:
|
|
969
|
+
w = w.partitionBy([sfn.col(elem) for elem in self.partitionCols])
|
|
970
|
+
return w
|
|
971
|
+
|
|
972
|
+
def __rangeBetweenWindow(
|
|
973
|
+
self,
|
|
974
|
+
range_from: int,
|
|
975
|
+
range_to: int,
|
|
976
|
+
sort_col: Optional[str] = None,
|
|
977
|
+
reverse: bool = False,
|
|
978
|
+
) -> WindowSpec:
|
|
979
|
+
return self.__baseWindow(sort_col=sort_col, reverse=reverse).rangeBetween(
|
|
980
|
+
range_from, range_to
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
def __rowsBetweenWindow(
|
|
984
|
+
self,
|
|
985
|
+
rows_from: int,
|
|
986
|
+
rows_to: int,
|
|
987
|
+
reverse: bool = False,
|
|
988
|
+
) -> WindowSpec:
|
|
989
|
+
return self.__baseWindow(reverse=reverse).rowsBetween(rows_from, rows_to)
|
|
990
|
+
|
|
991
|
+
def withPartitionCols(self, partitionCols: list[str]) -> "TSDF":
|
|
992
|
+
"""
|
|
993
|
+
Sets certain columns of the TSDF as partition columns. Partition columns are those that differentiate distinct timeseries
|
|
994
|
+
from each other.
|
|
995
|
+
:param partitionCols: a list of columns used to partition distinct timeseries
|
|
996
|
+
:return: a TSDF object with the given partition columns
|
|
997
|
+
"""
|
|
998
|
+
return TSDF(self.df, self.ts_col, partitionCols)
|
|
999
|
+
|
|
1000
|
+
def vwap(
|
|
1001
|
+
self,
|
|
1002
|
+
frequency: str = "m",
|
|
1003
|
+
volume_col: str = "volume",
|
|
1004
|
+
price_col: str = "price",
|
|
1005
|
+
) -> "TSDF":
|
|
1006
|
+
# set pre_vwap as self or enrich with the frequency
|
|
1007
|
+
pre_vwap = self.df
|
|
1008
|
+
if frequency == "m":
|
|
1009
|
+
pre_vwap = self.df.withColumn(
|
|
1010
|
+
"time_group",
|
|
1011
|
+
sfn.concat(
|
|
1012
|
+
sfn.lpad(sfn.hour(sfn.col(self.ts_col)), 2, "0"),
|
|
1013
|
+
sfn.lit(":"),
|
|
1014
|
+
sfn.lpad(sfn.minute(sfn.col(self.ts_col)), 2, "0"),
|
|
1015
|
+
),
|
|
1016
|
+
)
|
|
1017
|
+
elif frequency == "H":
|
|
1018
|
+
pre_vwap = self.df.withColumn(
|
|
1019
|
+
"time_group",
|
|
1020
|
+
sfn.concat(sfn.lpad(sfn.hour(sfn.col(self.ts_col)), 2, "0")),
|
|
1021
|
+
)
|
|
1022
|
+
elif frequency == "D":
|
|
1023
|
+
pre_vwap = self.df.withColumn(
|
|
1024
|
+
"time_group",
|
|
1025
|
+
sfn.concat(sfn.lpad(sfn.day(sfn.col(self.ts_col)), 2, "0")),
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
group_cols = ["time_group"]
|
|
1029
|
+
if self.partitionCols:
|
|
1030
|
+
group_cols.extend(self.partitionCols)
|
|
1031
|
+
vwapped = (
|
|
1032
|
+
pre_vwap.withColumn("dllr_value", sfn.col(price_col) * sfn.col(volume_col))
|
|
1033
|
+
.groupby(group_cols)
|
|
1034
|
+
.agg(
|
|
1035
|
+
sfn.sum("dllr_value").alias("dllr_value"),
|
|
1036
|
+
sfn.sum(volume_col).alias(volume_col),
|
|
1037
|
+
sfn.max(price_col).alias("_".join(["max", price_col])),
|
|
1038
|
+
)
|
|
1039
|
+
.withColumn("vwap", sfn.col("dllr_value") / sfn.col(volume_col))
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
return TSDF(vwapped, self.ts_col, self.partitionCols)
|
|
1043
|
+
|
|
1044
|
+
def EMA(self, colName: str, window: int = 30, exp_factor: float = 0.2) -> "TSDF":
|
|
1045
|
+
"""
|
|
1046
|
+
Constructs an approximate EMA in the fashion of:
|
|
1047
|
+
EMA = e * lag(col,0) + e * (1 - e) * lag(col, 1) + e * (1 - e)^2 * lag(col, 2) etc, up until window
|
|
1048
|
+
TODO: replace case when statement with coalesce
|
|
1049
|
+
TODO: add in time partitions functionality (what is the overlap fraction?)
|
|
1050
|
+
"""
|
|
1051
|
+
|
|
1052
|
+
emaColName = "_".join(["EMA", colName])
|
|
1053
|
+
df = self.df.withColumn(emaColName, sfn.lit(0)).orderBy(self.ts_col)
|
|
1054
|
+
w = self.__baseWindow()
|
|
1055
|
+
# Generate all the lag columns:
|
|
1056
|
+
for i in range(window):
|
|
1057
|
+
lagColName = "_".join(["lag", colName, str(i)])
|
|
1058
|
+
weight = exp_factor * (1 - exp_factor) ** i
|
|
1059
|
+
df = df.withColumn(
|
|
1060
|
+
lagColName, weight * sfn.lag(sfn.col(colName), i).over(w)
|
|
1061
|
+
)
|
|
1062
|
+
df = df.withColumn(
|
|
1063
|
+
emaColName,
|
|
1064
|
+
sfn.col(emaColName)
|
|
1065
|
+
+ sfn.when(sfn.col(lagColName).isNull(), sfn.lit(0)).otherwise(
|
|
1066
|
+
sfn.col(lagColName)
|
|
1067
|
+
),
|
|
1068
|
+
).drop(lagColName)
|
|
1069
|
+
# Nulls are currently removed
|
|
1070
|
+
|
|
1071
|
+
return TSDF(df, self.ts_col, self.partitionCols)
|
|
1072
|
+
|
|
1073
|
+
def withLookbackFeatures(
|
|
1074
|
+
self,
|
|
1075
|
+
featureCols: List[str],
|
|
1076
|
+
lookbackWindowSize: int,
|
|
1077
|
+
exactSize: bool = True,
|
|
1078
|
+
featureColName: str = "features",
|
|
1079
|
+
) -> Union[DataFrame | "TSDF"]:
|
|
1080
|
+
"""
|
|
1081
|
+
Creates a 2-D feature tensor suitable for training an ML model to predict current values from the history of
|
|
1082
|
+
some set of features. This function creates a new column containing, for each observation, a 2-D array of the values
|
|
1083
|
+
of some number of other columns over a trailing "lookback" window from the previous observation up to some maximum
|
|
1084
|
+
number of past observations.
|
|
1085
|
+
|
|
1086
|
+
:param featureCols: the names of one or more feature columns to be aggregated into the feature column
|
|
1087
|
+
:param lookbackWindowSize: The size of lookback window (in terms of past observations). Must be an integer >= 1
|
|
1088
|
+
:param exactSize: If True (the default), then the resulting DataFrame will only include observations where the
|
|
1089
|
+
generated feature column contains arrays of length lookbackWindowSize. This implies that it will truncate
|
|
1090
|
+
observations that occurred less than lookbackWindowSize from the start of the timeseries. If False, no truncation
|
|
1091
|
+
occurs, and the column may contain arrays less than lookbackWindowSize in length.
|
|
1092
|
+
:param featureColName: The name of the feature column to be generated. Defaults to "features"
|
|
1093
|
+
:return: a DataFrame with a feature column named featureColName containing the lookback feature tensor
|
|
1094
|
+
"""
|
|
1095
|
+
# first, join all featureCols into a single array column
|
|
1096
|
+
tempArrayColName = "__TempArrayCol"
|
|
1097
|
+
feat_array_tsdf = self.df.withColumn(tempArrayColName, sfn.array(featureCols))
|
|
1098
|
+
|
|
1099
|
+
# construct a lookback array
|
|
1100
|
+
lookback_win = self.__rowsBetweenWindow(-lookbackWindowSize, -1)
|
|
1101
|
+
lookback_tsdf = feat_array_tsdf.withColumn(
|
|
1102
|
+
featureColName,
|
|
1103
|
+
sfn.collect_list(sfn.col(tempArrayColName)).over(lookback_win),
|
|
1104
|
+
).drop(tempArrayColName)
|
|
1105
|
+
|
|
1106
|
+
# make sure only windows of exact size are allowed
|
|
1107
|
+
if exactSize:
|
|
1108
|
+
return lookback_tsdf.where(sfn.size(featureColName) == lookbackWindowSize)
|
|
1109
|
+
|
|
1110
|
+
return TSDF(lookback_tsdf, self.ts_col, self.partitionCols)
|
|
1111
|
+
|
|
1112
|
+
def withRangeStats(
|
|
1113
|
+
self,
|
|
1114
|
+
type: str = "range",
|
|
1115
|
+
colsToSummarize: Optional[List[Column]] = None,
|
|
1116
|
+
rangeBackWindowSecs: int = 1000,
|
|
1117
|
+
) -> "TSDF":
|
|
1118
|
+
"""
|
|
1119
|
+
Create a wider set of stats based on all numeric columns by default
|
|
1120
|
+
Users can choose which columns they want to summarize also. These stats are:
|
|
1121
|
+
mean/count/min/max/sum/std deviation/zscore
|
|
1122
|
+
:param type - this is created in case we want to extend these stats to lookback over a fixed number of rows instead of ranging over column values
|
|
1123
|
+
:param colsToSummarize - list of user-supplied columns to compute stats for. All numeric columns are used if no list is provided
|
|
1124
|
+
:param rangeBackWindowSecs - lookback this many seconds in time to summarize all stats. Note this will look back from the floor of the base event timestamp (as opposed to the exact time since we cast to long)
|
|
1125
|
+
Assumptions:
|
|
1126
|
+
|
|
1127
|
+
1. The features are summarized over a rolling window that ranges back
|
|
1128
|
+
2. The range back window can be specified by the user
|
|
1129
|
+
3. Sequence numbers are not yet supported for the sort
|
|
1130
|
+
4. There is a cast to long from timestamp so microseconds or more likely breaks down - this could be more easily handled with a string timestamp or sorting the timestamp itself. If using a 'rows preceding' window, this wouldn't be a problem
|
|
1131
|
+
"""
|
|
1132
|
+
|
|
1133
|
+
# identify columns to summarize if not provided
|
|
1134
|
+
# these should include all numeric columns that
|
|
1135
|
+
# are not the timestamp column and not any of the partition columns
|
|
1136
|
+
if colsToSummarize is None:
|
|
1137
|
+
# columns we should never summarize
|
|
1138
|
+
prohibited_cols = [self.ts_col.lower()]
|
|
1139
|
+
if self.partitionCols:
|
|
1140
|
+
prohibited_cols.extend([pc.lower() for pc in self.partitionCols])
|
|
1141
|
+
# filter columns to find summarizable columns
|
|
1142
|
+
colsToSummarize = [
|
|
1143
|
+
datatype[0]
|
|
1144
|
+
for datatype in self.df.dtypes
|
|
1145
|
+
if (
|
|
1146
|
+
(datatype[1] in self.summarizable_types)
|
|
1147
|
+
and (datatype[0].lower() not in prohibited_cols)
|
|
1148
|
+
)
|
|
1149
|
+
]
|
|
1150
|
+
|
|
1151
|
+
# build window
|
|
1152
|
+
if isinstance(self.df.schema[self.ts_col].dataType, TimestampType):
|
|
1153
|
+
self.df = self.__add_double_ts()
|
|
1154
|
+
prohibited_cols.extend(["double_ts"])
|
|
1155
|
+
w = self.__rangeBetweenWindow(
|
|
1156
|
+
-1 * rangeBackWindowSecs, 0, sort_col="double_ts"
|
|
1157
|
+
)
|
|
1158
|
+
else:
|
|
1159
|
+
w = self.__rangeBetweenWindow(-1 * rangeBackWindowSecs, 0)
|
|
1160
|
+
|
|
1161
|
+
# compute column summaries
|
|
1162
|
+
selectedCols = self.df.columns
|
|
1163
|
+
derivedCols = []
|
|
1164
|
+
for metric in colsToSummarize:
|
|
1165
|
+
selectedCols.append(sfn.mean(metric).over(w).alias("mean_" + metric))
|
|
1166
|
+
selectedCols.append(sfn.count(metric).over(w).alias("count_" + metric))
|
|
1167
|
+
selectedCols.append(sfn.min(metric).over(w).alias("min_" + metric))
|
|
1168
|
+
selectedCols.append(sfn.max(metric).over(w).alias("max_" + metric))
|
|
1169
|
+
selectedCols.append(sfn.sum(metric).over(w).alias("sum_" + metric))
|
|
1170
|
+
selectedCols.append(sfn.stddev(metric).over(w).alias("stddev_" + metric))
|
|
1171
|
+
derivedCols.append(
|
|
1172
|
+
(
|
|
1173
|
+
(sfn.col(metric) - sfn.col("mean_" + metric))
|
|
1174
|
+
/ sfn.col("stddev_" + metric)
|
|
1175
|
+
).alias("zscore_" + metric)
|
|
1176
|
+
)
|
|
1177
|
+
selected_df = self.df.select(*selectedCols)
|
|
1178
|
+
summary_df = selected_df.select(*selected_df.columns, *derivedCols).drop(
|
|
1179
|
+
"double_ts"
|
|
1180
|
+
)
|
|
1181
|
+
|
|
1182
|
+
return TSDF(summary_df, self.ts_col, self.partitionCols)
|
|
1183
|
+
|
|
1184
|
+
def withGroupedStats(
|
|
1185
|
+
self,
|
|
1186
|
+
metricCols: Optional[List[str]] = None,
|
|
1187
|
+
freq: Optional[str] = None,
|
|
1188
|
+
) -> "TSDF":
|
|
1189
|
+
"""
|
|
1190
|
+
Create a wider set of stats based on all numeric columns by default
|
|
1191
|
+
Users can choose which columns they want to summarize also. These stats are:
|
|
1192
|
+
mean/count/min/max/sum/std deviation
|
|
1193
|
+
:param metricCols - list of user-supplied columns to compute stats for. All numeric columns are used if no list is provided
|
|
1194
|
+
:param freq - frequency (provide a string of the form '1 min', '30 seconds' and we interpret the window to use to aggregate
|
|
1195
|
+
"""
|
|
1196
|
+
|
|
1197
|
+
# identify columns to summarize if not provided
|
|
1198
|
+
# these should include all numeric columns that
|
|
1199
|
+
# are not the timestamp column and not any of the partition columns
|
|
1200
|
+
if metricCols is None:
|
|
1201
|
+
# columns we should never summarize
|
|
1202
|
+
prohibited_cols = [self.ts_col.lower()]
|
|
1203
|
+
if self.partitionCols:
|
|
1204
|
+
prohibited_cols.extend([pc.lower() for pc in self.partitionCols])
|
|
1205
|
+
# filter columns to find summarizable columns
|
|
1206
|
+
metricCols = [
|
|
1207
|
+
datatype[0]
|
|
1208
|
+
for datatype in self.df.dtypes
|
|
1209
|
+
if (
|
|
1210
|
+
(datatype[1] in self.summarizable_types)
|
|
1211
|
+
and (datatype[0].lower() not in prohibited_cols)
|
|
1212
|
+
)
|
|
1213
|
+
]
|
|
1214
|
+
|
|
1215
|
+
# build window
|
|
1216
|
+
parsed_freq = t_resample.checkAllowableFreq(freq)
|
|
1217
|
+
period, unit = parsed_freq[0], parsed_freq[1]
|
|
1218
|
+
agg_window = sfn.window(
|
|
1219
|
+
sfn.col(self.ts_col),
|
|
1220
|
+
"{} {}".format(
|
|
1221
|
+
period, t_resample.freq_dict[unit] # type: ignore[literal-required]
|
|
1222
|
+
),
|
|
1223
|
+
)
|
|
1224
|
+
|
|
1225
|
+
# compute column summaries
|
|
1226
|
+
selectedCols = []
|
|
1227
|
+
for metric in metricCols:
|
|
1228
|
+
selectedCols.extend(
|
|
1229
|
+
[
|
|
1230
|
+
sfn.mean(sfn.col(metric)).alias("mean_" + metric),
|
|
1231
|
+
sfn.count(sfn.col(metric)).alias("count_" + metric),
|
|
1232
|
+
sfn.min(sfn.col(metric)).alias("min_" + metric),
|
|
1233
|
+
sfn.max(sfn.col(metric)).alias("max_" + metric),
|
|
1234
|
+
sfn.sum(sfn.col(metric)).alias("sum_" + metric),
|
|
1235
|
+
sfn.stddev(sfn.col(metric)).alias("stddev_" + metric),
|
|
1236
|
+
]
|
|
1237
|
+
)
|
|
1238
|
+
|
|
1239
|
+
selected_df = self.df.groupBy(self.partitionCols + [agg_window]).agg(
|
|
1240
|
+
*selectedCols
|
|
1241
|
+
)
|
|
1242
|
+
summary_df = (
|
|
1243
|
+
selected_df.select(*selected_df.columns)
|
|
1244
|
+
.withColumn(self.ts_col, sfn.col("window").start)
|
|
1245
|
+
.drop("window")
|
|
1246
|
+
)
|
|
1247
|
+
|
|
1248
|
+
return TSDF(summary_df, self.ts_col, self.partitionCols)
|
|
1249
|
+
|
|
1250
|
+
def write(
|
|
1251
|
+
self,
|
|
1252
|
+
spark: SparkSession,
|
|
1253
|
+
tabName: str,
|
|
1254
|
+
optimizationCols: Optional[List[str]] = None,
|
|
1255
|
+
) -> None:
|
|
1256
|
+
t_io.write(self, spark, tabName, optimizationCols)
|
|
1257
|
+
|
|
1258
|
+
def resample(
|
|
1259
|
+
self,
|
|
1260
|
+
freq: str,
|
|
1261
|
+
func: Union[Callable | str],
|
|
1262
|
+
metricCols: Optional[List[str]] = None,
|
|
1263
|
+
prefix: Optional[str] = None,
|
|
1264
|
+
fill: Optional[bool] = None,
|
|
1265
|
+
perform_checks: bool = True,
|
|
1266
|
+
) -> "TSDF":
|
|
1267
|
+
"""
|
|
1268
|
+
function to upsample based on frequency and aggregate function similar to pandas
|
|
1269
|
+
:param freq: frequency for upsample - valid inputs are "hr", "min", "sec" corresponding to hour, minute, or second
|
|
1270
|
+
:param func: function used to aggregate input
|
|
1271
|
+
:param metricCols supply a smaller list of numeric columns if the entire set of numeric columns should not be returned for the resample function
|
|
1272
|
+
:param prefix - supply a prefix for the newly sampled columns
|
|
1273
|
+
:param fill - Boolean - set to True if the desired output should contain filled in gaps (with 0s currently)
|
|
1274
|
+
:param perform_checks: calculate time horizon and warnings if True (default is True)
|
|
1275
|
+
:return: TSDF object with sample data using aggregate function
|
|
1276
|
+
"""
|
|
1277
|
+
t_resample.validateFuncExists(func)
|
|
1278
|
+
|
|
1279
|
+
# Throw warning for user to validate that the expected number of output rows is valid.
|
|
1280
|
+
if fill is True and perform_checks is True:
|
|
1281
|
+
t_utils.calculate_time_horizon(
|
|
1282
|
+
self.df, self.ts_col, freq, self.partitionCols
|
|
1283
|
+
)
|
|
1284
|
+
|
|
1285
|
+
enriched_df: DataFrame = t_resample.aggregate(
|
|
1286
|
+
self, freq, func, metricCols, prefix, fill
|
|
1287
|
+
)
|
|
1288
|
+
return _ResampledTSDF(
|
|
1289
|
+
enriched_df,
|
|
1290
|
+
ts_col=self.ts_col,
|
|
1291
|
+
partition_cols=self.partitionCols,
|
|
1292
|
+
freq=freq,
|
|
1293
|
+
func=func,
|
|
1294
|
+
)
|
|
1295
|
+
|
|
1296
|
+
def interpolate(
|
|
1297
|
+
self,
|
|
1298
|
+
method: str,
|
|
1299
|
+
freq: Optional[str] = None,
|
|
1300
|
+
func: Optional[Union[Callable | str]] = None,
|
|
1301
|
+
target_cols: Optional[List[str]] = None,
|
|
1302
|
+
ts_col: Optional[str] = None,
|
|
1303
|
+
partition_cols: Optional[List[str]] = None,
|
|
1304
|
+
show_interpolated: bool = False,
|
|
1305
|
+
perform_checks: bool = True,
|
|
1306
|
+
) -> "TSDF":
|
|
1307
|
+
"""
|
|
1308
|
+
Function to interpolate based on frequency, aggregation, and fill similar to pandas. Data will first be aggregated using resample, then missing values
|
|
1309
|
+
will be filled based on the fill calculation.
|
|
1310
|
+
|
|
1311
|
+
:param freq: frequency for upsample - valid inputs are "hr", "min", "sec" corresponding to hour, minute, or second
|
|
1312
|
+
:param func: function used to aggregate input
|
|
1313
|
+
:param method: function used to fill missing values e.g. linear, null, zero, bfill, ffill
|
|
1314
|
+
:param target_cols [optional]: columns that should be interpolated, by default interpolates all numeric columns
|
|
1315
|
+
:param ts_col [optional]: specify other ts_col, by default this uses the ts_col within the TSDF object
|
|
1316
|
+
:param partition_cols [optional]: specify other partition_cols, by default this uses the partition_cols within the TSDF object
|
|
1317
|
+
:param show_interpolated [optional]: if true will include an additional column to show which rows have been fully interpolated.
|
|
1318
|
+
:param perform_checks: calculate time horizon and warnings if True (default is True)
|
|
1319
|
+
:return: new TSDF object containing interpolated data
|
|
1320
|
+
"""
|
|
1321
|
+
|
|
1322
|
+
# Set defaults for target columns, timestamp column and partition columns when not provided
|
|
1323
|
+
if freq is None:
|
|
1324
|
+
raise ValueError("freq must be provided")
|
|
1325
|
+
if func is None:
|
|
1326
|
+
raise ValueError("func must be provided")
|
|
1327
|
+
if ts_col is None:
|
|
1328
|
+
ts_col = self.ts_col
|
|
1329
|
+
if partition_cols is None:
|
|
1330
|
+
partition_cols = self.partitionCols
|
|
1331
|
+
if target_cols is None:
|
|
1332
|
+
prohibited_cols: List[str] = partition_cols + [ts_col]
|
|
1333
|
+
target_cols = [col for col in self.df.columns if col not in prohibited_cols]
|
|
1334
|
+
|
|
1335
|
+
interpolate_service = t_interpolation.Interpolation(is_resampled=False)
|
|
1336
|
+
tsdf_input = TSDF(self.df, ts_col=ts_col, partition_cols=partition_cols)
|
|
1337
|
+
interpolated_df: DataFrame = interpolate_service.interpolate(
|
|
1338
|
+
tsdf_input,
|
|
1339
|
+
ts_col,
|
|
1340
|
+
partition_cols,
|
|
1341
|
+
target_cols,
|
|
1342
|
+
freq,
|
|
1343
|
+
func,
|
|
1344
|
+
method,
|
|
1345
|
+
show_interpolated,
|
|
1346
|
+
perform_checks,
|
|
1347
|
+
)
|
|
1348
|
+
|
|
1349
|
+
return TSDF(interpolated_df, ts_col=ts_col, partition_cols=partition_cols)
|
|
1350
|
+
|
|
1351
|
+
def calc_bars(
|
|
1352
|
+
tsdf,
|
|
1353
|
+
freq: str,
|
|
1354
|
+
metricCols: Optional[List[str]] = None,
|
|
1355
|
+
fill: Optional[bool] = None,
|
|
1356
|
+
) -> "TSDF":
|
|
1357
|
+
resample_open = tsdf.resample(
|
|
1358
|
+
freq=freq, func="floor", metricCols=metricCols, prefix="open", fill=fill
|
|
1359
|
+
)
|
|
1360
|
+
resample_low = tsdf.resample(
|
|
1361
|
+
freq=freq, func="min", metricCols=metricCols, prefix="low", fill=fill
|
|
1362
|
+
)
|
|
1363
|
+
resample_high = tsdf.resample(
|
|
1364
|
+
freq=freq, func="max", metricCols=metricCols, prefix="high", fill=fill
|
|
1365
|
+
)
|
|
1366
|
+
resample_close = tsdf.resample(
|
|
1367
|
+
freq=freq, func="ceil", metricCols=metricCols, prefix="close", fill=fill
|
|
1368
|
+
)
|
|
1369
|
+
|
|
1370
|
+
join_cols = resample_open.partitionCols + [resample_open.ts_col]
|
|
1371
|
+
bars = (
|
|
1372
|
+
resample_open.df.join(resample_high.df, join_cols)
|
|
1373
|
+
.join(resample_low.df, join_cols)
|
|
1374
|
+
.join(resample_close.df, join_cols)
|
|
1375
|
+
)
|
|
1376
|
+
non_part_cols = set(set(bars.columns) - set(resample_open.partitionCols)) - set(
|
|
1377
|
+
[resample_open.ts_col]
|
|
1378
|
+
)
|
|
1379
|
+
sel_and_sort = (
|
|
1380
|
+
resample_open.partitionCols + [resample_open.ts_col] + sorted(non_part_cols)
|
|
1381
|
+
)
|
|
1382
|
+
bars = bars.select(sel_and_sort)
|
|
1383
|
+
|
|
1384
|
+
return TSDF(bars, resample_open.ts_col, resample_open.partitionCols)
|
|
1385
|
+
|
|
1386
|
+
def fourier_transform(
|
|
1387
|
+
self, timestep: Union[int, float, complex], valueCol: str
|
|
1388
|
+
) -> "TSDF":
|
|
1389
|
+
"""
|
|
1390
|
+
Function to fourier transform the time series to its frequency domain representation.
|
|
1391
|
+
:param timestep: timestep value to be used for getting the frequency scale
|
|
1392
|
+
:param valueCol: name of the time domain data column which will be transformed
|
|
1393
|
+
"""
|
|
1394
|
+
|
|
1395
|
+
def tempo_fourier_util(
|
|
1396
|
+
pdf: pd.DataFrame,
|
|
1397
|
+
) -> pd.DataFrame:
|
|
1398
|
+
"""
|
|
1399
|
+
This method is a vanilla python logic implementing fourier transform on a numpy array using the scipy module.
|
|
1400
|
+
This method is meant to be called from Tempo TSDF as a pandas function API on Spark
|
|
1401
|
+
"""
|
|
1402
|
+
select_cols = list(pdf.columns)
|
|
1403
|
+
pdf.sort_values(by=["tpoints"], inplace=True, ascending=True)
|
|
1404
|
+
y = np.array(pdf["tdval"])
|
|
1405
|
+
tran = fft(y)
|
|
1406
|
+
r = tran.real
|
|
1407
|
+
i = tran.imag
|
|
1408
|
+
pdf["ft_real"] = r
|
|
1409
|
+
pdf["ft_imag"] = i
|
|
1410
|
+
N = tran.shape
|
|
1411
|
+
# fftfreq expects a float for the spacing parameter
|
|
1412
|
+
if isinstance(timestep, complex):
|
|
1413
|
+
spacing = abs(timestep) # Use magnitude for complex numbers
|
|
1414
|
+
else:
|
|
1415
|
+
spacing = float(timestep)
|
|
1416
|
+
xf = fftfreq(N[0], spacing)
|
|
1417
|
+
pdf["freq"] = xf
|
|
1418
|
+
return pdf[select_cols + ["freq", "ft_real", "ft_imag"]]
|
|
1419
|
+
|
|
1420
|
+
valueCol = self.__validated_column(self.df, valueCol)
|
|
1421
|
+
data = self.df
|
|
1422
|
+
if self.sequence_col:
|
|
1423
|
+
if self.partitionCols == []:
|
|
1424
|
+
data = data.withColumn("dummy_group", sfn.lit("dummy_val"))
|
|
1425
|
+
data = (
|
|
1426
|
+
data.select(
|
|
1427
|
+
sfn.col("dummy_group"),
|
|
1428
|
+
self.ts_col,
|
|
1429
|
+
self.sequence_col,
|
|
1430
|
+
sfn.col(valueCol),
|
|
1431
|
+
)
|
|
1432
|
+
.withColumn("tdval", sfn.col(valueCol))
|
|
1433
|
+
.withColumn("tpoints", sfn.col(self.ts_col))
|
|
1434
|
+
)
|
|
1435
|
+
return_schema = ",".join(
|
|
1436
|
+
[f"{i[0]} {i[1]}" for i in data.dtypes]
|
|
1437
|
+
+ ["freq double", "ft_real double", "ft_imag double"]
|
|
1438
|
+
)
|
|
1439
|
+
result = data.groupBy("dummy_group").applyInPandas(
|
|
1440
|
+
tempo_fourier_util, return_schema
|
|
1441
|
+
)
|
|
1442
|
+
result = result.drop("dummy_group", "tdval", "tpoints")
|
|
1443
|
+
else:
|
|
1444
|
+
group_cols = self.partitionCols
|
|
1445
|
+
data = (
|
|
1446
|
+
data.select(
|
|
1447
|
+
*group_cols,
|
|
1448
|
+
self.ts_col,
|
|
1449
|
+
self.sequence_col,
|
|
1450
|
+
sfn.col(valueCol),
|
|
1451
|
+
)
|
|
1452
|
+
.withColumn("tdval", sfn.col(valueCol))
|
|
1453
|
+
.withColumn("tpoints", sfn.col(self.ts_col))
|
|
1454
|
+
)
|
|
1455
|
+
return_schema = ",".join(
|
|
1456
|
+
[f"{i[0]} {i[1]}" for i in data.dtypes]
|
|
1457
|
+
+ ["freq double", "ft_real double", "ft_imag double"]
|
|
1458
|
+
)
|
|
1459
|
+
result = data.groupBy(*group_cols).applyInPandas(
|
|
1460
|
+
tempo_fourier_util, return_schema
|
|
1461
|
+
)
|
|
1462
|
+
result = result.drop("tdval", "tpoints")
|
|
1463
|
+
else:
|
|
1464
|
+
if self.partitionCols == []:
|
|
1465
|
+
data = data.withColumn("dummy_group", sfn.lit("dummy_val"))
|
|
1466
|
+
data = (
|
|
1467
|
+
data.select(sfn.col("dummy_group"), self.ts_col, sfn.col(valueCol))
|
|
1468
|
+
.withColumn("tdval", sfn.col(valueCol))
|
|
1469
|
+
.withColumn("tpoints", sfn.col(self.ts_col))
|
|
1470
|
+
)
|
|
1471
|
+
return_schema = ",".join(
|
|
1472
|
+
[f"{i[0]} {i[1]}" for i in data.dtypes]
|
|
1473
|
+
+ ["freq double", "ft_real double", "ft_imag double"]
|
|
1474
|
+
)
|
|
1475
|
+
result = data.groupBy("dummy_group").applyInPandas(
|
|
1476
|
+
tempo_fourier_util, return_schema
|
|
1477
|
+
)
|
|
1478
|
+
result = result.drop("dummy_group", "tdval", "tpoints")
|
|
1479
|
+
else:
|
|
1480
|
+
group_cols = self.partitionCols
|
|
1481
|
+
data = (
|
|
1482
|
+
data.select(*group_cols, self.ts_col, sfn.col(valueCol))
|
|
1483
|
+
.withColumn("tdval", sfn.col(valueCol))
|
|
1484
|
+
.withColumn("tpoints", sfn.col(self.ts_col))
|
|
1485
|
+
)
|
|
1486
|
+
return_schema = ",".join(
|
|
1487
|
+
[f"{i[0]} {i[1]}" for i in data.dtypes]
|
|
1488
|
+
+ ["freq double", "ft_real double", "ft_imag double"]
|
|
1489
|
+
)
|
|
1490
|
+
result = data.groupBy(*group_cols).applyInPandas(
|
|
1491
|
+
tempo_fourier_util, return_schema
|
|
1492
|
+
)
|
|
1493
|
+
result = result.drop("tdval", "tpoints")
|
|
1494
|
+
|
|
1495
|
+
return TSDF(result, self.ts_col, self.partitionCols, self.sequence_col)
|
|
1496
|
+
|
|
1497
|
+
def extractStateIntervals(
|
|
1498
|
+
self,
|
|
1499
|
+
*metric_cols: str,
|
|
1500
|
+
state_definition: Union[str, Callable[[Column, Column], Column]] = "=",
|
|
1501
|
+
) -> DataFrame:
|
|
1502
|
+
"""
|
|
1503
|
+
Extracts intervals from a :class:`~tsdf.TSDF` based on some notion of "state", as defined by the :param
|
|
1504
|
+
state_definition: parameter. The state definition consists of a comparison operation between the current and
|
|
1505
|
+
previous values of a metric. If the comparison operation evaluates to true across all metric columns,
|
|
1506
|
+
then we consider both points to be in the same "state". Changes of state occur when the comparison operator
|
|
1507
|
+
returns false for any given metric column. So, the default state definition ('=') entails that intervals of
|
|
1508
|
+
time wherein the metrics all remained constant. A state definition of '>=' would extract intervals wherein
|
|
1509
|
+
the metrics were all monotonically increasing.
|
|
1510
|
+
|
|
1511
|
+
:param: metric_cols: the set of metric columns to evaluate for state changes
|
|
1512
|
+
:param: state_definition: the comparison function used to evaluate individual metrics for state changes.
|
|
1513
|
+
|
|
1514
|
+
Either a string, giving a standard PySpark column comparison operation, or a binary function with the
|
|
1515
|
+
signature: `(x1: Column, x2: Column) -> Column` where the returned column expression evaluates to a
|
|
1516
|
+
:class:`~pyspark.sql.types.BooleanType`
|
|
1517
|
+
|
|
1518
|
+
:return: a :class:`~pyspark.sql.DataFrame` object containing the resulting intervals
|
|
1519
|
+
"""
|
|
1520
|
+
|
|
1521
|
+
# https://spark.apache.org/docs/latest/sql-ref-null-semantics.html#comparison-operators-
|
|
1522
|
+
def null_safe_equals(col1: Column, col2: Column) -> Column:
|
|
1523
|
+
return (
|
|
1524
|
+
sfn.when(col1.isNull() & col2.isNull(), True)
|
|
1525
|
+
.when(col1.isNull() | col2.isNull(), False)
|
|
1526
|
+
.otherwise(operator.eq(col1, col2))
|
|
1527
|
+
)
|
|
1528
|
+
|
|
1529
|
+
operator_dict = {
|
|
1530
|
+
# https://spark.apache.org/docs/latest/api/sql/#_2
|
|
1531
|
+
"!=": operator.ne,
|
|
1532
|
+
# https://spark.apache.org/docs/latest/api/sql/#_11
|
|
1533
|
+
"<>": operator.ne,
|
|
1534
|
+
# https://spark.apache.org/docs/latest/api/sql/#_8
|
|
1535
|
+
"<": operator.lt,
|
|
1536
|
+
# https://spark.apache.org/docs/latest/api/sql/#_9
|
|
1537
|
+
"<=": operator.le,
|
|
1538
|
+
# https://spark.apache.org/docs/latest/api/sql/#_10
|
|
1539
|
+
"<=>": null_safe_equals,
|
|
1540
|
+
# https://spark.apache.org/docs/latest/api/sql/#_12
|
|
1541
|
+
"=": operator.eq,
|
|
1542
|
+
# https://spark.apache.org/docs/latest/api/sql/#_13
|
|
1543
|
+
"==": operator.eq,
|
|
1544
|
+
# https://spark.apache.org/docs/latest/api/sql/#_14
|
|
1545
|
+
">": operator.gt,
|
|
1546
|
+
# https://spark.apache.org/docs/latest/api/sql/#_15
|
|
1547
|
+
">=": operator.ge,
|
|
1548
|
+
}
|
|
1549
|
+
|
|
1550
|
+
# Validate state definition and construct state comparison function
|
|
1551
|
+
if type(state_definition) is str:
|
|
1552
|
+
if state_definition not in operator_dict.keys():
|
|
1553
|
+
raise ValueError(
|
|
1554
|
+
f"Invalid comparison operator for `state_definition` argument: {state_definition}."
|
|
1555
|
+
)
|
|
1556
|
+
|
|
1557
|
+
def state_comparison_fn(a: CT, b: CT) -> Callable[[Column, Column], Column]:
|
|
1558
|
+
return operator_dict[state_definition](a, b)
|
|
1559
|
+
|
|
1560
|
+
elif callable(state_definition):
|
|
1561
|
+
state_comparison_fn = state_definition # type: ignore
|
|
1562
|
+
|
|
1563
|
+
else:
|
|
1564
|
+
raise TypeError(
|
|
1565
|
+
f"The `state_definition` argument can be of type `str` or `callable`, "
|
|
1566
|
+
f"but received value of type {type(state_definition)}"
|
|
1567
|
+
)
|
|
1568
|
+
|
|
1569
|
+
w = self.__baseWindow()
|
|
1570
|
+
|
|
1571
|
+
data = self.df
|
|
1572
|
+
|
|
1573
|
+
# Get previous timestamp to identify start time of the interval
|
|
1574
|
+
data = data.withColumn(
|
|
1575
|
+
"previous_ts",
|
|
1576
|
+
sfn.lag(sfn.col(self.ts_col), offset=1).over(w),
|
|
1577
|
+
)
|
|
1578
|
+
|
|
1579
|
+
# Determine state intervals using user-provided the state comparison function
|
|
1580
|
+
# The comparison occurs on the current and previous record per metric column
|
|
1581
|
+
temp_metric_compare_cols = []
|
|
1582
|
+
for mc in metric_cols:
|
|
1583
|
+
temp_metric_compare_col = f"__{mc}_compare"
|
|
1584
|
+
data = data.withColumn(
|
|
1585
|
+
temp_metric_compare_col,
|
|
1586
|
+
state_comparison_fn(sfn.col(mc), sfn.lag(sfn.col(mc), 1).over(w)),
|
|
1587
|
+
)
|
|
1588
|
+
temp_metric_compare_cols.append(temp_metric_compare_col)
|
|
1589
|
+
|
|
1590
|
+
# Remove first record which will have no state change
|
|
1591
|
+
# and produces `null` for all state comparisons
|
|
1592
|
+
data = data.filter(sfn.col("previous_ts").isNotNull())
|
|
1593
|
+
|
|
1594
|
+
# Each state comparison should return True if state remained constant
|
|
1595
|
+
data = data.withColumn(
|
|
1596
|
+
"state_change",
|
|
1597
|
+
sfn.array_contains(sfn.array(*temp_metric_compare_cols), False),
|
|
1598
|
+
)
|
|
1599
|
+
|
|
1600
|
+
# Count the distinct state changes to get the unique intervals
|
|
1601
|
+
data = data.withColumn(
|
|
1602
|
+
"state_incrementer",
|
|
1603
|
+
sfn.sum(sfn.col("state_change").cast("int")).over(w),
|
|
1604
|
+
).filter(~sfn.col("state_change"))
|
|
1605
|
+
|
|
1606
|
+
# Find the start and end timestamp of the interval
|
|
1607
|
+
result = (
|
|
1608
|
+
data.groupBy(*self.partitionCols, "state_incrementer")
|
|
1609
|
+
.agg(
|
|
1610
|
+
sfn.min("previous_ts").alias("start_ts"),
|
|
1611
|
+
sfn.max(self.ts_col).alias("end_ts"),
|
|
1612
|
+
)
|
|
1613
|
+
.drop("state_incrementer")
|
|
1614
|
+
)
|
|
1615
|
+
|
|
1616
|
+
return result
|
|
1617
|
+
|
|
1618
|
+
|
|
1619
|
+
class _ResampledTSDF(TSDF):
|
|
1620
|
+
def __init__(
|
|
1621
|
+
self,
|
|
1622
|
+
df: DataFrame,
|
|
1623
|
+
freq: str,
|
|
1624
|
+
func: Union[Callable | str],
|
|
1625
|
+
ts_col: str = "event_ts",
|
|
1626
|
+
partition_cols: Optional[List[str]] = None,
|
|
1627
|
+
sequence_col: Optional[str] = None,
|
|
1628
|
+
):
|
|
1629
|
+
super(_ResampledTSDF, self).__init__(df, ts_col, partition_cols, sequence_col)
|
|
1630
|
+
self.__freq = freq
|
|
1631
|
+
self.__func = func
|
|
1632
|
+
|
|
1633
|
+
def interpolate(
|
|
1634
|
+
self,
|
|
1635
|
+
method: str,
|
|
1636
|
+
freq: Optional[str] = None,
|
|
1637
|
+
func: Optional[Union[Callable | str]] = None,
|
|
1638
|
+
target_cols: Optional[List[str]] = None,
|
|
1639
|
+
ts_col: Optional[str] = None,
|
|
1640
|
+
partition_cols: Optional[List[str]] = None,
|
|
1641
|
+
show_interpolated: bool = False,
|
|
1642
|
+
perform_checks: bool = True,
|
|
1643
|
+
) -> "TSDF":
|
|
1644
|
+
"""
|
|
1645
|
+
Function to interpolate based on frequency, aggregation, and fill similar to pandas. This method requires an already sampled data set in order to use.
|
|
1646
|
+
|
|
1647
|
+
:param method: function used to fill missing values e.g. linear, null, zero, bfill, ffill
|
|
1648
|
+
:param target_cols [optional]: columns that should be interpolated, by default interpolates all numeric columns
|
|
1649
|
+
:param show_interpolated [optional]: if true will include an additional column to show which rows have been fully interpolated.
|
|
1650
|
+
:param perform_checks: calculate time horizon and warnings if True (default is True)
|
|
1651
|
+
:return: new TSDF object containing interpolated data
|
|
1652
|
+
"""
|
|
1653
|
+
|
|
1654
|
+
if freq is None:
|
|
1655
|
+
freq = self.__freq
|
|
1656
|
+
|
|
1657
|
+
if func is None:
|
|
1658
|
+
func = self.__func
|
|
1659
|
+
|
|
1660
|
+
if ts_col is None:
|
|
1661
|
+
ts_col = self.ts_col
|
|
1662
|
+
|
|
1663
|
+
if partition_cols is None:
|
|
1664
|
+
partition_cols = self.partitionCols
|
|
1665
|
+
|
|
1666
|
+
# Set defaults for target columns, timestamp column and partition columns when not provided
|
|
1667
|
+
if target_cols is None:
|
|
1668
|
+
prohibited_cols: List[str] = self.partitionCols + [self.ts_col]
|
|
1669
|
+
target_cols = [col for col in self.df.columns if col not in prohibited_cols]
|
|
1670
|
+
|
|
1671
|
+
interpolate_service = t_interpolation.Interpolation(is_resampled=True)
|
|
1672
|
+
tsdf_input = TSDF(
|
|
1673
|
+
self.df, ts_col=self.ts_col, partition_cols=self.partitionCols
|
|
1674
|
+
)
|
|
1675
|
+
interpolated_df = interpolate_service.interpolate(
|
|
1676
|
+
tsdf=tsdf_input,
|
|
1677
|
+
ts_col=self.ts_col,
|
|
1678
|
+
partition_cols=self.partitionCols,
|
|
1679
|
+
target_cols=target_cols,
|
|
1680
|
+
freq=freq,
|
|
1681
|
+
func=func,
|
|
1682
|
+
method=method,
|
|
1683
|
+
show_interpolated=show_interpolated,
|
|
1684
|
+
perform_checks=perform_checks,
|
|
1685
|
+
)
|
|
1686
|
+
|
|
1687
|
+
return TSDF(
|
|
1688
|
+
interpolated_df, ts_col=self.ts_col, partition_cols=self.partitionCols
|
|
1689
|
+
)
|
|
1690
|
+
|
|
1691
|
+
|
|
1692
|
+
class Comparable(metaclass=ABCMeta):
|
|
1693
|
+
"""For typing functions generated by operator_dict"""
|
|
1694
|
+
|
|
1695
|
+
@abstractmethod
|
|
1696
|
+
def __ne__(self, other: Any) -> bool:
|
|
1697
|
+
pass
|
|
1698
|
+
|
|
1699
|
+
@abstractmethod
|
|
1700
|
+
def __lt__(self, other: Any) -> bool:
|
|
1701
|
+
pass
|
|
1702
|
+
|
|
1703
|
+
@abstractmethod
|
|
1704
|
+
def __le__(self, other: Any) -> bool:
|
|
1705
|
+
pass
|
|
1706
|
+
|
|
1707
|
+
@abstractmethod
|
|
1708
|
+
def __eq__(self, other: Any) -> bool:
|
|
1709
|
+
pass
|
|
1710
|
+
|
|
1711
|
+
@abstractmethod
|
|
1712
|
+
def __gt__(self, other: Any) -> bool:
|
|
1713
|
+
pass
|
|
1714
|
+
|
|
1715
|
+
@abstractmethod
|
|
1716
|
+
def __ge__(self, other: Any) -> bool:
|
|
1717
|
+
pass
|
|
1718
|
+
|
|
1719
|
+
|
|
1720
|
+
CT = TypeVar("CT", bound=Comparable)
|