spotforecast2 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spotforecast2/.DS_Store +0 -0
- spotforecast2/__init__.py +2 -0
- spotforecast2/data/__init__.py +0 -0
- spotforecast2/data/data.py +130 -0
- spotforecast2/data/fetch_data.py +209 -0
- spotforecast2/exceptions.py +681 -0
- spotforecast2/forecaster/.DS_Store +0 -0
- spotforecast2/forecaster/__init__.py +7 -0
- spotforecast2/forecaster/base.py +448 -0
- spotforecast2/forecaster/metrics.py +527 -0
- spotforecast2/forecaster/recursive/__init__.py +4 -0
- spotforecast2/forecaster/recursive/_forecaster_equivalent_date.py +1075 -0
- spotforecast2/forecaster/recursive/_forecaster_recursive.py +939 -0
- spotforecast2/forecaster/recursive/_warnings.py +15 -0
- spotforecast2/forecaster/utils.py +954 -0
- spotforecast2/model_selection/__init__.py +5 -0
- spotforecast2/model_selection/bayesian_search.py +453 -0
- spotforecast2/model_selection/grid_search.py +314 -0
- spotforecast2/model_selection/random_search.py +151 -0
- spotforecast2/model_selection/split_base.py +357 -0
- spotforecast2/model_selection/split_one_step.py +245 -0
- spotforecast2/model_selection/split_ts_cv.py +634 -0
- spotforecast2/model_selection/utils_common.py +718 -0
- spotforecast2/model_selection/utils_metrics.py +103 -0
- spotforecast2/model_selection/validation.py +685 -0
- spotforecast2/preprocessing/__init__.py +30 -0
- spotforecast2/preprocessing/_binner.py +378 -0
- spotforecast2/preprocessing/_common.py +123 -0
- spotforecast2/preprocessing/_differentiator.py +123 -0
- spotforecast2/preprocessing/_rolling.py +136 -0
- spotforecast2/preprocessing/curate_data.py +254 -0
- spotforecast2/preprocessing/imputation.py +92 -0
- spotforecast2/preprocessing/outlier.py +114 -0
- spotforecast2/preprocessing/split.py +139 -0
- spotforecast2/py.typed +0 -0
- spotforecast2/utils/__init__.py +43 -0
- spotforecast2/utils/convert_to_utc.py +44 -0
- spotforecast2/utils/data_transform.py +208 -0
- spotforecast2/utils/forecaster_config.py +344 -0
- spotforecast2/utils/generate_holiday.py +70 -0
- spotforecast2/utils/validation.py +569 -0
- spotforecast2/weather/__init__.py +0 -0
- spotforecast2/weather/weather_client.py +288 -0
- spotforecast2-0.0.1.dist-info/METADATA +47 -0
- spotforecast2-0.0.1.dist-info/RECORD +46 -0
- spotforecast2-0.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def split_abs_train_val_test(
|
|
5
|
+
data: pd.DataFrame,
|
|
6
|
+
end_train: pd.Timestamp,
|
|
7
|
+
end_validation: pd.Timestamp,
|
|
8
|
+
verbose: bool = False,
|
|
9
|
+
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
|
10
|
+
"""Splits a time series DataFrame into training, validation, and test sets based on absolute timestamps.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
data (pd.DataFrame): The time series data with a DateTimeIndex.
|
|
14
|
+
end_train (pd.Timestamp): The end date for the training set.
|
|
15
|
+
end_validation (pd.Timestamp): The end date for the validation set.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
tuple: A tuple containing:
|
|
19
|
+
- data_train (pd.DataFrame): The training set.
|
|
20
|
+
- data_val (pd.DataFrame): The validation set.
|
|
21
|
+
- data_test (pd.DataFrame): The test set.
|
|
22
|
+
|
|
23
|
+
Examples:
|
|
24
|
+
>>> from spotforecast2.data.fetch_data import fetch_data
|
|
25
|
+
>>> from spotforecast2.preprocessing.split import split_train_val_test
|
|
26
|
+
>>> data = fetch_data()
|
|
27
|
+
>>> end_train = pd.Timestamp('2020-12-31 23:00:00')
|
|
28
|
+
>>> end_validation = pd.Timestamp('2021-06-30 23:00:00')
|
|
29
|
+
>>> data_train, data_val, data_test = split_train_val_test(
|
|
30
|
+
... data,
|
|
31
|
+
... end_train=end_train,
|
|
32
|
+
... end_validation=end_validation,
|
|
33
|
+
... verbose=True
|
|
34
|
+
... )
|
|
35
|
+
"""
|
|
36
|
+
data = data.copy()
|
|
37
|
+
start_date = data.index.min()
|
|
38
|
+
end_date = data.index.max()
|
|
39
|
+
if verbose:
|
|
40
|
+
print(f"Start date: {start_date}")
|
|
41
|
+
print(f"End date: {end_date}")
|
|
42
|
+
data_train = data.loc[:end_train, :].copy()
|
|
43
|
+
data_val = data.loc[end_train:end_validation, :].copy()
|
|
44
|
+
data_test = data.loc[end_validation:, :].copy()
|
|
45
|
+
|
|
46
|
+
if verbose:
|
|
47
|
+
print(
|
|
48
|
+
f"Train: {data_train.index.min()} --- {data_train.index.max()} (n={len(data_train)})"
|
|
49
|
+
)
|
|
50
|
+
print(
|
|
51
|
+
f"Val: {data_val.index.min()} --- {data_val.index.max()} (n={len(data_val)})"
|
|
52
|
+
)
|
|
53
|
+
print(
|
|
54
|
+
f"Test: {data_test.index.min()} --- {data_test.index.max()} (n={len(data_test)})"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return data_train, data_val, data_test
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def split_rel_train_val_test(
|
|
61
|
+
data: pd.DataFrame,
|
|
62
|
+
perc_train: float,
|
|
63
|
+
perc_val: float,
|
|
64
|
+
verbose: bool = False,
|
|
65
|
+
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
|
66
|
+
"""Splits a time series DataFrame into training, validation, and test sets by percentages.
|
|
67
|
+
|
|
68
|
+
The test percentage is computed as 1 - perc_train - perc_val.
|
|
69
|
+
Sizes are rounded to ensure the splits sum to the full dataset size.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
data (pd.DataFrame): The time series data with a DateTimeIndex.
|
|
73
|
+
perc_train (float): Fraction of data used for training.
|
|
74
|
+
perc_val (float): Fraction of data used for validation.
|
|
75
|
+
verbose (bool): Whether to print additional information.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
tuple: A tuple containing:
|
|
79
|
+
- data_train (pd.DataFrame): The training set.
|
|
80
|
+
- data_val (pd.DataFrame): The validation set.
|
|
81
|
+
- data_test (pd.DataFrame): The test set.
|
|
82
|
+
|
|
83
|
+
Examples:
|
|
84
|
+
>>> from spotforecast2.data.fetch_data import fetch_data
|
|
85
|
+
>>> from spotforecast2.preprocessing.split import split_rel_train_val_test
|
|
86
|
+
>>> data = fetch_data()
|
|
87
|
+
>>> data_train, data_val, data_test = split_rel_train_val_test(
|
|
88
|
+
... data,
|
|
89
|
+
... perc_train=0.7,
|
|
90
|
+
... perc_val=0.2,
|
|
91
|
+
... verbose=True
|
|
92
|
+
... )
|
|
93
|
+
"""
|
|
94
|
+
data = data.copy()
|
|
95
|
+
if data.shape[0] == 0:
|
|
96
|
+
raise ValueError("Input data is empty.")
|
|
97
|
+
if not (0 <= perc_train <= 1) or not (0 <= perc_val <= 1):
|
|
98
|
+
raise ValueError("perc_train and perc_val must be between 0 and 1 (inclusive).")
|
|
99
|
+
|
|
100
|
+
perc_test = 1 - perc_train - perc_val
|
|
101
|
+
if verbose:
|
|
102
|
+
print(
|
|
103
|
+
f"Splitting data into train/val/test with percentages: "
|
|
104
|
+
f"{perc_train:.4%} / {perc_val:.4%} / {perc_test:.4%}"
|
|
105
|
+
)
|
|
106
|
+
if round(perc_test, 10) < 0.0:
|
|
107
|
+
print(
|
|
108
|
+
f"Splitting data into train/val/test with percentages: "
|
|
109
|
+
f"{perc_train:.4%} / {perc_val:.4%} / {perc_test:.4%}"
|
|
110
|
+
)
|
|
111
|
+
raise ValueError(
|
|
112
|
+
"perc_train and perc_val must sum to 1 or less to leave room for a test set."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
n_total = len(data)
|
|
116
|
+
n_train = int(round(n_total * perc_train))
|
|
117
|
+
n_val = int(round(n_total * perc_val))
|
|
118
|
+
n_test = n_total - n_train - n_val
|
|
119
|
+
|
|
120
|
+
if n_test < 0:
|
|
121
|
+
n_test = 0
|
|
122
|
+
n_val = n_total - n_train
|
|
123
|
+
if n_val < 0:
|
|
124
|
+
n_val = 0
|
|
125
|
+
n_train = n_total
|
|
126
|
+
|
|
127
|
+
end_train_idx = n_train
|
|
128
|
+
end_val_idx = n_train + n_val
|
|
129
|
+
|
|
130
|
+
data_train = data.iloc[:end_train_idx, :].copy()
|
|
131
|
+
data_val = data.iloc[end_train_idx:end_val_idx, :].copy()
|
|
132
|
+
data_test = data.iloc[end_val_idx:, :].copy()
|
|
133
|
+
|
|
134
|
+
if verbose:
|
|
135
|
+
print(f"Train size: {len(data_train)} ({len(data_train) / n_total:.2%})")
|
|
136
|
+
print(f"Val size: {len(data_val)} ({len(data_val) / n_total:.2%})")
|
|
137
|
+
print(f"Test size: {len(data_test)} ({len(data_test) / n_total:.2%})")
|
|
138
|
+
|
|
139
|
+
return data_train, data_val, data_test
|
spotforecast2/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Utility functions for spotforecast."""
|
|
2
|
+
|
|
3
|
+
from spotforecast2.utils.validation import (
|
|
4
|
+
check_y,
|
|
5
|
+
check_exog,
|
|
6
|
+
get_exog_dtypes,
|
|
7
|
+
check_interval,
|
|
8
|
+
MissingValuesWarning,
|
|
9
|
+
DataTypeWarning,
|
|
10
|
+
check_exog_dtypes,
|
|
11
|
+
check_predict_input,
|
|
12
|
+
)
|
|
13
|
+
from spotforecast2.utils.data_transform import (
|
|
14
|
+
input_to_frame,
|
|
15
|
+
expand_index,
|
|
16
|
+
transform_dataframe,
|
|
17
|
+
)
|
|
18
|
+
from spotforecast2.utils.forecaster_config import (
|
|
19
|
+
initialize_lags,
|
|
20
|
+
initialize_weights,
|
|
21
|
+
check_select_fit_kwargs,
|
|
22
|
+
)
|
|
23
|
+
from spotforecast2.utils.convert_to_utc import convert_to_utc
|
|
24
|
+
from spotforecast2.utils.generate_holiday import create_holiday_df
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"check_y",
|
|
28
|
+
"check_exog",
|
|
29
|
+
"get_exog_dtypes",
|
|
30
|
+
"check_interval",
|
|
31
|
+
"MissingValuesWarning",
|
|
32
|
+
"DataTypeWarning",
|
|
33
|
+
"input_to_frame",
|
|
34
|
+
"initialize_lags",
|
|
35
|
+
"expand_index",
|
|
36
|
+
"initialize_weights",
|
|
37
|
+
"check_select_fit_kwargs",
|
|
38
|
+
"check_exog_dtypes",
|
|
39
|
+
"check_predict_input",
|
|
40
|
+
"transform_dataframe",
|
|
41
|
+
"convert_to_utc",
|
|
42
|
+
"create_holiday_df",
|
|
43
|
+
]
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Utility functions for timezone conversion."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def convert_to_utc(df: pd.DataFrame, timezone: Optional[str]) -> pd.DataFrame:
|
|
8
|
+
"""Convert DataFrame index timezone to UTC.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
df: DataFrame with DatetimeIndex.
|
|
12
|
+
timezone: Optional timezone string. Required if index has no timezone.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
DataFrame with UTC timezone index.
|
|
16
|
+
|
|
17
|
+
Raises:
|
|
18
|
+
ValueError: If index is not DatetimeIndex or has no timezone and
|
|
19
|
+
timezone is None.
|
|
20
|
+
|
|
21
|
+
Examples:
|
|
22
|
+
>>> from spotforecast2.utils.convert_to_utc import convert_to_utc
|
|
23
|
+
>>> df = pd.DataFrame({"value": [1, 2, 3]}, index=pd.to_datetime(["2022-01-01", "2022-01-02", "2022-01-03"]))
|
|
24
|
+
>>> convert_to_utc(df, "Europe/Berlin")
|
|
25
|
+
value
|
|
26
|
+
2022-01-01 00:00:00+01:00
|
|
27
|
+
2022-01-02 00:00:00+01:00
|
|
28
|
+
2022-01-03 00:00:00+01:00
|
|
29
|
+
"""
|
|
30
|
+
if not isinstance(df.index, pd.DatetimeIndex):
|
|
31
|
+
raise ValueError(
|
|
32
|
+
"No DatetimeIndex found. Please specify the time column via 'index_col'"
|
|
33
|
+
)
|
|
34
|
+
if df.index.tz is None:
|
|
35
|
+
if timezone is not None:
|
|
36
|
+
df.index = df.index.tz_localize(timezone)
|
|
37
|
+
else:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
"Index has no timezone information. Please provide a timezone."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
df.index = df.index.tz_convert("UTC")
|
|
43
|
+
|
|
44
|
+
return df
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data transformation utilities for time series forecasting.
|
|
3
|
+
|
|
4
|
+
This module provides functions for normalizing and transforming data formats.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Union
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def input_to_frame(
|
|
13
|
+
data: Union[pd.Series, pd.DataFrame], input_name: str
|
|
14
|
+
) -> pd.DataFrame:
|
|
15
|
+
"""
|
|
16
|
+
Convert input data to a pandas DataFrame.
|
|
17
|
+
|
|
18
|
+
This function ensures consistent DataFrame format for internal processing.
|
|
19
|
+
If data is already a DataFrame, it's returned as-is. If it's a Series,
|
|
20
|
+
it's converted to a single-column DataFrame.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
data: Input data as pandas Series or DataFrame.
|
|
24
|
+
input_name: Name of the input data type. Accepted values are:
|
|
25
|
+
- 'y': Target time series
|
|
26
|
+
- 'last_window': Last window for prediction
|
|
27
|
+
- 'exog': Exogenous variables
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
DataFrame version of the input data. For Series input, uses the series
|
|
31
|
+
name if available, otherwise uses a default name based on input_name.
|
|
32
|
+
|
|
33
|
+
Examples:
|
|
34
|
+
>>> import pandas as pd
|
|
35
|
+
>>> from spotforecast2.utils.data_transform import input_to_frame
|
|
36
|
+
>>>
|
|
37
|
+
>>> # Series with name
|
|
38
|
+
>>> y = pd.Series([1, 2, 3], name="sales")
|
|
39
|
+
>>> df = input_to_frame(y, input_name="y")
|
|
40
|
+
>>> df.columns.tolist()
|
|
41
|
+
['sales']
|
|
42
|
+
>>>
|
|
43
|
+
>>> # Series without name (uses default)
|
|
44
|
+
>>> y_no_name = pd.Series([1, 2, 3])
|
|
45
|
+
>>> df = input_to_frame(y_no_name, input_name="y")
|
|
46
|
+
>>> df.columns.tolist()
|
|
47
|
+
['y']
|
|
48
|
+
>>>
|
|
49
|
+
>>> # DataFrame (returned as-is)
|
|
50
|
+
>>> df_input = pd.DataFrame({"temp": [20, 21], "humidity": [50, 55]})
|
|
51
|
+
>>> df_output = input_to_frame(df_input, input_name="exog")
|
|
52
|
+
>>> df_output.columns.tolist()
|
|
53
|
+
['temp', 'humidity']
|
|
54
|
+
>>>
|
|
55
|
+
>>> # Exog series without name
|
|
56
|
+
>>> exog = pd.Series([10, 20, 30])
|
|
57
|
+
>>> df_exog = input_to_frame(exog, input_name="exog")
|
|
58
|
+
>>> df_exog.columns.tolist()
|
|
59
|
+
['exog']
|
|
60
|
+
"""
|
|
61
|
+
output_col_name = {"y": "y", "last_window": "y", "exog": "exog"}
|
|
62
|
+
|
|
63
|
+
if isinstance(data, pd.Series):
|
|
64
|
+
data = data.to_frame(
|
|
65
|
+
name=data.name if data.name is not None else output_col_name[input_name]
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return data
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def expand_index(index: Union[pd.Index, None], steps: int) -> pd.Index:
|
|
72
|
+
"""
|
|
73
|
+
Create a new index extending from the end of the original index.
|
|
74
|
+
|
|
75
|
+
This function generates future indices for forecasting by extending the time
|
|
76
|
+
series index by a specified number of steps. Handles both DatetimeIndex and
|
|
77
|
+
RangeIndex appropriately.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
index: Original pandas Index (DatetimeIndex or RangeIndex). If None,
|
|
81
|
+
creates a RangeIndex starting from 0.
|
|
82
|
+
steps: Number of future steps to generate.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
New pandas Index with `steps` future periods.
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
TypeError: If steps is not an integer, or if index is neither DatetimeIndex
|
|
89
|
+
nor RangeIndex.
|
|
90
|
+
|
|
91
|
+
Examples:
|
|
92
|
+
>>> import pandas as pd
|
|
93
|
+
>>> from spotforecast2.utils.data_transform import expand_index
|
|
94
|
+
>>>
|
|
95
|
+
>>> # DatetimeIndex
|
|
96
|
+
>>> dates = pd.date_range("2023-01-01", periods=5, freq="D")
|
|
97
|
+
>>> new_index = expand_index(dates, 3)
|
|
98
|
+
>>> new_index
|
|
99
|
+
DatetimeIndex(['2023-01-06', '2023-01-07', '2023-01-08'], dtype='datetime64[ns]', freq='D')
|
|
100
|
+
>>>
|
|
101
|
+
>>> # RangeIndex
|
|
102
|
+
>>> range_idx = pd.RangeIndex(start=0, stop=10)
|
|
103
|
+
>>> new_index = expand_index(range_idx, 5)
|
|
104
|
+
>>> new_index
|
|
105
|
+
RangeIndex(start=10, stop=15, step=1)
|
|
106
|
+
>>>
|
|
107
|
+
>>> # None index (creates new RangeIndex)
|
|
108
|
+
>>> new_index = expand_index(None, 3)
|
|
109
|
+
>>> new_index
|
|
110
|
+
RangeIndex(start=0, stop=3, step=1)
|
|
111
|
+
>>>
|
|
112
|
+
>>> # Invalid: steps not an integer
|
|
113
|
+
>>> try:
|
|
114
|
+
... expand_index(dates, 3.5)
|
|
115
|
+
... except TypeError as e:
|
|
116
|
+
... print("Error: steps must be an integer")
|
|
117
|
+
Error: steps must be an integer
|
|
118
|
+
"""
|
|
119
|
+
if not isinstance(steps, (int, np.integer)):
|
|
120
|
+
raise TypeError(f"`steps` must be an integer. Got {type(steps)}.")
|
|
121
|
+
|
|
122
|
+
# Convert numpy integer to Python int if needed
|
|
123
|
+
if isinstance(steps, np.integer):
|
|
124
|
+
steps = int(steps)
|
|
125
|
+
|
|
126
|
+
if isinstance(index, pd.Index):
|
|
127
|
+
if isinstance(index, pd.DatetimeIndex):
|
|
128
|
+
new_index = pd.date_range(
|
|
129
|
+
start=index[-1] + index.freq, periods=steps, freq=index.freq
|
|
130
|
+
)
|
|
131
|
+
elif isinstance(index, pd.RangeIndex):
|
|
132
|
+
new_index = pd.RangeIndex(start=index[-1] + 1, stop=index[-1] + 1 + steps)
|
|
133
|
+
else:
|
|
134
|
+
raise TypeError(
|
|
135
|
+
"Argument `index` must be a pandas DatetimeIndex or RangeIndex."
|
|
136
|
+
)
|
|
137
|
+
else:
|
|
138
|
+
new_index = pd.RangeIndex(start=0, stop=steps)
|
|
139
|
+
|
|
140
|
+
return new_index
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def transform_dataframe(
|
|
144
|
+
df: pd.DataFrame,
|
|
145
|
+
transformer: object,
|
|
146
|
+
fit: bool = False,
|
|
147
|
+
inverse_transform: bool = False,
|
|
148
|
+
) -> pd.DataFrame:
|
|
149
|
+
"""
|
|
150
|
+
Transform raw values of pandas DataFrame with a scikit-learn alike
|
|
151
|
+
transformer, preprocessor or ColumnTransformer.
|
|
152
|
+
|
|
153
|
+
The transformer used must have the following methods: fit, transform,
|
|
154
|
+
fit_transform and inverse_transform. ColumnTransformers are not allowed
|
|
155
|
+
since they do not have inverse_transform method.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
df: DataFrame to be transformed.
|
|
159
|
+
transformer: Scikit-learn alike transformer, preprocessor, or ColumnTransformer.
|
|
160
|
+
Must implement fit, transform, fit_transform and inverse_transform.
|
|
161
|
+
fit: Train the transformer before applying it. Defaults to False.
|
|
162
|
+
inverse_transform: Transform back the data to the original representation.
|
|
163
|
+
This is not available when using transformers of class
|
|
164
|
+
scikit-learn ColumnTransformers. Defaults to False.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Transformed DataFrame.
|
|
168
|
+
|
|
169
|
+
Raises:
|
|
170
|
+
TypeError: If df is not a pandas DataFrame.
|
|
171
|
+
ValueError: If inverse_transform is requested for ColumnTransformer.
|
|
172
|
+
"""
|
|
173
|
+
if not isinstance(df, pd.DataFrame):
|
|
174
|
+
raise TypeError(f"`df` argument must be a pandas DataFrame. Got {type(df)}")
|
|
175
|
+
|
|
176
|
+
if transformer is None:
|
|
177
|
+
return df
|
|
178
|
+
|
|
179
|
+
# Check for ColumnTransformer by class name to avoid importing sklearn
|
|
180
|
+
is_column_transformer = type(
|
|
181
|
+
transformer
|
|
182
|
+
).__name__ == "ColumnTransformer" or hasattr(transformer, "transformers")
|
|
183
|
+
|
|
184
|
+
if inverse_transform and is_column_transformer:
|
|
185
|
+
raise ValueError(
|
|
186
|
+
"`inverse_transform` is not available when using ColumnTransformers."
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if not inverse_transform:
|
|
190
|
+
if fit:
|
|
191
|
+
values_transformed = transformer.fit_transform(df)
|
|
192
|
+
else:
|
|
193
|
+
values_transformed = transformer.transform(df)
|
|
194
|
+
else:
|
|
195
|
+
values_transformed = transformer.inverse_transform(df)
|
|
196
|
+
|
|
197
|
+
if hasattr(values_transformed, "toarray"):
|
|
198
|
+
# If the returned values are in sparse matrix format, it is converted to dense
|
|
199
|
+
values_transformed = values_transformed.toarray()
|
|
200
|
+
|
|
201
|
+
if isinstance(values_transformed, pd.DataFrame):
|
|
202
|
+
df_transformed = values_transformed
|
|
203
|
+
else:
|
|
204
|
+
df_transformed = pd.DataFrame(
|
|
205
|
+
values_transformed, index=df.index, columns=df.columns
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
return df_transformed
|