oracle-ads 2.10.0__py3-none-any.whl → 2.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +12 -0
- ads/aqua/base.py +324 -0
- ads/aqua/cli.py +19 -0
- ads/aqua/config/deployment_config_defaults.json +9 -0
- ads/aqua/config/resource_limit_names.json +7 -0
- ads/aqua/constants.py +45 -0
- ads/aqua/data.py +40 -0
- ads/aqua/decorator.py +101 -0
- ads/aqua/deployment.py +643 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation.py +1751 -0
- ads/aqua/exception.py +82 -0
- ads/aqua/extension/__init__.py +40 -0
- ads/aqua/extension/base_handler.py +138 -0
- ads/aqua/extension/common_handler.py +21 -0
- ads/aqua/extension/deployment_handler.py +202 -0
- ads/aqua/extension/evaluation_handler.py +135 -0
- ads/aqua/extension/finetune_handler.py +66 -0
- ads/aqua/extension/model_handler.py +59 -0
- ads/aqua/extension/ui_handler.py +201 -0
- ads/aqua/extension/utils.py +23 -0
- ads/aqua/finetune.py +579 -0
- ads/aqua/job.py +29 -0
- ads/aqua/model.py +819 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +459 -0
- ads/aqua/ui.py +453 -0
- ads/aqua/utils.py +715 -0
- ads/cli.py +37 -6
- ads/common/auth.py +7 -0
- ads/common/decorator/__init__.py +7 -3
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/object_storage_details.py +166 -7
- ads/common/oci_client.py +18 -1
- ads/common/oci_logging.py +2 -2
- ads/common/oci_mixin.py +4 -5
- ads/common/serializer.py +34 -5
- ads/common/utils.py +75 -10
- ads/config.py +40 -1
- ads/dataset/correlation_plot.py +10 -12
- ads/jobs/ads_job.py +43 -25
- ads/jobs/builders/infrastructure/base.py +4 -2
- ads/jobs/builders/infrastructure/dsc_job.py +49 -39
- ads/jobs/builders/runtimes/base.py +71 -1
- ads/jobs/builders/runtimes/container_runtime.py +4 -4
- ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
- ads/jobs/templates/driver_pytorch.py +27 -10
- ads/model/artifact_downloader.py +84 -14
- ads/model/artifact_uploader.py +25 -23
- ads/model/datascience_model.py +388 -38
- ads/model/deployment/model_deployment.py +10 -2
- ads/model/generic_model.py +8 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_metadata.py +1 -1
- ads/model/service/oci_datascience_model.py +34 -5
- ads/opctl/config/merger.py +2 -2
- ads/opctl/operator/__init__.py +3 -1
- ads/opctl/operator/cli.py +7 -1
- ads/opctl/operator/cmd.py +3 -3
- ads/opctl/operator/common/errors.py +2 -1
- ads/opctl/operator/common/operator_config.py +22 -3
- ads/opctl/operator/common/utils.py +16 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +15 -0
- ads/opctl/operator/lowcode/anomaly/README.md +209 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +104 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +88 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +12 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +147 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +89 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +103 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +354 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +67 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +105 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +359 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +81 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +96 -0
- ads/opctl/operator/lowcode/common/errors.py +41 -0
- ads/opctl/operator/lowcode/common/transformations.py +191 -0
- ads/opctl/operator/lowcode/common/utils.py +250 -0
- ads/opctl/operator/lowcode/forecast/README.md +3 -2
- ads/opctl/operator/lowcode/forecast/__main__.py +18 -2
- ads/opctl/operator/lowcode/forecast/cmd.py +8 -7
- ads/opctl/operator/lowcode/forecast/const.py +17 -1
- ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
- ads/opctl/operator/lowcode/forecast/model/arima.py +106 -117
- ads/opctl/operator/lowcode/forecast/model/automlx.py +204 -180
- ads/opctl/operator/lowcode/forecast/model/autots.py +144 -253
- ads/opctl/operator/lowcode/forecast/model/base_model.py +326 -259
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +325 -176
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +293 -237
- ads/opctl/operator/lowcode/forecast/model/prophet.py +191 -208
- ads/opctl/operator/lowcode/forecast/operator_config.py +24 -33
- ads/opctl/operator/lowcode/forecast/schema.yaml +116 -29
- ads/opctl/operator/lowcode/forecast/utils.py +186 -356
- ads/opctl/operator/lowcode/pii/model/guardrails.py +18 -15
- ads/opctl/operator/lowcode/pii/model/report.py +7 -7
- ads/opctl/operator/lowcode/pii/operator_config.py +1 -8
- ads/opctl/operator/lowcode/pii/utils.py +0 -82
- ads/opctl/operator/runtime/runtime.py +3 -2
- ads/telemetry/base.py +62 -0
- ads/telemetry/client.py +105 -0
- ads/telemetry/telemetry.py +6 -3
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +44 -7
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +116 -59
- ads/opctl/operator/lowcode/forecast/model/transformations.py +0 -125
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
@@ -4,14 +4,121 @@
|
|
4
4
|
# Copyright (c) 2023 Oracle and/or its affiliates.
|
5
5
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
6
|
|
7
|
+
import time
|
7
8
|
import pandas as pd
|
9
|
+
from pandas.api.types import is_datetime64_any_dtype, is_string_dtype, is_numeric_dtype
|
10
|
+
|
8
11
|
from ..operator_config import ForecastOperatorConfig
|
9
|
-
from .. import utils
|
10
|
-
from .transformations import Transformations
|
11
12
|
from ads.opctl import logger
|
12
|
-
import pandas as pd
|
13
13
|
from ..const import ForecastOutputColumns, PROPHET_INTERNAL_DATE_COL
|
14
|
-
from
|
14
|
+
from ads.common.object_storage_details import ObjectStorageDetails
|
15
|
+
from ads.opctl.operator.lowcode.common.utils import (
|
16
|
+
get_frequency_in_seconds,
|
17
|
+
get_frequency_of_datetime,
|
18
|
+
)
|
19
|
+
from ads.opctl.operator.lowcode.common.data import AbstractData
|
20
|
+
from ads.opctl.operator.lowcode.forecast.utils import (
|
21
|
+
default_signer,
|
22
|
+
)
|
23
|
+
from ads.opctl.operator.lowcode.common.errors import (
|
24
|
+
InputDataError,
|
25
|
+
InvalidParameterError,
|
26
|
+
PermissionsError,
|
27
|
+
DataMismatchError,
|
28
|
+
)
|
29
|
+
from ..const import SupportedModels
|
30
|
+
from abc import ABC, abstractmethod
|
31
|
+
|
32
|
+
|
33
|
+
class HistoricalData(AbstractData):
|
34
|
+
def __init__(self, spec: dict):
|
35
|
+
super().__init__(spec=spec, name="historical_data")
|
36
|
+
|
37
|
+
def _ingest_data(self, spec):
|
38
|
+
try:
|
39
|
+
self.freq = get_frequency_of_datetime(self.data.index.get_level_values(0))
|
40
|
+
except TypeError as e:
|
41
|
+
logger.warn(
|
42
|
+
f"Error determining frequency: {e.args}. Setting Frequency to None"
|
43
|
+
)
|
44
|
+
logger.debug(f"Full traceback: {e}")
|
45
|
+
self.freq = None
|
46
|
+
self._verify_dt_col(spec)
|
47
|
+
super()._ingest_data(spec)
|
48
|
+
|
49
|
+
def _verify_dt_col(self, spec):
|
50
|
+
# Check frequency is compatible with model type
|
51
|
+
self.freq_in_secs = get_frequency_in_seconds(
|
52
|
+
self.data.index.get_level_values(0)
|
53
|
+
)
|
54
|
+
if spec.model == SupportedModels.AutoMLX:
|
55
|
+
if abs(self.freq_in_secs) < 3600:
|
56
|
+
message = (
|
57
|
+
"{} requires data with a frequency of at least one hour. Please try using a different model,"
|
58
|
+
" or select the 'auto' option.".format(SupportedModels.AutoMLX)
|
59
|
+
)
|
60
|
+
raise InvalidParameterError(message)
|
61
|
+
|
62
|
+
|
63
|
+
class AdditionalData(AbstractData):
|
64
|
+
def __init__(self, spec, historical_data):
|
65
|
+
if spec.additional_data is not None:
|
66
|
+
super().__init__(spec=spec, name="additional_data")
|
67
|
+
add_dates = self.data.index.get_level_values(0).unique().tolist()
|
68
|
+
add_dates.sort()
|
69
|
+
if historical_data.get_max_time() > add_dates[-spec.horizon]:
|
70
|
+
raise DataMismatchError(
|
71
|
+
f"The Historical Data ends on {historical_data.get_max_time()}. The additional data horizon starts on {add_dates[-spec.horizon]}. The horizon should have exactly {spec.horizon} dates after the Hisotrical at a frequency of {historical_data.freq}"
|
72
|
+
)
|
73
|
+
elif historical_data.get_max_time() != add_dates[-(spec.horizon + 1)]:
|
74
|
+
raise DataMismatchError(
|
75
|
+
f"The Additional Data must be present for all historical data and the entire horizon. The Historical Data ends on {historical_data.get_max_time()}. The additonal data horizon starts after {add_dates[-(spec.horizon+1)]}. These should be the same date."
|
76
|
+
)
|
77
|
+
else:
|
78
|
+
self.name = "additional_data"
|
79
|
+
self.data = None
|
80
|
+
self._data_dict = dict()
|
81
|
+
self.create_horizon(spec, historical_data)
|
82
|
+
|
83
|
+
def create_horizon(self, spec, historical_data):
|
84
|
+
logger.debug(f"No additional data provided. Constructing horizon.")
|
85
|
+
future_dates = pd.Series(
|
86
|
+
pd.date_range(
|
87
|
+
start=historical_data.get_max_time(),
|
88
|
+
periods=spec.horizon + 1,
|
89
|
+
freq=historical_data.freq,
|
90
|
+
),
|
91
|
+
name=spec.datetime_column.name,
|
92
|
+
)
|
93
|
+
add_dfs = []
|
94
|
+
for s_id in historical_data.list_series_ids():
|
95
|
+
df_i = historical_data.get_data_for_series(s_id)[spec.datetime_column.name]
|
96
|
+
df_i = pd.DataFrame(pd.concat([df_i, future_dates[1:]]))
|
97
|
+
df_i[ForecastOutputColumns.SERIES] = s_id
|
98
|
+
df_i = df_i.set_index(
|
99
|
+
[spec.datetime_column.name, ForecastOutputColumns.SERIES]
|
100
|
+
)
|
101
|
+
add_dfs.append(df_i)
|
102
|
+
data = pd.concat(add_dfs, axis=1)
|
103
|
+
self.data = data.sort_values(
|
104
|
+
[spec.datetime_column.name, ForecastOutputColumns.SERIES], ascending=True
|
105
|
+
)
|
106
|
+
self.additional_regressors = []
|
107
|
+
|
108
|
+
def _ingest_data(self, spec):
|
109
|
+
self.additional_regressors = list(self.data.columns)
|
110
|
+
if not self.additional_regressors:
|
111
|
+
logger.warn(
|
112
|
+
f"No additional variables found in the additional_data. Only columns found: {self.data.columns}. Skipping for now."
|
113
|
+
)
|
114
|
+
# Check that datetime column matches historical datetime column
|
115
|
+
|
116
|
+
|
117
|
+
class TestData(AbstractData):
|
118
|
+
def __init__(self, spec):
|
119
|
+
super().__init__(spec=spec, name="test_data")
|
120
|
+
self.dt_column_name = spec.datetime_column.name
|
121
|
+
self.target_name = spec.target_column
|
15
122
|
|
16
123
|
|
17
124
|
class ForecastDatasets:
|
@@ -23,223 +130,259 @@ class ForecastDatasets:
|
|
23
130
|
config: ForecastOperatorConfig
|
24
131
|
The forecast operator configuration.
|
25
132
|
"""
|
26
|
-
self.
|
27
|
-
self.
|
28
|
-
|
29
|
-
self.
|
30
|
-
self.
|
31
|
-
self.categories = None
|
32
|
-
self.datetime_col = PROPHET_INTERNAL_DATE_COL
|
33
|
-
self.datetime_format = config.spec.datetime_column.format
|
133
|
+
self.historical_data: HistoricalData = None
|
134
|
+
self.additional_data: AdditionalData = None
|
135
|
+
|
136
|
+
self._horizon = config.spec.horizon
|
137
|
+
self._datetime_column_name = config.spec.datetime_column.name
|
34
138
|
self._load_data(config.spec)
|
35
139
|
|
36
140
|
def _load_data(self, spec):
|
37
141
|
"""Loads forecasting input data."""
|
142
|
+
self.historical_data = HistoricalData(spec)
|
143
|
+
self.additional_data = AdditionalData(spec, self.historical_data)
|
38
144
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
145
|
+
if spec.generate_explanations:
|
146
|
+
if spec.additional_data is None:
|
147
|
+
logger.warn(
|
148
|
+
f"Unable to generate explanations as there is no additional data passed in. Either set generate_explanations to False, or pass in additional data."
|
149
|
+
)
|
150
|
+
spec.generate_explanations = False
|
151
|
+
|
152
|
+
def get_all_data_long(self, include_horizon=True):
|
153
|
+
how = "outer" if include_horizon else "left"
|
154
|
+
return pd.merge(
|
155
|
+
self.historical_data.data,
|
156
|
+
self.additional_data.data,
|
157
|
+
how=how,
|
158
|
+
on=[self._datetime_column_name, ForecastOutputColumns.SERIES],
|
159
|
+
).reset_index()
|
160
|
+
|
161
|
+
def get_data_multi_indexed(self):
|
162
|
+
return pd.concat(
|
163
|
+
[
|
164
|
+
self.historical_data.data,
|
165
|
+
self.additional_data.data,
|
166
|
+
],
|
167
|
+
axis=1,
|
43
168
|
)
|
44
|
-
self.original_user_data = raw_data.copy()
|
45
|
-
data_transformer = Transformations(raw_data, spec)
|
46
|
-
data = data_transformer.run()
|
47
|
-
try:
|
48
|
-
spec.freq = utils.get_frequency_of_datetime(data, spec)
|
49
|
-
except TypeError as e:
|
50
|
-
logger.warn(
|
51
|
-
f"Error determining frequency: {e.args}. Setting Frequency to None"
|
52
|
-
)
|
53
|
-
logger.debug(f"Full traceback: {e}")
|
54
|
-
spec.freq = None
|
55
169
|
|
56
|
-
|
57
|
-
|
170
|
+
def get_data_by_series(self, include_horizon=True):
|
171
|
+
total_dict = dict()
|
172
|
+
hist_data = self.historical_data.get_dict_by_series()
|
173
|
+
add_data = self.additional_data.get_dict_by_series()
|
174
|
+
how = "outer" if include_horizon else "left"
|
175
|
+
for s_id in self.list_series_ids():
|
176
|
+
# Note: ensure no duplicate column names
|
177
|
+
total_dict[s_id] = pd.merge(
|
178
|
+
hist_data[s_id],
|
179
|
+
add_data[s_id],
|
180
|
+
how=how,
|
181
|
+
on=[self._datetime_column_name],
|
182
|
+
)
|
183
|
+
return total_dict
|
58
184
|
|
185
|
+
def get_data_at_series(self, s_id, include_horizon=True):
|
186
|
+
all_data = self.get_data_by_series(include_horizon=include_horizon)
|
59
187
|
try:
|
60
|
-
|
61
|
-
data[spec.datetime_column.name], format=self.datetime_format
|
62
|
-
)
|
188
|
+
return all_data[s_id]
|
63
189
|
except:
|
64
|
-
raise
|
65
|
-
f"Unable to
|
190
|
+
raise InvalidParameterError(
|
191
|
+
f"Unable to retrieve series id: {s_id} from data. Available series ids are: {self.list_series_ids()}"
|
66
192
|
)
|
67
193
|
|
68
|
-
|
69
|
-
|
70
|
-
filename=spec.additional_data.url,
|
71
|
-
format=spec.additional_data.format,
|
72
|
-
columns=spec.additional_data.columns,
|
73
|
-
)
|
74
|
-
additional_data = data_transformer._sort_by_datetime_col(additional_data)
|
75
|
-
try:
|
76
|
-
additional_data[spec.datetime_column.name] = pd.to_datetime(
|
77
|
-
additional_data[spec.datetime_column.name],
|
78
|
-
format=self.datetime_format,
|
79
|
-
)
|
80
|
-
except:
|
81
|
-
raise ValueError(
|
82
|
-
f"Unable to determine the datetime type for column: {spec.datetime_column.name}. Please specify the format explicitly."
|
83
|
-
)
|
194
|
+
def get_horizon_at_series(self, s_id):
|
195
|
+
return self.get_data_at_series(s_id)[-self._horizon :]
|
84
196
|
|
85
|
-
|
86
|
-
|
87
|
-
else:
|
88
|
-
# Need to add the horizon to the data for compatibility
|
89
|
-
additional_data_small = data[
|
90
|
-
[spec.datetime_column.name] + spec.target_category_columns
|
91
|
-
].set_index(spec.datetime_column.name)
|
92
|
-
if is_datetime64_any_dtype(additional_data_small.index):
|
93
|
-
horizon_index = pd.date_range(
|
94
|
-
start=additional_data_small.index.values[-1],
|
95
|
-
freq=spec.freq,
|
96
|
-
periods=spec.horizon + 1,
|
97
|
-
)[1:]
|
98
|
-
elif is_numeric_dtype(additional_data_small.index):
|
99
|
-
# If datetime column is just ints
|
100
|
-
assert (
|
101
|
-
len(additional_data_small.index.values) > 1
|
102
|
-
), "Dataset is too small to infer frequency. Please pass in the horizon explicitly through the additional data."
|
103
|
-
start = additional_data_small.index.values[-1]
|
104
|
-
step = (
|
105
|
-
additional_data_small.index.values[-1]
|
106
|
-
- additional_data_small.index.values[-2]
|
107
|
-
)
|
108
|
-
horizon_index = pd.RangeIndex(
|
109
|
-
start, start + step * (spec.horizon + 1), step=step
|
110
|
-
)[1:]
|
111
|
-
else:
|
112
|
-
raise ValueError(
|
113
|
-
f"Unable to determine the datetime type for column: {spec.datetime_column.name}. Please specify the format explicitly."
|
114
|
-
)
|
197
|
+
def has_artificial_series(self):
|
198
|
+
return self.historical_data._data_transformer.has_artificial_series
|
115
199
|
|
116
|
-
|
117
|
-
|
118
|
-
for cat_col in spec.target_category_columns:
|
119
|
-
for cat in additional_data_small[cat_col].unique():
|
120
|
-
add_data_i = additional_data_small[
|
121
|
-
additional_data_small[cat_col] == cat
|
122
|
-
]
|
123
|
-
horizon_df_i = pd.DataFrame([], index=horizon_index)
|
124
|
-
horizon_df_i[cat_col] = cat
|
125
|
-
additional_data = pd.concat(
|
126
|
-
[additional_data, add_data_i, horizon_df_i]
|
127
|
-
)
|
128
|
-
additional_data = additional_data.reset_index().rename(
|
129
|
-
{"index": spec.datetime_column.name}, axis=1
|
130
|
-
)
|
200
|
+
def get_earliest_timestamp(self):
|
201
|
+
return self.historical_data.get_min_time()
|
131
202
|
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
)
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
203
|
+
def get_latest_timestamp(self):
|
204
|
+
return self.historical_data.get_max_time()
|
205
|
+
|
206
|
+
def get_additional_data_column_names(self):
|
207
|
+
return self.additional_data.additional_regressors
|
208
|
+
|
209
|
+
def get_datetime_frequency(self):
|
210
|
+
return self.historical_data.freq
|
211
|
+
|
212
|
+
def get_datetime_frequency_in_seconds(self):
|
213
|
+
return self.historical_data.freq_in_secs
|
214
|
+
|
215
|
+
def get_num_rows(self):
|
216
|
+
return self.historical_data.get_num_rows()
|
217
|
+
|
218
|
+
def list_series_ids(self, sorted=True):
|
219
|
+
series_ids = self.historical_data.list_series_ids()
|
220
|
+
if sorted:
|
221
|
+
try:
|
222
|
+
series_ids.sort()
|
223
|
+
except:
|
224
|
+
pass
|
225
|
+
return series_ids
|
152
226
|
|
153
227
|
def format_wide(self):
|
154
228
|
data_merged = pd.concat(
|
155
229
|
[
|
156
|
-
v[v[k].notna()].set_index(self.
|
157
|
-
for k, v in self.
|
230
|
+
v[v[k].notna()].set_index(self._datetime_column_name)
|
231
|
+
for k, v in self.get_data_by_series().items()
|
158
232
|
],
|
159
233
|
axis=1,
|
160
234
|
).reset_index()
|
161
235
|
return data_merged
|
162
236
|
|
163
237
|
def get_longest_datetime_column(self):
|
164
|
-
return
|
165
|
-
self.format_wide()[self.datetime_col], format=self.datetime_format
|
166
|
-
)
|
238
|
+
return self.format_wide()[self._datetime_column_name]
|
167
239
|
|
168
240
|
|
169
241
|
class ForecastOutput:
|
170
|
-
def __init__(
|
242
|
+
def __init__(
|
243
|
+
self,
|
244
|
+
confidence_interval_width: float,
|
245
|
+
horizon: int,
|
246
|
+
target_column: str,
|
247
|
+
dt_column: str,
|
248
|
+
):
|
171
249
|
"""Forecast Output contains all of the details required to generate the forecast.csv output file.
|
172
250
|
|
173
|
-
|
174
|
-
|
175
|
-
|
251
|
+
init
|
252
|
+
-------
|
253
|
+
confidence_interval_width: float value from OperatorSpec
|
254
|
+
horizon: int length of horizon
|
255
|
+
target_column: str the name of the original target column
|
256
|
+
dt_column: the name of the original datetime column
|
176
257
|
"""
|
177
|
-
self.
|
178
|
-
self.
|
179
|
-
self.
|
180
|
-
self.
|
181
|
-
self.
|
258
|
+
self.series_id_map = dict()
|
259
|
+
self._set_ci_column_names(confidence_interval_width)
|
260
|
+
self.horizon = horizon
|
261
|
+
self.target_column_name = target_column
|
262
|
+
self.dt_column_name = dt_column
|
182
263
|
|
183
|
-
def
|
264
|
+
def add_series_id(
|
184
265
|
self,
|
185
|
-
|
186
|
-
target_category_column: str,
|
266
|
+
series_id: str,
|
187
267
|
forecast: pd.DataFrame,
|
188
268
|
overwrite: bool = False,
|
189
269
|
):
|
190
|
-
if not overwrite and
|
270
|
+
if not overwrite and series_id in self.series_id_map.keys():
|
191
271
|
raise ValueError(
|
192
|
-
f"Attempting to update ForecastOutput for
|
272
|
+
f"Attempting to update ForecastOutput for series_id {series_id} when this already exists. Set overwrite to True."
|
193
273
|
)
|
194
274
|
forecast = self._check_forecast_format(forecast)
|
195
|
-
|
196
|
-
self.category_map[category] = forecast
|
197
|
-
self.category_to_target[category] = target_category_column
|
275
|
+
self.series_id_map[series_id] = forecast
|
198
276
|
|
199
|
-
def
|
200
|
-
|
277
|
+
def init_series_output(self, series_id, data_at_series):
|
278
|
+
output_i = pd.DataFrame()
|
201
279
|
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
list(self.category_to_target.values()).index(target_category_column)
|
206
|
-
]
|
207
|
-
return self.category_map[category]
|
280
|
+
output_i["Date"] = data_at_series[self.dt_column_name]
|
281
|
+
output_i["Series"] = series_id
|
282
|
+
output_i["input_value"] = data_at_series[self.target_column_name]
|
208
283
|
|
209
|
-
|
210
|
-
|
284
|
+
output_i["fitted_value"] = float("nan")
|
285
|
+
output_i["forecast_value"] = float("nan")
|
286
|
+
output_i[self.lower_bound_name] = float("nan")
|
287
|
+
output_i[self.upper_bound_name] = float("nan")
|
288
|
+
self.series_id_map[series_id] = output_i
|
211
289
|
|
212
|
-
def
|
213
|
-
|
290
|
+
def populate_series_output(
|
291
|
+
self, series_id, fit_val, forecast_val, upper_bound, lower_bound
|
292
|
+
):
|
293
|
+
"""
|
294
|
+
This method should be run after init_series_output has been run on this series_id
|
214
295
|
|
215
|
-
|
216
|
-
|
296
|
+
Parameters:
|
297
|
+
-----------
|
298
|
+
series_id: [str, int] the series being forecasted
|
299
|
+
fit_val: numpy.array of length input_value - horizon
|
300
|
+
forecast_val: numpy.array of length horizon containing the forecasted values
|
301
|
+
upper_bound: numpy.array of length horizon containing the upper_bound values
|
302
|
+
lower_bound: numpy.array of length horizon containing the lower_bound values
|
217
303
|
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
)
|
304
|
+
Returns:
|
305
|
+
--------
|
306
|
+
None
|
307
|
+
"""
|
308
|
+
try:
|
309
|
+
output_i = self.series_id_map[series_id]
|
310
|
+
except KeyError:
|
311
|
+
raise ValueError(
|
312
|
+
f"Attempting to update output for series: {series_id}, however no series output has been initialized."
|
313
|
+
)
|
229
314
|
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
315
|
+
if (output_i.shape[0] - self.horizon) == len(fit_val):
|
316
|
+
output_i["fitted_value"].iloc[
|
317
|
+
: -self.horizon
|
318
|
+
] = fit_val # Note: may need to do len(output_i) - (len(fit_val) + horizon) : -horizon
|
319
|
+
elif (output_i.shape[0] - self.horizon) > len(fit_val):
|
320
|
+
logger.debug(
|
321
|
+
f"Fitted Values were only generated on a subset ({len(fit_val)}/{(output_i.shape[0] - self.horizon)}) of the data for Series: {series_id}."
|
322
|
+
)
|
323
|
+
start_idx = output_i.shape[0] - self.horizon - len(fit_val)
|
324
|
+
output_i["fitted_value"].iloc[start_idx : -self.horizon] = fit_val
|
325
|
+
else:
|
326
|
+
output_i["fitted_value"].iloc[start_idx : -self.horizon] = fit_val[
|
327
|
+
-(output_i.shape[0] - self.horizon) :
|
328
|
+
]
|
240
329
|
|
241
|
-
|
242
|
-
|
330
|
+
if len(forecast_val) != self.horizon:
|
331
|
+
raise ValueError(
|
332
|
+
f"Attempting to set forecast along horizon ({self.horizon}) for series: {series_id}, however forecast is only length {len(forecast_val)}"
|
333
|
+
)
|
334
|
+
output_i["forecast_value"].iloc[-self.horizon :] = forecast_val
|
335
|
+
|
336
|
+
if len(upper_bound) != self.horizon:
|
337
|
+
raise ValueError(
|
338
|
+
f"Attempting to set upper_bound along horizon ({self.horizon}) for series: {series_id}, however upper_bound is only length {len(upper_bound)}"
|
339
|
+
)
|
340
|
+
output_i[self.upper_bound_name].iloc[-self.horizon :] = upper_bound
|
341
|
+
|
342
|
+
if len(lower_bound) != self.horizon:
|
343
|
+
raise ValueError(
|
344
|
+
f"Attempting to set lower_bound along horizon ({self.horizon}) for series: {series_id}, however lower_bound is only length {len(lower_bound)}"
|
345
|
+
)
|
346
|
+
output_i[self.lower_bound_name].iloc[-self.horizon :] = lower_bound
|
347
|
+
|
348
|
+
self.series_id_map[series_id] = output_i
|
349
|
+
self.verify_series_output(series_id)
|
350
|
+
|
351
|
+
def verify_series_output(self, series_id):
|
352
|
+
forecast = self.series_id_map[series_id]
|
353
|
+
self._check_forecast_format(forecast)
|
354
|
+
|
355
|
+
def get_horizon_by_series(self, series_id):
|
356
|
+
return self.series_id_map[series_id][-self.horizon :]
|
357
|
+
|
358
|
+
def get_horizon_long(self):
|
359
|
+
df = pd.DataFrame()
|
360
|
+
for s_id in self.list_series_ids():
|
361
|
+
df = pd.concat([df, self.get_horizon_by_series(s_id)])
|
362
|
+
return df.reset_index(drop=True)
|
363
|
+
|
364
|
+
def get_forecast(self, series_id):
|
365
|
+
try:
|
366
|
+
return self.series_id_map[series_id]
|
367
|
+
except KeyError as ke:
|
368
|
+
logger.debug(
|
369
|
+
f"No Forecast found for series_id: {series_id}. Returning empty DataFrame."
|
370
|
+
)
|
371
|
+
return pd.DataFrame()
|
372
|
+
|
373
|
+
def list_series_ids(self, sorted=True):
|
374
|
+
series_ids = list(self.series_id_map.keys())
|
375
|
+
if sorted:
|
376
|
+
try:
|
377
|
+
series_ids.sort()
|
378
|
+
except:
|
379
|
+
pass
|
380
|
+
return series_ids
|
381
|
+
|
382
|
+
def _set_ci_column_names(self, confidence_interval_width):
|
383
|
+
yhat_lower_percentage = (100 - confidence_interval_width * 100) // 2
|
384
|
+
self.upper_bound_name = "p" + str(int(100 - yhat_lower_percentage))
|
385
|
+
self.lower_bound_name = "p" + str(int(yhat_lower_percentage))
|
243
386
|
|
244
387
|
def _check_forecast_format(self, forecast):
|
245
388
|
assert isinstance(forecast, pd.DataFrame)
|
@@ -251,8 +394,8 @@ class ForecastOutput:
|
|
251
394
|
assert ForecastOutputColumns.INPUT_VALUE in forecast.columns
|
252
395
|
assert ForecastOutputColumns.FITTED_VALUE in forecast.columns
|
253
396
|
assert ForecastOutputColumns.FORECAST_VALUE in forecast.columns
|
254
|
-
assert
|
255
|
-
assert
|
397
|
+
assert self.upper_bound_name in forecast.columns
|
398
|
+
assert self.lower_bound_name in forecast.columns
|
256
399
|
assert not forecast.empty
|
257
400
|
# forecast.columns = pd.Index([
|
258
401
|
# ForecastOutputColumns.DATE,
|
@@ -264,3 +407,9 @@ class ForecastOutput:
|
|
264
407
|
# ForecastOutputColumns.LOWER_BOUND,
|
265
408
|
# ])
|
266
409
|
return forecast
|
410
|
+
|
411
|
+
def get_forecast_long(self):
|
412
|
+
output = pd.DataFrame()
|
413
|
+
for df in self.series_id_map.values():
|
414
|
+
output = pd.concat([output, df])
|
415
|
+
return output.reset_index(drop=True)
|