spotforecast2 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spotforecast2/.DS_Store +0 -0
- spotforecast2/__init__.py +2 -0
- spotforecast2/data/__init__.py +0 -0
- spotforecast2/data/data.py +130 -0
- spotforecast2/data/fetch_data.py +209 -0
- spotforecast2/exceptions.py +681 -0
- spotforecast2/forecaster/.DS_Store +0 -0
- spotforecast2/forecaster/__init__.py +7 -0
- spotforecast2/forecaster/base.py +448 -0
- spotforecast2/forecaster/metrics.py +527 -0
- spotforecast2/forecaster/recursive/__init__.py +4 -0
- spotforecast2/forecaster/recursive/_forecaster_equivalent_date.py +1075 -0
- spotforecast2/forecaster/recursive/_forecaster_recursive.py +939 -0
- spotforecast2/forecaster/recursive/_warnings.py +15 -0
- spotforecast2/forecaster/utils.py +954 -0
- spotforecast2/model_selection/__init__.py +5 -0
- spotforecast2/model_selection/bayesian_search.py +453 -0
- spotforecast2/model_selection/grid_search.py +314 -0
- spotforecast2/model_selection/random_search.py +151 -0
- spotforecast2/model_selection/split_base.py +357 -0
- spotforecast2/model_selection/split_one_step.py +245 -0
- spotforecast2/model_selection/split_ts_cv.py +634 -0
- spotforecast2/model_selection/utils_common.py +718 -0
- spotforecast2/model_selection/utils_metrics.py +103 -0
- spotforecast2/model_selection/validation.py +685 -0
- spotforecast2/preprocessing/__init__.py +30 -0
- spotforecast2/preprocessing/_binner.py +378 -0
- spotforecast2/preprocessing/_common.py +123 -0
- spotforecast2/preprocessing/_differentiator.py +123 -0
- spotforecast2/preprocessing/_rolling.py +136 -0
- spotforecast2/preprocessing/curate_data.py +254 -0
- spotforecast2/preprocessing/imputation.py +92 -0
- spotforecast2/preprocessing/outlier.py +114 -0
- spotforecast2/preprocessing/split.py +139 -0
- spotforecast2/py.typed +0 -0
- spotforecast2/utils/__init__.py +43 -0
- spotforecast2/utils/convert_to_utc.py +44 -0
- spotforecast2/utils/data_transform.py +208 -0
- spotforecast2/utils/forecaster_config.py +344 -0
- spotforecast2/utils/generate_holiday.py +70 -0
- spotforecast2/utils/validation.py +569 -0
- spotforecast2/weather/__init__.py +0 -0
- spotforecast2/weather/weather_client.py +288 -0
- spotforecast2-0.0.1.dist-info/METADATA +47 -0
- spotforecast2-0.0.1.dist-info/RECORD +46 -0
- spotforecast2-0.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Forecaster configuration utilities.
|
|
3
|
+
|
|
4
|
+
This module provides functions for initializing and validating forecaster
|
|
5
|
+
configuration parameters like lags and weights.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Union, List, Tuple, Optional
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def initialize_lags(
|
|
13
|
+
forecaster_name: str, lags: Any
|
|
14
|
+
) -> Tuple[Optional[np.ndarray], Optional[List[str]], Optional[int]]:
|
|
15
|
+
"""
|
|
16
|
+
Validate and normalize lag specification for forecasting.
|
|
17
|
+
|
|
18
|
+
This function converts various lag specifications (int, list, tuple, range, ndarray)
|
|
19
|
+
into a standardized format: sorted numpy array, lag names, and maximum lag value.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
forecaster_name: Name of the forecaster class for error messages.
|
|
23
|
+
lags: Lag specification in one of several formats:
|
|
24
|
+
- int: Creates lags from 1 to lags (e.g., 5 → [1,2,3,4,5])
|
|
25
|
+
- list/tuple/range: Converted to numpy array
|
|
26
|
+
- numpy.ndarray: Validated and used directly
|
|
27
|
+
- None: Returns (None, None, None)
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Tuple containing:
|
|
31
|
+
- lags: Sorted numpy array of lag values (or None)
|
|
32
|
+
- lags_names: List of lag names like ['lag_1', 'lag_2', ...] (or None)
|
|
33
|
+
- max_lag: Maximum lag value (or None)
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ValueError: If lags < 1, empty array, or not 1-dimensional.
|
|
37
|
+
TypeError: If lags is not an integer, not in the right format for the forecaster,
|
|
38
|
+
or array contains non-integer values.
|
|
39
|
+
|
|
40
|
+
Examples:
|
|
41
|
+
>>> import numpy as np
|
|
42
|
+
>>> from spotforecast2.utils.forecaster_config import initialize_lags
|
|
43
|
+
>>>
|
|
44
|
+
>>> # Integer input
|
|
45
|
+
>>> lags, names, max_lag = initialize_lags("ForecasterRecursive", 3)
|
|
46
|
+
>>> lags
|
|
47
|
+
array([1, 2, 3])
|
|
48
|
+
>>> names
|
|
49
|
+
['lag_1', 'lag_2', 'lag_3']
|
|
50
|
+
>>> max_lag
|
|
51
|
+
3
|
|
52
|
+
>>>
|
|
53
|
+
>>> # List input
|
|
54
|
+
>>> lags, names, max_lag = initialize_lags("ForecasterRecursive", [1, 3, 5])
|
|
55
|
+
>>> lags
|
|
56
|
+
array([1, 3, 5])
|
|
57
|
+
>>> names
|
|
58
|
+
['lag_1', 'lag_3', 'lag_5']
|
|
59
|
+
>>>
|
|
60
|
+
>>> # Range input
|
|
61
|
+
>>> lags, names, max_lag = initialize_lags("ForecasterRecursive", range(1, 4))
|
|
62
|
+
>>> lags
|
|
63
|
+
array([1, 2, 3])
|
|
64
|
+
>>>
|
|
65
|
+
>>> # None input
|
|
66
|
+
>>> lags, names, max_lag = initialize_lags("ForecasterRecursive", None)
|
|
67
|
+
>>> lags is None
|
|
68
|
+
True
|
|
69
|
+
>>>
|
|
70
|
+
>>> # Invalid: lags < 1
|
|
71
|
+
>>> try:
|
|
72
|
+
... initialize_lags("ForecasterRecursive", 0)
|
|
73
|
+
... except ValueError as e:
|
|
74
|
+
... print("Error: Minimum value of lags allowed is 1")
|
|
75
|
+
Error: Minimum value of lags allowed is 1
|
|
76
|
+
>>>
|
|
77
|
+
>>> # Invalid: negative lags
|
|
78
|
+
>>> try:
|
|
79
|
+
... initialize_lags("ForecasterRecursive", [1, -2, 3])
|
|
80
|
+
... except ValueError as e:
|
|
81
|
+
... print("Error: Minimum value of lags allowed is 1")
|
|
82
|
+
Error: Minimum value of lags allowed is 1
|
|
83
|
+
"""
|
|
84
|
+
lags_names = None
|
|
85
|
+
max_lag = None
|
|
86
|
+
|
|
87
|
+
if lags is not None:
|
|
88
|
+
if isinstance(lags, int):
|
|
89
|
+
if lags < 1:
|
|
90
|
+
raise ValueError("Minimum value of lags allowed is 1.")
|
|
91
|
+
lags = np.arange(1, lags + 1)
|
|
92
|
+
|
|
93
|
+
if isinstance(lags, (list, tuple, range)):
|
|
94
|
+
lags = np.array(lags)
|
|
95
|
+
|
|
96
|
+
if isinstance(lags, np.ndarray):
|
|
97
|
+
if lags.size == 0:
|
|
98
|
+
return None, None, None
|
|
99
|
+
if lags.ndim != 1:
|
|
100
|
+
raise ValueError("`lags` must be a 1-dimensional array.")
|
|
101
|
+
if not np.issubdtype(lags.dtype, np.integer):
|
|
102
|
+
raise TypeError("All values in `lags` must be integers.")
|
|
103
|
+
if np.any(lags < 1):
|
|
104
|
+
raise ValueError("Minimum value of lags allowed is 1.")
|
|
105
|
+
else:
|
|
106
|
+
if forecaster_name == "ForecasterDirectMultiVariate":
|
|
107
|
+
raise TypeError(
|
|
108
|
+
f"`lags` argument must be a dict, int, 1d numpy ndarray, range, "
|
|
109
|
+
f"tuple or list. Got {type(lags)}."
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
raise TypeError(
|
|
113
|
+
f"`lags` argument must be an int, 1d numpy ndarray, range, "
|
|
114
|
+
f"tuple or list. Got {type(lags)}."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
lags = np.sort(lags)
|
|
118
|
+
lags_names = [f"lag_{i}" for i in lags]
|
|
119
|
+
max_lag = int(max(lags))
|
|
120
|
+
|
|
121
|
+
return lags, lags_names, max_lag
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def initialize_weights(
|
|
125
|
+
forecaster_name: str, estimator: Any, weight_func: Any, series_weights: Any
|
|
126
|
+
) -> Tuple[Any, Optional[Union[str, dict]], Any]:
|
|
127
|
+
"""
|
|
128
|
+
Validate and initialize weight function configuration for forecasting.
|
|
129
|
+
|
|
130
|
+
This function validates weight_func and series_weights, extracts source code
|
|
131
|
+
from weight functions for serialization, and checks if the estimator supports
|
|
132
|
+
sample weights in its fit method.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
forecaster_name: Name of the forecaster class.
|
|
136
|
+
estimator: Scikit-learn compatible estimator or pipeline.
|
|
137
|
+
weight_func: Weight function specification:
|
|
138
|
+
- Callable: Single weight function
|
|
139
|
+
- dict: Dictionary of weight functions (for MultiSeries forecasters)
|
|
140
|
+
- None: No weighting
|
|
141
|
+
series_weights: Dictionary of series-level weights (for MultiSeries forecasters).
|
|
142
|
+
- dict: Maps series names to weight values
|
|
143
|
+
- None: No series weighting
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Tuple containing:
|
|
147
|
+
- weight_func: Validated weight function (or None if invalid)
|
|
148
|
+
- source_code_weight_func: Source code of weight function(s) for serialization (or None)
|
|
149
|
+
- series_weights: Validated series weights (or None if invalid)
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
TypeError: If weight_func is not Callable/dict (depending on forecaster type),
|
|
153
|
+
or if series_weights is not a dict.
|
|
154
|
+
|
|
155
|
+
Warnings:
|
|
156
|
+
IgnoredArgumentWarning: If estimator doesn't support sample_weight.
|
|
157
|
+
|
|
158
|
+
Examples:
|
|
159
|
+
>>> import numpy as np
|
|
160
|
+
>>> from sklearn.linear_model import Ridge
|
|
161
|
+
>>> from spotforecast2.utils.forecaster_config import initialize_weights
|
|
162
|
+
>>>
|
|
163
|
+
>>> # Simple weight function
|
|
164
|
+
>>> def custom_weights(index):
|
|
165
|
+
... return np.ones(len(index))
|
|
166
|
+
>>>
|
|
167
|
+
>>> estimator = Ridge()
|
|
168
|
+
>>> wf, source, sw = initialize_weights(
|
|
169
|
+
... "ForecasterRecursive", estimator, custom_weights, None
|
|
170
|
+
... )
|
|
171
|
+
>>> wf is not None
|
|
172
|
+
True
|
|
173
|
+
>>> isinstance(source, str)
|
|
174
|
+
True
|
|
175
|
+
>>>
|
|
176
|
+
>>> # No weight function
|
|
177
|
+
>>> wf, source, sw = initialize_weights(
|
|
178
|
+
... "ForecasterRecursive", estimator, None, None
|
|
179
|
+
... )
|
|
180
|
+
>>> wf is None
|
|
181
|
+
True
|
|
182
|
+
>>> source is None
|
|
183
|
+
True
|
|
184
|
+
>>>
|
|
185
|
+
>>> # Invalid type for non-MultiSeries forecaster
|
|
186
|
+
>>> try:
|
|
187
|
+
... initialize_weights("ForecasterRecursive", estimator, "invalid", None)
|
|
188
|
+
... except TypeError as e:
|
|
189
|
+
... print("Error: weight_func must be Callable")
|
|
190
|
+
Error: weight_func must be Callable
|
|
191
|
+
"""
|
|
192
|
+
import inspect
|
|
193
|
+
import warnings
|
|
194
|
+
from collections.abc import Callable
|
|
195
|
+
|
|
196
|
+
# Import IgnoredArgumentWarning if available, otherwise define locally
|
|
197
|
+
try:
|
|
198
|
+
from spotforecast2.exceptions import IgnoredArgumentWarning
|
|
199
|
+
except ImportError:
|
|
200
|
+
|
|
201
|
+
class IgnoredArgumentWarning(UserWarning):
|
|
202
|
+
"""Warning for ignored arguments."""
|
|
203
|
+
|
|
204
|
+
pass
|
|
205
|
+
|
|
206
|
+
source_code_weight_func = None
|
|
207
|
+
|
|
208
|
+
if weight_func is not None:
|
|
209
|
+
if forecaster_name in ["ForecasterRecursiveMultiSeries"]:
|
|
210
|
+
if not isinstance(weight_func, (Callable, dict)):
|
|
211
|
+
raise TypeError(
|
|
212
|
+
f"Argument `weight_func` must be a Callable or a dict of "
|
|
213
|
+
f"Callables. Got {type(weight_func)}."
|
|
214
|
+
)
|
|
215
|
+
elif not isinstance(weight_func, Callable):
|
|
216
|
+
raise TypeError(
|
|
217
|
+
f"Argument `weight_func` must be a Callable. Got {type(weight_func)}."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if isinstance(weight_func, dict):
|
|
221
|
+
source_code_weight_func = {}
|
|
222
|
+
for key in weight_func:
|
|
223
|
+
try:
|
|
224
|
+
source_code_weight_func[key] = inspect.getsource(weight_func[key])
|
|
225
|
+
except OSError:
|
|
226
|
+
source_code_weight_func[key] = (
|
|
227
|
+
f"<source unavailable: {weight_func[key]!r}>"
|
|
228
|
+
)
|
|
229
|
+
else:
|
|
230
|
+
try:
|
|
231
|
+
source_code_weight_func = inspect.getsource(weight_func)
|
|
232
|
+
except OSError:
|
|
233
|
+
source_code_weight_func = f"<source unavailable: {weight_func!r}>"
|
|
234
|
+
|
|
235
|
+
if "sample_weight" not in inspect.signature(estimator.fit).parameters:
|
|
236
|
+
warnings.warn(
|
|
237
|
+
f"Argument `weight_func` is ignored since estimator {estimator} "
|
|
238
|
+
f"does not accept `sample_weight` in its `fit` method.",
|
|
239
|
+
IgnoredArgumentWarning,
|
|
240
|
+
)
|
|
241
|
+
weight_func = None
|
|
242
|
+
source_code_weight_func = None
|
|
243
|
+
|
|
244
|
+
if series_weights is not None:
|
|
245
|
+
if not isinstance(series_weights, dict):
|
|
246
|
+
raise TypeError(
|
|
247
|
+
f"Argument `series_weights` must be a dict of floats or ints."
|
|
248
|
+
f"Got {type(series_weights)}."
|
|
249
|
+
)
|
|
250
|
+
if "sample_weight" not in inspect.signature(estimator.fit).parameters:
|
|
251
|
+
warnings.warn(
|
|
252
|
+
f"Argument `series_weights` is ignored since estimator {estimator} "
|
|
253
|
+
f"does not accept `sample_weight` in its `fit` method.",
|
|
254
|
+
IgnoredArgumentWarning,
|
|
255
|
+
)
|
|
256
|
+
series_weights = None
|
|
257
|
+
|
|
258
|
+
return weight_func, source_code_weight_func, series_weights
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def check_select_fit_kwargs(estimator: Any, fit_kwargs: Optional[dict] = None) -> dict:
|
|
262
|
+
"""
|
|
263
|
+
Check if `fit_kwargs` is a dict and select only keys used by estimator's `fit`.
|
|
264
|
+
|
|
265
|
+
This function validates that fit_kwargs is a dictionary, warns about unused arguments,
|
|
266
|
+
removes 'sample_weight' (which should be handled via weight_func), and returns
|
|
267
|
+
a dictionary containing only the arguments accepted by the estimator's fit method.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
estimator: Scikit-learn compatible estimator.
|
|
271
|
+
fit_kwargs: Dictionary of arguments to pass to the estimator's fit method.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
Dictionary with only the arguments accepted by the estimator's fit method.
|
|
275
|
+
|
|
276
|
+
Raises:
|
|
277
|
+
TypeError: If fit_kwargs is not a dict.
|
|
278
|
+
|
|
279
|
+
Warnings:
|
|
280
|
+
IgnoredArgumentWarning: If fit_kwargs contains keys not used by fit method,
|
|
281
|
+
or if 'sample_weight' is present (it gets removed).
|
|
282
|
+
|
|
283
|
+
Examples:
|
|
284
|
+
>>> from sklearn.linear_model import Ridge
|
|
285
|
+
>>> from spotforecast2.utils.forecaster_config import check_select_fit_kwargs
|
|
286
|
+
>>>
|
|
287
|
+
>>> estimator = Ridge()
|
|
288
|
+
>>> # Valid argument for Ridge.fit
|
|
289
|
+
>>> kwargs = {"sample_weight": [1, 1], "invalid_arg": 10}
|
|
290
|
+
>>> # sample_weight is removed (should be passed via weight_func in forecaster)
|
|
291
|
+
>>> # invalid_arg is ignored
|
|
292
|
+
>>> filtered = check_select_fit_kwargs(estimator, kwargs)
|
|
293
|
+
>>> filtered
|
|
294
|
+
{}
|
|
295
|
+
"""
|
|
296
|
+
import inspect
|
|
297
|
+
import warnings
|
|
298
|
+
|
|
299
|
+
# Import IgnoredArgumentWarning if available, otherwise define locally
|
|
300
|
+
try:
|
|
301
|
+
from spotforecast2.exceptions import IgnoredArgumentWarning
|
|
302
|
+
except ImportError:
|
|
303
|
+
|
|
304
|
+
class IgnoredArgumentWarning(UserWarning):
|
|
305
|
+
"""Warning for ignored arguments."""
|
|
306
|
+
|
|
307
|
+
pass
|
|
308
|
+
|
|
309
|
+
if fit_kwargs is None:
|
|
310
|
+
fit_kwargs = {}
|
|
311
|
+
else:
|
|
312
|
+
if not isinstance(fit_kwargs, dict):
|
|
313
|
+
raise TypeError(
|
|
314
|
+
f"Argument `fit_kwargs` must be a dict. Got {type(fit_kwargs)}."
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Get parameters accepted by estimator.fit
|
|
318
|
+
fit_params = inspect.signature(estimator.fit).parameters
|
|
319
|
+
|
|
320
|
+
# Identify unused keys
|
|
321
|
+
non_used_keys = [k for k in fit_kwargs.keys() if k not in fit_params]
|
|
322
|
+
if non_used_keys:
|
|
323
|
+
warnings.warn(
|
|
324
|
+
f"Argument/s {non_used_keys} ignored since they are not used by the "
|
|
325
|
+
f"estimator's `fit` method.",
|
|
326
|
+
IgnoredArgumentWarning,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# Handle sample_weight specially
|
|
330
|
+
if "sample_weight" in fit_kwargs.keys():
|
|
331
|
+
warnings.warn(
|
|
332
|
+
"The `sample_weight` argument is ignored. Use `weight_func` to pass "
|
|
333
|
+
"a function that defines the individual weights for each sample "
|
|
334
|
+
"based on its index.",
|
|
335
|
+
IgnoredArgumentWarning,
|
|
336
|
+
)
|
|
337
|
+
del fit_kwargs["sample_weight"]
|
|
338
|
+
|
|
339
|
+
# Select only the keyword arguments allowed by the estimator's `fit` method.
|
|
340
|
+
# Note: We need to re-check keys because sample_weight might have been deleted but it might be in fit_params
|
|
341
|
+
# If it was deleted, it is no longer in fit_kwargs, so this comprehension is safe
|
|
342
|
+
fit_kwargs = {k: v for k, v in fit_kwargs.items() if k in fit_params}
|
|
343
|
+
|
|
344
|
+
return fit_kwargs
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Utilities for generating holiday dataframe as covariate."""
|
|
2
|
+
|
|
3
|
+
from typing import Union
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import holidays
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def create_holiday_df(
|
|
9
|
+
start: Union[str, pd.Timestamp],
|
|
10
|
+
end: Union[str, pd.Timestamp],
|
|
11
|
+
tz: str = "UTC",
|
|
12
|
+
freq: str = "h",
|
|
13
|
+
country_code: str = "DE",
|
|
14
|
+
state: str = "NW",
|
|
15
|
+
) -> pd.DataFrame:
|
|
16
|
+
"""Create a DataFrame with datetime index and a binary holiday indicator column.
|
|
17
|
+
|
|
18
|
+
Expands daily holidays to all timestamps in the desired frequency.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
start: Start date/datetime.
|
|
22
|
+
end: End date/datetime.
|
|
23
|
+
tz: Timezone to use if not inferred from start/end.
|
|
24
|
+
freq: Frequency of the resulting DataFrame.
|
|
25
|
+
country_code: Country code for holidays (e.g. "DE", "US").
|
|
26
|
+
state: State code for holidays (e.g. "NW", "CA").
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
pd.DataFrame: DataFrame with index covering [start, end] at `freq`,
|
|
30
|
+
and a 'holiday' column (1 if holiday, 0 otherwise).
|
|
31
|
+
|
|
32
|
+
Examples:
|
|
33
|
+
>>> df = create_holiday_df("2023-12-24", "2023-12-26", freq="D")
|
|
34
|
+
>>> df["holiday"].tolist()
|
|
35
|
+
[0, 1, 1]
|
|
36
|
+
"""
|
|
37
|
+
# If start/end are Timestamps with timezones, use that timezone instead of
|
|
38
|
+
# the default. This avoids conflicts when timezone-aware Timestamps are
|
|
39
|
+
# passed with a different tz parameter
|
|
40
|
+
inferred_tz = None
|
|
41
|
+
if isinstance(start, pd.Timestamp) and start.tz is not None:
|
|
42
|
+
inferred_tz = str(start.tz)
|
|
43
|
+
elif isinstance(end, pd.Timestamp) and end.tz is not None:
|
|
44
|
+
inferred_tz = str(end.tz)
|
|
45
|
+
|
|
46
|
+
# Use inferred timezone if available, otherwise use the provided tz parameter
|
|
47
|
+
effective_tz = inferred_tz if inferred_tz is not None else tz
|
|
48
|
+
|
|
49
|
+
# When creating date_range with timezone-aware Timestamps, don't pass tz parameter
|
|
50
|
+
# to avoid conflicts - pandas will infer it from the Timestamps
|
|
51
|
+
if inferred_tz is not None:
|
|
52
|
+
full_index = pd.date_range(start=start, end=end, freq=freq)
|
|
53
|
+
daily_index = pd.date_range(start=start, end=end, freq="D")
|
|
54
|
+
else:
|
|
55
|
+
full_index = pd.date_range(start=start, end=end, freq=freq, tz=effective_tz)
|
|
56
|
+
daily_index = pd.date_range(start=start, end=end, freq="D", tz=effective_tz)
|
|
57
|
+
|
|
58
|
+
# Get holidays for the country/state
|
|
59
|
+
country_holidays = holidays.country_holidays(country_code, subdiv=state)
|
|
60
|
+
|
|
61
|
+
# Check each day if it is a holiday
|
|
62
|
+
# We use the date part for lookup
|
|
63
|
+
is_holiday = [1 if date.date() in country_holidays else 0 for date in daily_index]
|
|
64
|
+
|
|
65
|
+
df_holiday = pd.DataFrame({"holiday": is_holiday}, index=daily_index)
|
|
66
|
+
|
|
67
|
+
# Reindex to full frequency and forward fill
|
|
68
|
+
df_full = df_holiday.reindex(full_index, method="ffill").fillna(0).astype(int)
|
|
69
|
+
|
|
70
|
+
return df_full
|