oracle-ads 2.10.0__py3-none-any.whl → 2.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +12 -0
- ads/aqua/base.py +324 -0
- ads/aqua/cli.py +19 -0
- ads/aqua/config/deployment_config_defaults.json +9 -0
- ads/aqua/config/resource_limit_names.json +7 -0
- ads/aqua/constants.py +45 -0
- ads/aqua/data.py +40 -0
- ads/aqua/decorator.py +101 -0
- ads/aqua/deployment.py +643 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation.py +1751 -0
- ads/aqua/exception.py +82 -0
- ads/aqua/extension/__init__.py +40 -0
- ads/aqua/extension/base_handler.py +138 -0
- ads/aqua/extension/common_handler.py +21 -0
- ads/aqua/extension/deployment_handler.py +202 -0
- ads/aqua/extension/evaluation_handler.py +135 -0
- ads/aqua/extension/finetune_handler.py +66 -0
- ads/aqua/extension/model_handler.py +59 -0
- ads/aqua/extension/ui_handler.py +201 -0
- ads/aqua/extension/utils.py +23 -0
- ads/aqua/finetune.py +579 -0
- ads/aqua/job.py +29 -0
- ads/aqua/model.py +819 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +459 -0
- ads/aqua/ui.py +453 -0
- ads/aqua/utils.py +715 -0
- ads/cli.py +37 -6
- ads/common/auth.py +7 -0
- ads/common/decorator/__init__.py +7 -3
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/object_storage_details.py +166 -7
- ads/common/oci_client.py +18 -1
- ads/common/oci_logging.py +2 -2
- ads/common/oci_mixin.py +4 -5
- ads/common/serializer.py +34 -5
- ads/common/utils.py +75 -10
- ads/config.py +40 -1
- ads/dataset/correlation_plot.py +10 -12
- ads/jobs/ads_job.py +43 -25
- ads/jobs/builders/infrastructure/base.py +4 -2
- ads/jobs/builders/infrastructure/dsc_job.py +49 -39
- ads/jobs/builders/runtimes/base.py +71 -1
- ads/jobs/builders/runtimes/container_runtime.py +4 -4
- ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
- ads/jobs/templates/driver_pytorch.py +27 -10
- ads/model/artifact_downloader.py +84 -14
- ads/model/artifact_uploader.py +25 -23
- ads/model/datascience_model.py +388 -38
- ads/model/deployment/model_deployment.py +10 -2
- ads/model/generic_model.py +8 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_metadata.py +1 -1
- ads/model/service/oci_datascience_model.py +34 -5
- ads/opctl/config/merger.py +2 -2
- ads/opctl/operator/__init__.py +3 -1
- ads/opctl/operator/cli.py +7 -1
- ads/opctl/operator/cmd.py +3 -3
- ads/opctl/operator/common/errors.py +2 -1
- ads/opctl/operator/common/operator_config.py +22 -3
- ads/opctl/operator/common/utils.py +16 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +15 -0
- ads/opctl/operator/lowcode/anomaly/README.md +209 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +104 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +88 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +12 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +147 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +89 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +103 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +354 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +67 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +105 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +359 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +81 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +96 -0
- ads/opctl/operator/lowcode/common/errors.py +41 -0
- ads/opctl/operator/lowcode/common/transformations.py +191 -0
- ads/opctl/operator/lowcode/common/utils.py +250 -0
- ads/opctl/operator/lowcode/forecast/README.md +3 -2
- ads/opctl/operator/lowcode/forecast/__main__.py +18 -2
- ads/opctl/operator/lowcode/forecast/cmd.py +8 -7
- ads/opctl/operator/lowcode/forecast/const.py +17 -1
- ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
- ads/opctl/operator/lowcode/forecast/model/arima.py +106 -117
- ads/opctl/operator/lowcode/forecast/model/automlx.py +204 -180
- ads/opctl/operator/lowcode/forecast/model/autots.py +144 -253
- ads/opctl/operator/lowcode/forecast/model/base_model.py +326 -259
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +325 -176
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +293 -237
- ads/opctl/operator/lowcode/forecast/model/prophet.py +191 -208
- ads/opctl/operator/lowcode/forecast/operator_config.py +24 -33
- ads/opctl/operator/lowcode/forecast/schema.yaml +116 -29
- ads/opctl/operator/lowcode/forecast/utils.py +186 -356
- ads/opctl/operator/lowcode/pii/model/guardrails.py +18 -15
- ads/opctl/operator/lowcode/pii/model/report.py +7 -7
- ads/opctl/operator/lowcode/pii/operator_config.py +1 -8
- ads/opctl/operator/lowcode/pii/utils.py +0 -82
- ads/opctl/operator/runtime/runtime.py +3 -2
- ads/telemetry/base.py +62 -0
- ads/telemetry/client.py +105 -0
- ads/telemetry/telemetry.py +6 -3
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +44 -7
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +116 -59
- ads/opctl/operator/lowcode/forecast/model/transformations.py +0 -125
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
@@ -1,25 +1,42 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
# -*- coding: utf-8 -*--
|
3
3
|
|
4
|
-
# Copyright (c)
|
4
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
5
5
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import optuna
|
9
9
|
import pandas as pd
|
10
|
+
import logging
|
11
|
+
from joblib import Parallel, delayed
|
10
12
|
from ads.common.decorator.runtime_dependency import runtime_dependency
|
11
13
|
from ads.opctl import logger
|
12
14
|
from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorConfig
|
13
15
|
|
14
|
-
from ..const import
|
15
|
-
|
16
|
+
from ..const import (
|
17
|
+
DEFAULT_TRIALS,
|
18
|
+
PROPHET_INTERNAL_DATE_COL,
|
19
|
+
ForecastOutputColumns,
|
20
|
+
SupportedModels,
|
21
|
+
)
|
22
|
+
from ads.opctl.operator.lowcode.forecast.utils import (
|
23
|
+
_select_plot_list,
|
24
|
+
_label_encode_dataframe,
|
25
|
+
)
|
26
|
+
from ads.opctl.operator.lowcode.common.utils import set_log_level
|
16
27
|
from .base_model import ForecastOperatorBaseModel
|
17
28
|
from ..operator_config import ForecastOperatorConfig
|
18
29
|
from .forecast_datasets import ForecastDatasets, ForecastOutput
|
19
30
|
import traceback
|
20
31
|
import matplotlib as mpl
|
21
32
|
|
22
|
-
|
33
|
+
|
34
|
+
try:
|
35
|
+
set_log_level("prophet", logger.level)
|
36
|
+
set_log_level("cmdstanpy", logger.level)
|
37
|
+
mpl.rcParams["figure.max_open_warning"] = 100
|
38
|
+
except:
|
39
|
+
pass
|
23
40
|
|
24
41
|
|
25
42
|
def _add_unit(num, unit):
|
@@ -41,150 +58,45 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
|
|
41
58
|
|
42
59
|
def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
|
43
60
|
super().__init__(config=config, datasets=datasets)
|
44
|
-
self.train_metrics = True
|
45
61
|
self.global_explanation = {}
|
46
62
|
self.local_explanation = {}
|
47
63
|
|
48
|
-
def
|
49
|
-
from prophet import Prophet
|
50
|
-
from prophet.diagnostics import cross_validation, performance_metrics
|
51
|
-
|
52
|
-
full_data_dict = self.datasets.full_data_dict
|
53
|
-
models = []
|
54
|
-
outputs = dict()
|
55
|
-
outputs_legacy = []
|
56
|
-
|
64
|
+
def set_kwargs(self):
|
57
65
|
# Extract the Confidence Interval Width and convert to prophet's equivalent - interval_width
|
58
66
|
if self.spec.confidence_interval_width is None:
|
59
67
|
self.spec.confidence_interval_width = 1 - self.spec.model_kwargs.get(
|
60
68
|
"alpha", 0.90
|
61
69
|
)
|
62
|
-
|
63
70
|
model_kwargs = self.spec.model_kwargs
|
64
71
|
model_kwargs["interval_width"] = self.spec.confidence_interval_width
|
72
|
+
return model_kwargs
|
65
73
|
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
for i, (target, df) in enumerate(full_data_dict.items()):
|
71
|
-
le, df_encoded = utils._label_encode_dataframe(
|
72
|
-
df, no_encode={self.spec.datetime_column.name, target}
|
73
|
-
)
|
74
|
+
def _train_model(self, i, series_id, df, model_kwargs):
|
75
|
+
try:
|
76
|
+
from prophet import Prophet
|
77
|
+
from prophet.diagnostics import cross_validation, performance_metrics
|
74
78
|
|
75
|
-
|
76
|
-
|
77
|
-
df_clean = self._preprocess(
|
78
|
-
df_encoded,
|
79
|
-
self.spec.datetime_column.name,
|
80
|
-
self.spec.datetime_column.format,
|
79
|
+
self.forecast_output.init_series_output(
|
80
|
+
series_id=series_id, data_at_series=df
|
81
81
|
)
|
82
|
-
data_i = df_clean[df_clean[target].notna()]
|
83
|
-
data_i.rename({target: "y"}, axis=1, inplace=True)
|
84
|
-
|
85
|
-
# Assume that all columns passed in should be used as additional data
|
86
|
-
additional_regressors = set(data_i.columns) - {
|
87
|
-
"y",
|
88
|
-
PROPHET_INTERNAL_DATE_COL,
|
89
|
-
}
|
90
|
-
|
91
|
-
if self.perform_tuning:
|
92
|
-
|
93
|
-
def objective(trial):
|
94
|
-
params = {
|
95
|
-
"seasonality_mode": trial.suggest_categorical(
|
96
|
-
"seasonality_mode", ["additive", "multiplicative"]
|
97
|
-
),
|
98
|
-
"changepoint_prior_scale": trial.suggest_float(
|
99
|
-
"changepoint_prior_scale", 0.001, 0.5, log=True
|
100
|
-
),
|
101
|
-
"seasonality_prior_scale": trial.suggest_float(
|
102
|
-
"seasonality_prior_scale", 0.01, 10, log=True
|
103
|
-
),
|
104
|
-
"holidays_prior_scale": trial.suggest_float(
|
105
|
-
"holidays_prior_scale", 0.01, 10, log=True
|
106
|
-
),
|
107
|
-
"changepoint_range": trial.suggest_float(
|
108
|
-
"changepoint_range", 0.8, 0.95
|
109
|
-
),
|
110
|
-
}
|
111
|
-
params.update(model_kwargs_i)
|
112
|
-
|
113
|
-
model = _fit_model(
|
114
|
-
data=data_i,
|
115
|
-
params=params,
|
116
|
-
additional_regressors=additional_regressors,
|
117
|
-
)
|
118
82
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
unit = "D"
|
127
|
-
interval = interval * 365.25
|
128
|
-
horizon = _add_unit(int(self.spec.horizon * interval), unit=unit)
|
129
|
-
initial = _add_unit((data_i.shape[0] * interval) // 2, unit=unit)
|
130
|
-
period = _add_unit((data_i.shape[0] * interval) // 4, unit=unit)
|
131
|
-
|
132
|
-
logger.debug(
|
133
|
-
f"using: horizon: {horizon}. initial:{initial}, period: {period}"
|
134
|
-
)
|
83
|
+
data = self.preprocess(df, series_id)
|
84
|
+
data_i = self.drop_horizon(data)
|
85
|
+
if self.loaded_models is not None:
|
86
|
+
model = self.loaded_models[series_id]
|
87
|
+
else:
|
88
|
+
if self.perform_tuning:
|
89
|
+
model_kwargs = self.run_tuning(data_i, model_kwargs)
|
135
90
|
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
period=period,
|
141
|
-
parallel="threads",
|
142
|
-
)
|
143
|
-
df_p = performance_metrics(df_cv)
|
144
|
-
try:
|
145
|
-
return np.mean(df_p[self.spec.metric])
|
146
|
-
except KeyError:
|
147
|
-
logger.warn(
|
148
|
-
f"Could not find the metric {self.spec.metric} within "
|
149
|
-
f"the performance metrics: {df_p.columns}. Defaulting to `rmse`"
|
150
|
-
)
|
151
|
-
return np.mean(df_p["rmse"])
|
152
|
-
|
153
|
-
study = optuna.create_study(direction="minimize")
|
154
|
-
m_temp = Prophet()
|
155
|
-
study.enqueue_trial(
|
156
|
-
{
|
157
|
-
"seasonality_mode": m_temp.seasonality_mode,
|
158
|
-
"changepoint_prior_scale": m_temp.changepoint_prior_scale,
|
159
|
-
"seasonality_prior_scale": m_temp.seasonality_prior_scale,
|
160
|
-
"holidays_prior_scale": m_temp.holidays_prior_scale,
|
161
|
-
"changepoint_range": m_temp.changepoint_range,
|
162
|
-
}
|
163
|
-
)
|
164
|
-
study.optimize(
|
165
|
-
objective,
|
166
|
-
n_trials=self.spec.tuning.n_trials
|
167
|
-
if self.spec.tuning
|
168
|
-
else DEFAULT_TRIALS,
|
169
|
-
n_jobs=-1,
|
91
|
+
model = _fit_model(
|
92
|
+
data=data,
|
93
|
+
params=model_kwargs,
|
94
|
+
additional_regressors=self.additional_regressors,
|
170
95
|
)
|
171
96
|
|
172
|
-
|
173
|
-
|
174
|
-
model = _fit_model(
|
175
|
-
data=data_i,
|
176
|
-
params=model_kwargs_i,
|
177
|
-
additional_regressors=additional_regressors,
|
178
|
-
)
|
97
|
+
# Get future df for prediction
|
98
|
+
future = data.drop("y", axis=1)
|
179
99
|
|
180
|
-
# Make future df for prediction
|
181
|
-
if len(additional_regressors):
|
182
|
-
future = df_clean.drop(target, axis=1)
|
183
|
-
else:
|
184
|
-
future = model.make_future_dataframe(
|
185
|
-
periods=self.spec.horizon,
|
186
|
-
freq=self.spec.freq,
|
187
|
-
)
|
188
100
|
# Make Prediction
|
189
101
|
forecast = model.predict(future)
|
190
102
|
logger.debug(f"-----------------Model {i}----------------------")
|
@@ -194,107 +106,186 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
|
|
194
106
|
].tail()
|
195
107
|
)
|
196
108
|
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
logger.debug("===========Done===========")
|
206
|
-
|
207
|
-
# Merge the outputs from each model into 1 df with all outputs by target and category
|
208
|
-
col = self.original_target_column
|
209
|
-
output_col = pd.DataFrame()
|
210
|
-
yhat_upper_name = ForecastOutputColumns.UPPER_BOUND
|
211
|
-
yhat_lower_name = ForecastOutputColumns.LOWER_BOUND
|
212
|
-
for cat in self.categories:
|
213
|
-
output_i = pd.DataFrame()
|
214
|
-
|
215
|
-
output_i["Date"] = outputs[f"{col}_{cat}"][PROPHET_INTERNAL_DATE_COL]
|
216
|
-
output_i["Series"] = cat
|
217
|
-
output_i["input_value"] = full_data_dict[f"{col}_{cat}"][f"{col}_{cat}"]
|
218
|
-
|
219
|
-
output_i[f"fitted_value"] = float("nan")
|
220
|
-
output_i[f"forecast_value"] = float("nan")
|
221
|
-
output_i[yhat_upper_name] = float("nan")
|
222
|
-
output_i[yhat_lower_name] = float("nan")
|
223
|
-
|
224
|
-
output_i.iloc[
|
225
|
-
: -self.spec.horizon, output_i.columns.get_loc(f"fitted_value")
|
226
|
-
] = (outputs[f"{col}_{cat}"]["yhat"].iloc[: -self.spec.horizon].values)
|
227
|
-
output_i.iloc[
|
228
|
-
-self.spec.horizon :,
|
229
|
-
output_i.columns.get_loc(f"forecast_value"),
|
230
|
-
] = (
|
231
|
-
outputs[f"{col}_{cat}"]["yhat"].iloc[-self.spec.horizon :].values
|
109
|
+
self.outputs[series_id] = forecast
|
110
|
+
self.forecast_output.populate_series_output(
|
111
|
+
series_id=series_id,
|
112
|
+
fit_val=self.drop_horizon(forecast["yhat"]).values,
|
113
|
+
forecast_val=self.get_horizon(forecast["yhat"]).values,
|
114
|
+
upper_bound=self.get_horizon(forecast["yhat_upper"]).values,
|
115
|
+
lower_bound=self.get_horizon(forecast["yhat_lower"]).values,
|
232
116
|
)
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
117
|
+
self.models[series_id] = model
|
118
|
+
|
119
|
+
params = vars(model).copy()
|
120
|
+
for param in ["history", "history_dates", "stan_fit"]:
|
121
|
+
if param in params:
|
122
|
+
params.pop(param)
|
123
|
+
self.model_parameters[series_id] = {
|
124
|
+
"framework": SupportedModels.Prophet,
|
125
|
+
**params,
|
126
|
+
}
|
127
|
+
|
128
|
+
logger.debug("===========Done===========")
|
129
|
+
except Exception as e:
|
130
|
+
self.errors_dict[series_id] = {
|
131
|
+
"model_name": self.spec.model,
|
132
|
+
"error": str(e),
|
133
|
+
}
|
134
|
+
|
135
|
+
def _build_model(self) -> pd.DataFrame:
|
136
|
+
from prophet import Prophet
|
137
|
+
from prophet.diagnostics import cross_validation, performance_metrics
|
138
|
+
|
139
|
+
full_data_dict = self.datasets.get_data_by_series()
|
140
|
+
self.models = dict()
|
141
|
+
self.outputs = dict()
|
142
|
+
self.additional_regressors = self.datasets.get_additional_data_column_names()
|
143
|
+
model_kwargs = self.set_kwargs()
|
144
|
+
self.forecast_output = ForecastOutput(
|
145
|
+
confidence_interval_width=self.spec.confidence_interval_width,
|
146
|
+
horizon=self.spec.horizon,
|
147
|
+
target_column=self.original_target_column,
|
148
|
+
dt_column=self.spec.datetime_column.name,
|
149
|
+
)
|
150
|
+
|
151
|
+
Parallel(n_jobs=-1, require="sharedmem")(
|
152
|
+
delayed(ProphetOperatorModel._train_model)(
|
153
|
+
self, i, series_id, df, model_kwargs.copy()
|
237
154
|
)
|
238
|
-
|
239
|
-
|
240
|
-
] = (
|
241
|
-
outputs[f"{col}_{cat}"]["yhat_lower"].iloc[-self.spec.horizon :].values
|
155
|
+
for self, (i, (series_id, df)) in zip(
|
156
|
+
[self] * len(full_data_dict), enumerate(full_data_dict.items())
|
242
157
|
)
|
243
|
-
|
244
|
-
|
245
|
-
|
158
|
+
)
|
159
|
+
|
160
|
+
return self.forecast_output.get_forecast_long()
|
161
|
+
|
162
|
+
def run_tuning(self, data_i, model_kwargs_i):
|
163
|
+
def objective(trial):
|
164
|
+
params = {
|
165
|
+
"seasonality_mode": trial.suggest_categorical(
|
166
|
+
"seasonality_mode", ["additive", "multiplicative"]
|
167
|
+
),
|
168
|
+
"changepoint_prior_scale": trial.suggest_float(
|
169
|
+
"changepoint_prior_scale", 0.001, 0.5, log=True
|
170
|
+
),
|
171
|
+
"seasonality_prior_scale": trial.suggest_float(
|
172
|
+
"seasonality_prior_scale", 0.01, 10, log=True
|
173
|
+
),
|
174
|
+
"holidays_prior_scale": trial.suggest_float(
|
175
|
+
"holidays_prior_scale", 0.01, 10, log=True
|
176
|
+
),
|
177
|
+
"changepoint_range": trial.suggest_float(
|
178
|
+
"changepoint_range", 0.8, 0.95
|
179
|
+
),
|
180
|
+
}
|
181
|
+
params.update(model_kwargs_i)
|
182
|
+
|
183
|
+
model = _fit_model(
|
184
|
+
data=data_i,
|
185
|
+
params=params,
|
186
|
+
additional_regressors=self.additional_regressors,
|
246
187
|
)
|
247
188
|
|
248
|
-
|
189
|
+
# Manual workaround because pandas 1.x dropped support for M and Y
|
190
|
+
interval = self.spec.horizon
|
191
|
+
freq = self.datasets.get_datetime_frequency()
|
192
|
+
unit = freq.split("-")[0] if freq else None
|
193
|
+
if unit == "M":
|
194
|
+
unit = "D"
|
195
|
+
interval = interval * 30.5
|
196
|
+
elif unit == "Y":
|
197
|
+
unit = "D"
|
198
|
+
interval = interval * 365.25
|
199
|
+
horizon = _add_unit(int(self.spec.horizon * interval), unit=unit)
|
200
|
+
initial = _add_unit((data_i.shape[0] * interval) // 2, unit=unit)
|
201
|
+
period = _add_unit((data_i.shape[0] * interval) // 4, unit=unit)
|
249
202
|
|
250
|
-
|
203
|
+
logger.debug(
|
204
|
+
f"using: horizon: {horizon}. initial:{initial}, period: {period}"
|
205
|
+
)
|
206
|
+
|
207
|
+
df_cv = cross_validation(
|
208
|
+
model,
|
209
|
+
horizon=horizon,
|
210
|
+
initial=initial,
|
211
|
+
period=period,
|
212
|
+
parallel="threads",
|
213
|
+
)
|
214
|
+
df_p = performance_metrics(df_cv)
|
215
|
+
try:
|
216
|
+
return np.mean(df_p[self.spec.metric])
|
217
|
+
except KeyError:
|
218
|
+
logger.warn(
|
219
|
+
f"Could not find the metric {self.spec.metric} within "
|
220
|
+
f"the performance metrics: {df_p.columns}. Defaulting to `rmse`"
|
221
|
+
)
|
222
|
+
return np.mean(df_p["rmse"])
|
223
|
+
|
224
|
+
study = optuna.create_study(direction="minimize")
|
225
|
+
m_temp = Prophet()
|
226
|
+
study.enqueue_trial(
|
227
|
+
{
|
228
|
+
"seasonality_mode": m_temp.seasonality_mode,
|
229
|
+
"changepoint_prior_scale": m_temp.changepoint_prior_scale,
|
230
|
+
"seasonality_prior_scale": m_temp.seasonality_prior_scale,
|
231
|
+
"holidays_prior_scale": m_temp.holidays_prior_scale,
|
232
|
+
"changepoint_range": m_temp.changepoint_range,
|
233
|
+
}
|
234
|
+
)
|
235
|
+
study.optimize(
|
236
|
+
objective,
|
237
|
+
n_trials=self.spec.tuning.n_trials if self.spec.tuning else DEFAULT_TRIALS,
|
238
|
+
n_jobs=-1,
|
239
|
+
)
|
240
|
+
|
241
|
+
study.best_params.update(model_kwargs_i)
|
242
|
+
model_kwargs_i = study.best_params
|
243
|
+
return model_kwargs_i
|
251
244
|
|
252
245
|
def _generate_report(self):
|
253
246
|
import datapane as dp
|
254
247
|
from prophet.plot import add_changepoints_to_plot
|
255
248
|
|
249
|
+
series_ids = self.datasets.list_series_ids()
|
250
|
+
|
256
251
|
sec1_text = dp.Text(
|
257
252
|
"## Forecast Overview \n"
|
258
253
|
"These plots show your forecast in the context of historical data."
|
259
254
|
)
|
260
|
-
sec1 =
|
261
|
-
lambda
|
262
|
-
self.outputs[
|
255
|
+
sec1 = _select_plot_list(
|
256
|
+
lambda s_id: self.models[s_id].plot(
|
257
|
+
self.outputs[s_id], include_legend=True
|
263
258
|
),
|
264
|
-
|
259
|
+
series_ids=series_ids,
|
265
260
|
)
|
266
261
|
|
267
262
|
sec2_text = dp.Text(f"## Forecast Broken Down by Trend Component")
|
268
|
-
sec2 =
|
269
|
-
lambda
|
270
|
-
|
263
|
+
sec2 = _select_plot_list(
|
264
|
+
lambda s_id: self.models[s_id].plot_components(self.outputs[s_id]),
|
265
|
+
series_ids=series_ids,
|
271
266
|
)
|
272
267
|
|
273
268
|
sec3_text = dp.Text(f"## Forecast Changepoints")
|
274
|
-
sec3_figs =
|
275
|
-
self.models[
|
276
|
-
|
277
|
-
|
278
|
-
[
|
269
|
+
sec3_figs = {
|
270
|
+
s_id: self.models[s_id].plot(self.outputs[s_id]) for s_id in series_ids
|
271
|
+
}
|
272
|
+
for s_id in series_ids:
|
279
273
|
add_changepoints_to_plot(
|
280
|
-
sec3_figs[
|
274
|
+
sec3_figs[s_id].gca(), self.models[s_id], self.outputs[s_id]
|
281
275
|
)
|
282
|
-
|
283
|
-
]
|
284
|
-
sec3 = utils._select_plot_list(
|
285
|
-
lambda idx, *args: sec3_figs[idx], target_columns=self.target_columns
|
286
|
-
)
|
276
|
+
sec3 = _select_plot_list(lambda s_id: sec3_figs[s_id], series_ids=series_ids)
|
287
277
|
|
288
278
|
all_sections = [sec1_text, sec1, sec2_text, sec2, sec3_text, sec3]
|
289
279
|
|
290
280
|
sec5_text = dp.Text(f"## Prophet Model Seasonality Components")
|
291
281
|
model_states = []
|
292
|
-
for
|
282
|
+
for s_id in series_ids:
|
283
|
+
m = self.models[s_id]
|
293
284
|
model_states.append(
|
294
285
|
pd.Series(
|
295
286
|
m.seasonalities,
|
296
287
|
index=pd.Index(m.seasonalities.keys(), dtype="object"),
|
297
|
-
name=
|
288
|
+
name=s_id,
|
298
289
|
dtype="object",
|
299
290
|
)
|
300
291
|
)
|
@@ -306,10 +297,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
|
|
306
297
|
if self.spec.generate_explanations:
|
307
298
|
try:
|
308
299
|
# If the key is present, call the "explain_model" method
|
309
|
-
self.explain_model(
|
310
|
-
datetime_col_name=PROPHET_INTERNAL_DATE_COL,
|
311
|
-
explain_predict_fn=self._custom_predict_prophet,
|
312
|
-
)
|
300
|
+
self.explain_model()
|
313
301
|
|
314
302
|
# Create a markdown text block for the global explanation section
|
315
303
|
global_explanation_text = dp.Text(
|
@@ -333,7 +321,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
|
|
333
321
|
aggregate_local_explanations = pd.DataFrame()
|
334
322
|
for s_id, local_ex_df in self.local_explanation.items():
|
335
323
|
local_ex_df_copy = local_ex_df.copy()
|
336
|
-
local_ex_df_copy[
|
324
|
+
local_ex_df_copy[ForecastOutputColumns.SERIES] = s_id
|
337
325
|
aggregate_local_explanations = pd.concat(
|
338
326
|
[aggregate_local_explanations, local_ex_df_copy], axis=0
|
339
327
|
)
|
@@ -376,8 +364,3 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
|
|
376
364
|
model_description,
|
377
365
|
other_sections,
|
378
366
|
)
|
379
|
-
|
380
|
-
def _custom_predict_prophet(self, data):
|
381
|
-
return self.models[self.target_columns.index(self.series_id)].predict(
|
382
|
-
data.reset_index()
|
383
|
-
)["yhat"]
|
@@ -10,44 +10,16 @@ from typing import Dict, List
|
|
10
10
|
|
11
11
|
from ads.common.serializer import DataClassSerializable
|
12
12
|
from ads.opctl.operator.common.utils import _load_yaml_from_uri
|
13
|
-
from ads.opctl.operator.common.operator_config import OperatorConfig
|
13
|
+
from ads.opctl.operator.common.operator_config import OperatorConfig, OutputDirectory, InputData
|
14
14
|
|
15
|
-
from .const import SupportedMetrics
|
15
|
+
from .const import SupportedMetrics, SpeedAccuracyMode
|
16
16
|
from .const import SupportedModels
|
17
17
|
|
18
|
-
@dataclass(repr=True)
|
19
|
-
class InputData(DataClassSerializable):
|
20
|
-
"""Class representing operator specification input data details."""
|
21
|
-
|
22
|
-
format: str = None
|
23
|
-
columns: List[str] = None
|
24
|
-
url: str = None
|
25
|
-
options: Dict = None
|
26
|
-
limit: int = None
|
27
|
-
|
28
18
|
|
29
19
|
@dataclass(repr=True)
|
30
|
-
class TestData(
|
20
|
+
class TestData(InputData):
|
31
21
|
"""Class representing operator specification test data details."""
|
32
22
|
|
33
|
-
connect_args: Dict = None
|
34
|
-
format: str = None
|
35
|
-
columns: List[str] = None
|
36
|
-
url: str = None
|
37
|
-
name: str = None
|
38
|
-
options: Dict = None
|
39
|
-
|
40
|
-
|
41
|
-
@dataclass(repr=True)
|
42
|
-
class OutputDirectory(DataClassSerializable):
|
43
|
-
"""Class representing operator specification output directory details."""
|
44
|
-
|
45
|
-
connect_args: Dict = None
|
46
|
-
format: str = None
|
47
|
-
url: str = None
|
48
|
-
name: str = None
|
49
|
-
options: Dict = None
|
50
|
-
|
51
23
|
|
52
24
|
@dataclass(repr=True)
|
53
25
|
class DateTimeColumn(DataClassSerializable):
|
@@ -88,10 +60,14 @@ class ForecastOperatorSpec(DataClassSerializable):
|
|
88
60
|
generate_report: bool = None
|
89
61
|
generate_metrics: bool = None
|
90
62
|
generate_explanations: bool = None
|
63
|
+
explanations_accuracy_mode: str = None
|
91
64
|
horizon: int = None
|
92
|
-
freq: str = None
|
93
65
|
model: str = None
|
94
66
|
model_kwargs: Dict = field(default_factory=dict)
|
67
|
+
model_parameters: str = None
|
68
|
+
previous_output_dir: str = None
|
69
|
+
generate_model_parameters: bool = None
|
70
|
+
generate_model_pickle: bool = None
|
95
71
|
confidence_interval_width: float = None
|
96
72
|
metric: str = None
|
97
73
|
tuning: Tuning = field(default_factory=Tuning)
|
@@ -99,7 +75,7 @@ class ForecastOperatorSpec(DataClassSerializable):
|
|
99
75
|
def __post_init__(self):
|
100
76
|
"""Adjusts the specification details."""
|
101
77
|
self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower()
|
102
|
-
self.model =
|
78
|
+
self.model = self.model or SupportedModels.Auto
|
103
79
|
self.confidence_interval_width = self.confidence_interval_width or 0.80
|
104
80
|
self.report_filename = self.report_filename or "report.html"
|
105
81
|
self.preprocessing = (
|
@@ -119,6 +95,20 @@ class ForecastOperatorSpec(DataClassSerializable):
|
|
119
95
|
if self.generate_explanations is not None
|
120
96
|
else False
|
121
97
|
)
|
98
|
+
self.explanations_accuracy_mode = (
|
99
|
+
self.explanations_accuracy_mode or SpeedAccuracyMode.FAST_APPROXIMATE
|
100
|
+
)
|
101
|
+
|
102
|
+
self.generate_model_parameters = (
|
103
|
+
self.generate_model_parameters
|
104
|
+
if self.generate_model_parameters is not None
|
105
|
+
else False
|
106
|
+
)
|
107
|
+
self.generate_model_pickle = (
|
108
|
+
self.generate_model_pickle
|
109
|
+
if self.generate_model_pickle is not None
|
110
|
+
else False
|
111
|
+
)
|
122
112
|
self.report_theme = self.report_theme or "light"
|
123
113
|
self.metrics_filename = self.metrics_filename or "metrics.csv"
|
124
114
|
self.test_metrics_filename = self.test_metrics_filename or "test_metrics.csv"
|
@@ -130,6 +120,7 @@ class ForecastOperatorSpec(DataClassSerializable):
|
|
130
120
|
self.local_explanation_filename or "local_explanation.csv"
|
131
121
|
)
|
132
122
|
self.target_column = self.target_column or "Sales"
|
123
|
+
self.errors_dict_filename = "errors.json"
|
133
124
|
self.model_kwargs = self.model_kwargs or dict()
|
134
125
|
|
135
126
|
|