oracle-ads 2.10.0__py3-none-any.whl → 2.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. ads/aqua/__init__.py +12 -0
  2. ads/aqua/base.py +324 -0
  3. ads/aqua/cli.py +19 -0
  4. ads/aqua/config/deployment_config_defaults.json +9 -0
  5. ads/aqua/config/resource_limit_names.json +7 -0
  6. ads/aqua/constants.py +45 -0
  7. ads/aqua/data.py +40 -0
  8. ads/aqua/decorator.py +101 -0
  9. ads/aqua/deployment.py +643 -0
  10. ads/aqua/dummy_data/icon.txt +1 -0
  11. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  12. ads/aqua/dummy_data/oci_models.json +1 -0
  13. ads/aqua/dummy_data/readme.md +26 -0
  14. ads/aqua/evaluation.py +1751 -0
  15. ads/aqua/exception.py +82 -0
  16. ads/aqua/extension/__init__.py +40 -0
  17. ads/aqua/extension/base_handler.py +138 -0
  18. ads/aqua/extension/common_handler.py +21 -0
  19. ads/aqua/extension/deployment_handler.py +202 -0
  20. ads/aqua/extension/evaluation_handler.py +135 -0
  21. ads/aqua/extension/finetune_handler.py +66 -0
  22. ads/aqua/extension/model_handler.py +59 -0
  23. ads/aqua/extension/ui_handler.py +201 -0
  24. ads/aqua/extension/utils.py +23 -0
  25. ads/aqua/finetune.py +579 -0
  26. ads/aqua/job.py +29 -0
  27. ads/aqua/model.py +819 -0
  28. ads/aqua/training/__init__.py +4 -0
  29. ads/aqua/training/exceptions.py +459 -0
  30. ads/aqua/ui.py +453 -0
  31. ads/aqua/utils.py +715 -0
  32. ads/cli.py +37 -6
  33. ads/common/auth.py +7 -0
  34. ads/common/decorator/__init__.py +7 -3
  35. ads/common/decorator/require_nonempty_arg.py +65 -0
  36. ads/common/object_storage_details.py +166 -7
  37. ads/common/oci_client.py +18 -1
  38. ads/common/oci_logging.py +2 -2
  39. ads/common/oci_mixin.py +4 -5
  40. ads/common/serializer.py +34 -5
  41. ads/common/utils.py +75 -10
  42. ads/config.py +40 -1
  43. ads/dataset/correlation_plot.py +10 -12
  44. ads/jobs/ads_job.py +43 -25
  45. ads/jobs/builders/infrastructure/base.py +4 -2
  46. ads/jobs/builders/infrastructure/dsc_job.py +49 -39
  47. ads/jobs/builders/runtimes/base.py +71 -1
  48. ads/jobs/builders/runtimes/container_runtime.py +4 -4
  49. ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
  50. ads/jobs/templates/driver_pytorch.py +27 -10
  51. ads/model/artifact_downloader.py +84 -14
  52. ads/model/artifact_uploader.py +25 -23
  53. ads/model/datascience_model.py +388 -38
  54. ads/model/deployment/model_deployment.py +10 -2
  55. ads/model/generic_model.py +8 -0
  56. ads/model/model_file_description_schema.json +68 -0
  57. ads/model/model_metadata.py +1 -1
  58. ads/model/service/oci_datascience_model.py +34 -5
  59. ads/opctl/config/merger.py +2 -2
  60. ads/opctl/operator/__init__.py +3 -1
  61. ads/opctl/operator/cli.py +7 -1
  62. ads/opctl/operator/cmd.py +3 -3
  63. ads/opctl/operator/common/errors.py +2 -1
  64. ads/opctl/operator/common/operator_config.py +22 -3
  65. ads/opctl/operator/common/utils.py +16 -0
  66. ads/opctl/operator/lowcode/anomaly/MLoperator +15 -0
  67. ads/opctl/operator/lowcode/anomaly/README.md +209 -0
  68. ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
  69. ads/opctl/operator/lowcode/anomaly/__main__.py +104 -0
  70. ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
  71. ads/opctl/operator/lowcode/anomaly/const.py +88 -0
  72. ads/opctl/operator/lowcode/anomaly/environment.yaml +12 -0
  73. ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
  74. ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +147 -0
  75. ads/opctl/operator/lowcode/anomaly/model/automlx.py +89 -0
  76. ads/opctl/operator/lowcode/anomaly/model/autots.py +103 -0
  77. ads/opctl/operator/lowcode/anomaly/model/base_model.py +354 -0
  78. ads/opctl/operator/lowcode/anomaly/model/factory.py +67 -0
  79. ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
  80. ads/opctl/operator/lowcode/anomaly/operator_config.py +105 -0
  81. ads/opctl/operator/lowcode/anomaly/schema.yaml +359 -0
  82. ads/opctl/operator/lowcode/anomaly/utils.py +81 -0
  83. ads/opctl/operator/lowcode/common/__init__.py +5 -0
  84. ads/opctl/operator/lowcode/common/const.py +10 -0
  85. ads/opctl/operator/lowcode/common/data.py +96 -0
  86. ads/opctl/operator/lowcode/common/errors.py +41 -0
  87. ads/opctl/operator/lowcode/common/transformations.py +191 -0
  88. ads/opctl/operator/lowcode/common/utils.py +250 -0
  89. ads/opctl/operator/lowcode/forecast/README.md +3 -2
  90. ads/opctl/operator/lowcode/forecast/__main__.py +18 -2
  91. ads/opctl/operator/lowcode/forecast/cmd.py +8 -7
  92. ads/opctl/operator/lowcode/forecast/const.py +17 -1
  93. ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
  94. ads/opctl/operator/lowcode/forecast/model/arima.py +106 -117
  95. ads/opctl/operator/lowcode/forecast/model/automlx.py +204 -180
  96. ads/opctl/operator/lowcode/forecast/model/autots.py +144 -253
  97. ads/opctl/operator/lowcode/forecast/model/base_model.py +326 -259
  98. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +325 -176
  99. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +293 -237
  100. ads/opctl/operator/lowcode/forecast/model/prophet.py +191 -208
  101. ads/opctl/operator/lowcode/forecast/operator_config.py +24 -33
  102. ads/opctl/operator/lowcode/forecast/schema.yaml +116 -29
  103. ads/opctl/operator/lowcode/forecast/utils.py +186 -356
  104. ads/opctl/operator/lowcode/pii/model/guardrails.py +18 -15
  105. ads/opctl/operator/lowcode/pii/model/report.py +7 -7
  106. ads/opctl/operator/lowcode/pii/operator_config.py +1 -8
  107. ads/opctl/operator/lowcode/pii/utils.py +0 -82
  108. ads/opctl/operator/runtime/runtime.py +3 -2
  109. ads/telemetry/base.py +62 -0
  110. ads/telemetry/client.py +105 -0
  111. ads/telemetry/telemetry.py +6 -3
  112. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +44 -7
  113. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +116 -59
  114. ads/opctl/operator/lowcode/forecast/model/transformations.py +0 -125
  115. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
  116. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
  117. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
@@ -1,25 +1,42 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8 -*--
3
3
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
4
+ # Copyright (c) 2024 Oracle and/or its affiliates.
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
 
7
7
  import numpy as np
8
8
  import optuna
9
9
  import pandas as pd
10
+ import logging
11
+ from joblib import Parallel, delayed
10
12
  from ads.common.decorator.runtime_dependency import runtime_dependency
11
13
  from ads.opctl import logger
12
14
  from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorConfig
13
15
 
14
- from ..const import DEFAULT_TRIALS, PROPHET_INTERNAL_DATE_COL, ForecastOutputColumns
15
- from .. import utils
16
+ from ..const import (
17
+ DEFAULT_TRIALS,
18
+ PROPHET_INTERNAL_DATE_COL,
19
+ ForecastOutputColumns,
20
+ SupportedModels,
21
+ )
22
+ from ads.opctl.operator.lowcode.forecast.utils import (
23
+ _select_plot_list,
24
+ _label_encode_dataframe,
25
+ )
26
+ from ads.opctl.operator.lowcode.common.utils import set_log_level
16
27
  from .base_model import ForecastOperatorBaseModel
17
28
  from ..operator_config import ForecastOperatorConfig
18
29
  from .forecast_datasets import ForecastDatasets, ForecastOutput
19
30
  import traceback
20
31
  import matplotlib as mpl
21
32
 
22
- mpl.rcParams["figure.max_open_warning"] = 100
33
+
34
+ try:
35
+ set_log_level("prophet", logger.level)
36
+ set_log_level("cmdstanpy", logger.level)
37
+ mpl.rcParams["figure.max_open_warning"] = 100
38
+ except:
39
+ pass
23
40
 
24
41
 
25
42
  def _add_unit(num, unit):
@@ -41,150 +58,45 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
41
58
 
42
59
  def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
43
60
  super().__init__(config=config, datasets=datasets)
44
- self.train_metrics = True
45
61
  self.global_explanation = {}
46
62
  self.local_explanation = {}
47
63
 
48
- def _build_model(self) -> pd.DataFrame:
49
- from prophet import Prophet
50
- from prophet.diagnostics import cross_validation, performance_metrics
51
-
52
- full_data_dict = self.datasets.full_data_dict
53
- models = []
54
- outputs = dict()
55
- outputs_legacy = []
56
-
64
+ def set_kwargs(self):
57
65
  # Extract the Confidence Interval Width and convert to prophet's equivalent - interval_width
58
66
  if self.spec.confidence_interval_width is None:
59
67
  self.spec.confidence_interval_width = 1 - self.spec.model_kwargs.get(
60
68
  "alpha", 0.90
61
69
  )
62
-
63
70
  model_kwargs = self.spec.model_kwargs
64
71
  model_kwargs["interval_width"] = self.spec.confidence_interval_width
72
+ return model_kwargs
65
73
 
66
- self.forecast_output = ForecastOutput(
67
- confidence_interval_width=self.spec.confidence_interval_width
68
- )
69
-
70
- for i, (target, df) in enumerate(full_data_dict.items()):
71
- le, df_encoded = utils._label_encode_dataframe(
72
- df, no_encode={self.spec.datetime_column.name, target}
73
- )
74
+ def _train_model(self, i, series_id, df, model_kwargs):
75
+ try:
76
+ from prophet import Prophet
77
+ from prophet.diagnostics import cross_validation, performance_metrics
74
78
 
75
- model_kwargs_i = model_kwargs.copy()
76
- # format the dataframe for this target. Dropping NA on target[df] will remove all future data
77
- df_clean = self._preprocess(
78
- df_encoded,
79
- self.spec.datetime_column.name,
80
- self.spec.datetime_column.format,
79
+ self.forecast_output.init_series_output(
80
+ series_id=series_id, data_at_series=df
81
81
  )
82
- data_i = df_clean[df_clean[target].notna()]
83
- data_i.rename({target: "y"}, axis=1, inplace=True)
84
-
85
- # Assume that all columns passed in should be used as additional data
86
- additional_regressors = set(data_i.columns) - {
87
- "y",
88
- PROPHET_INTERNAL_DATE_COL,
89
- }
90
-
91
- if self.perform_tuning:
92
-
93
- def objective(trial):
94
- params = {
95
- "seasonality_mode": trial.suggest_categorical(
96
- "seasonality_mode", ["additive", "multiplicative"]
97
- ),
98
- "changepoint_prior_scale": trial.suggest_float(
99
- "changepoint_prior_scale", 0.001, 0.5, log=True
100
- ),
101
- "seasonality_prior_scale": trial.suggest_float(
102
- "seasonality_prior_scale", 0.01, 10, log=True
103
- ),
104
- "holidays_prior_scale": trial.suggest_float(
105
- "holidays_prior_scale", 0.01, 10, log=True
106
- ),
107
- "changepoint_range": trial.suggest_float(
108
- "changepoint_range", 0.8, 0.95
109
- ),
110
- }
111
- params.update(model_kwargs_i)
112
-
113
- model = _fit_model(
114
- data=data_i,
115
- params=params,
116
- additional_regressors=additional_regressors,
117
- )
118
82
 
119
- # Manual workaround because pandas 1.x dropped support for M and Y
120
- interval = self.spec.horizon.interval
121
- unit = self.spec.horizon.interval_unit
122
- if unit == "M":
123
- unit = "D"
124
- interval = interval * 30.5
125
- elif unit == "Y":
126
- unit = "D"
127
- interval = interval * 365.25
128
- horizon = _add_unit(int(self.spec.horizon * interval), unit=unit)
129
- initial = _add_unit((data_i.shape[0] * interval) // 2, unit=unit)
130
- period = _add_unit((data_i.shape[0] * interval) // 4, unit=unit)
131
-
132
- logger.debug(
133
- f"using: horizon: {horizon}. initial:{initial}, period: {period}"
134
- )
83
+ data = self.preprocess(df, series_id)
84
+ data_i = self.drop_horizon(data)
85
+ if self.loaded_models is not None:
86
+ model = self.loaded_models[series_id]
87
+ else:
88
+ if self.perform_tuning:
89
+ model_kwargs = self.run_tuning(data_i, model_kwargs)
135
90
 
136
- df_cv = cross_validation(
137
- model,
138
- horizon=horizon,
139
- initial=initial,
140
- period=period,
141
- parallel="threads",
142
- )
143
- df_p = performance_metrics(df_cv)
144
- try:
145
- return np.mean(df_p[self.spec.metric])
146
- except KeyError:
147
- logger.warn(
148
- f"Could not find the metric {self.spec.metric} within "
149
- f"the performance metrics: {df_p.columns}. Defaulting to `rmse`"
150
- )
151
- return np.mean(df_p["rmse"])
152
-
153
- study = optuna.create_study(direction="minimize")
154
- m_temp = Prophet()
155
- study.enqueue_trial(
156
- {
157
- "seasonality_mode": m_temp.seasonality_mode,
158
- "changepoint_prior_scale": m_temp.changepoint_prior_scale,
159
- "seasonality_prior_scale": m_temp.seasonality_prior_scale,
160
- "holidays_prior_scale": m_temp.holidays_prior_scale,
161
- "changepoint_range": m_temp.changepoint_range,
162
- }
163
- )
164
- study.optimize(
165
- objective,
166
- n_trials=self.spec.tuning.n_trials
167
- if self.spec.tuning
168
- else DEFAULT_TRIALS,
169
- n_jobs=-1,
91
+ model = _fit_model(
92
+ data=data,
93
+ params=model_kwargs,
94
+ additional_regressors=self.additional_regressors,
170
95
  )
171
96
 
172
- study.best_params.update(model_kwargs_i)
173
- model_kwargs_i = study.best_params
174
- model = _fit_model(
175
- data=data_i,
176
- params=model_kwargs_i,
177
- additional_regressors=additional_regressors,
178
- )
97
+ # Get future df for prediction
98
+ future = data.drop("y", axis=1)
179
99
 
180
- # Make future df for prediction
181
- if len(additional_regressors):
182
- future = df_clean.drop(target, axis=1)
183
- else:
184
- future = model.make_future_dataframe(
185
- periods=self.spec.horizon,
186
- freq=self.spec.freq,
187
- )
188
100
  # Make Prediction
189
101
  forecast = model.predict(future)
190
102
  logger.debug(f"-----------------Model {i}----------------------")
@@ -194,107 +106,186 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
194
106
  ].tail()
195
107
  )
196
108
 
197
- # Collect Outputs
198
- models.append(model)
199
- outputs[target] = forecast
200
- outputs_legacy.append(forecast)
201
-
202
- self.models = models
203
- self.outputs = outputs_legacy
204
-
205
- logger.debug("===========Done===========")
206
-
207
- # Merge the outputs from each model into 1 df with all outputs by target and category
208
- col = self.original_target_column
209
- output_col = pd.DataFrame()
210
- yhat_upper_name = ForecastOutputColumns.UPPER_BOUND
211
- yhat_lower_name = ForecastOutputColumns.LOWER_BOUND
212
- for cat in self.categories:
213
- output_i = pd.DataFrame()
214
-
215
- output_i["Date"] = outputs[f"{col}_{cat}"][PROPHET_INTERNAL_DATE_COL]
216
- output_i["Series"] = cat
217
- output_i["input_value"] = full_data_dict[f"{col}_{cat}"][f"{col}_{cat}"]
218
-
219
- output_i[f"fitted_value"] = float("nan")
220
- output_i[f"forecast_value"] = float("nan")
221
- output_i[yhat_upper_name] = float("nan")
222
- output_i[yhat_lower_name] = float("nan")
223
-
224
- output_i.iloc[
225
- : -self.spec.horizon, output_i.columns.get_loc(f"fitted_value")
226
- ] = (outputs[f"{col}_{cat}"]["yhat"].iloc[: -self.spec.horizon].values)
227
- output_i.iloc[
228
- -self.spec.horizon :,
229
- output_i.columns.get_loc(f"forecast_value"),
230
- ] = (
231
- outputs[f"{col}_{cat}"]["yhat"].iloc[-self.spec.horizon :].values
109
+ self.outputs[series_id] = forecast
110
+ self.forecast_output.populate_series_output(
111
+ series_id=series_id,
112
+ fit_val=self.drop_horizon(forecast["yhat"]).values,
113
+ forecast_val=self.get_horizon(forecast["yhat"]).values,
114
+ upper_bound=self.get_horizon(forecast["yhat_upper"]).values,
115
+ lower_bound=self.get_horizon(forecast["yhat_lower"]).values,
232
116
  )
233
- output_i.iloc[
234
- -self.spec.horizon :, output_i.columns.get_loc(yhat_upper_name)
235
- ] = (
236
- outputs[f"{col}_{cat}"]["yhat_upper"].iloc[-self.spec.horizon :].values
117
+ self.models[series_id] = model
118
+
119
+ params = vars(model).copy()
120
+ for param in ["history", "history_dates", "stan_fit"]:
121
+ if param in params:
122
+ params.pop(param)
123
+ self.model_parameters[series_id] = {
124
+ "framework": SupportedModels.Prophet,
125
+ **params,
126
+ }
127
+
128
+ logger.debug("===========Done===========")
129
+ except Exception as e:
130
+ self.errors_dict[series_id] = {
131
+ "model_name": self.spec.model,
132
+ "error": str(e),
133
+ }
134
+
135
+ def _build_model(self) -> pd.DataFrame:
136
+ from prophet import Prophet
137
+ from prophet.diagnostics import cross_validation, performance_metrics
138
+
139
+ full_data_dict = self.datasets.get_data_by_series()
140
+ self.models = dict()
141
+ self.outputs = dict()
142
+ self.additional_regressors = self.datasets.get_additional_data_column_names()
143
+ model_kwargs = self.set_kwargs()
144
+ self.forecast_output = ForecastOutput(
145
+ confidence_interval_width=self.spec.confidence_interval_width,
146
+ horizon=self.spec.horizon,
147
+ target_column=self.original_target_column,
148
+ dt_column=self.spec.datetime_column.name,
149
+ )
150
+
151
+ Parallel(n_jobs=-1, require="sharedmem")(
152
+ delayed(ProphetOperatorModel._train_model)(
153
+ self, i, series_id, df, model_kwargs.copy()
237
154
  )
238
- output_i.iloc[
239
- -self.spec.horizon :, output_i.columns.get_loc(yhat_lower_name)
240
- ] = (
241
- outputs[f"{col}_{cat}"]["yhat_lower"].iloc[-self.spec.horizon :].values
155
+ for self, (i, (series_id, df)) in zip(
156
+ [self] * len(full_data_dict), enumerate(full_data_dict.items())
242
157
  )
243
- output_col = pd.concat([output_col, output_i])
244
- self.forecast_output.add_category(
245
- category=cat, target_category_column=f"{col}_{cat}", forecast=output_i
158
+ )
159
+
160
+ return self.forecast_output.get_forecast_long()
161
+
162
+ def run_tuning(self, data_i, model_kwargs_i):
163
+ def objective(trial):
164
+ params = {
165
+ "seasonality_mode": trial.suggest_categorical(
166
+ "seasonality_mode", ["additive", "multiplicative"]
167
+ ),
168
+ "changepoint_prior_scale": trial.suggest_float(
169
+ "changepoint_prior_scale", 0.001, 0.5, log=True
170
+ ),
171
+ "seasonality_prior_scale": trial.suggest_float(
172
+ "seasonality_prior_scale", 0.01, 10, log=True
173
+ ),
174
+ "holidays_prior_scale": trial.suggest_float(
175
+ "holidays_prior_scale", 0.01, 10, log=True
176
+ ),
177
+ "changepoint_range": trial.suggest_float(
178
+ "changepoint_range", 0.8, 0.95
179
+ ),
180
+ }
181
+ params.update(model_kwargs_i)
182
+
183
+ model = _fit_model(
184
+ data=data_i,
185
+ params=params,
186
+ additional_regressors=self.additional_regressors,
246
187
  )
247
188
 
248
- output_col = output_col.reset_index(drop=True)
189
+ # Manual workaround because pandas 1.x dropped support for M and Y
190
+ interval = self.spec.horizon
191
+ freq = self.datasets.get_datetime_frequency()
192
+ unit = freq.split("-")[0] if freq else None
193
+ if unit == "M":
194
+ unit = "D"
195
+ interval = interval * 30.5
196
+ elif unit == "Y":
197
+ unit = "D"
198
+ interval = interval * 365.25
199
+ horizon = _add_unit(int(self.spec.horizon * interval), unit=unit)
200
+ initial = _add_unit((data_i.shape[0] * interval) // 2, unit=unit)
201
+ period = _add_unit((data_i.shape[0] * interval) // 4, unit=unit)
249
202
 
250
- return output_col
203
+ logger.debug(
204
+ f"using: horizon: {horizon}. initial:{initial}, period: {period}"
205
+ )
206
+
207
+ df_cv = cross_validation(
208
+ model,
209
+ horizon=horizon,
210
+ initial=initial,
211
+ period=period,
212
+ parallel="threads",
213
+ )
214
+ df_p = performance_metrics(df_cv)
215
+ try:
216
+ return np.mean(df_p[self.spec.metric])
217
+ except KeyError:
218
+ logger.warn(
219
+ f"Could not find the metric {self.spec.metric} within "
220
+ f"the performance metrics: {df_p.columns}. Defaulting to `rmse`"
221
+ )
222
+ return np.mean(df_p["rmse"])
223
+
224
+ study = optuna.create_study(direction="minimize")
225
+ m_temp = Prophet()
226
+ study.enqueue_trial(
227
+ {
228
+ "seasonality_mode": m_temp.seasonality_mode,
229
+ "changepoint_prior_scale": m_temp.changepoint_prior_scale,
230
+ "seasonality_prior_scale": m_temp.seasonality_prior_scale,
231
+ "holidays_prior_scale": m_temp.holidays_prior_scale,
232
+ "changepoint_range": m_temp.changepoint_range,
233
+ }
234
+ )
235
+ study.optimize(
236
+ objective,
237
+ n_trials=self.spec.tuning.n_trials if self.spec.tuning else DEFAULT_TRIALS,
238
+ n_jobs=-1,
239
+ )
240
+
241
+ study.best_params.update(model_kwargs_i)
242
+ model_kwargs_i = study.best_params
243
+ return model_kwargs_i
251
244
 
252
245
  def _generate_report(self):
253
246
  import datapane as dp
254
247
  from prophet.plot import add_changepoints_to_plot
255
248
 
249
+ series_ids = self.datasets.list_series_ids()
250
+
256
251
  sec1_text = dp.Text(
257
252
  "## Forecast Overview \n"
258
253
  "These plots show your forecast in the context of historical data."
259
254
  )
260
- sec1 = utils._select_plot_list(
261
- lambda idx, *args: self.models[idx].plot(
262
- self.outputs[idx], include_legend=True
255
+ sec1 = _select_plot_list(
256
+ lambda s_id: self.models[s_id].plot(
257
+ self.outputs[s_id], include_legend=True
263
258
  ),
264
- target_columns=self.target_columns,
259
+ series_ids=series_ids,
265
260
  )
266
261
 
267
262
  sec2_text = dp.Text(f"## Forecast Broken Down by Trend Component")
268
- sec2 = utils._select_plot_list(
269
- lambda idx, *args: self.models[idx].plot_components(self.outputs[idx]),
270
- target_columns=self.target_columns,
263
+ sec2 = _select_plot_list(
264
+ lambda s_id: self.models[s_id].plot_components(self.outputs[s_id]),
265
+ series_ids=series_ids,
271
266
  )
272
267
 
273
268
  sec3_text = dp.Text(f"## Forecast Changepoints")
274
- sec3_figs = [
275
- self.models[idx].plot(self.outputs[idx])
276
- for idx in range(len(self.target_columns))
277
- ]
278
- [
269
+ sec3_figs = {
270
+ s_id: self.models[s_id].plot(self.outputs[s_id]) for s_id in series_ids
271
+ }
272
+ for s_id in series_ids:
279
273
  add_changepoints_to_plot(
280
- sec3_figs[idx].gca(), self.models[idx], self.outputs[idx]
274
+ sec3_figs[s_id].gca(), self.models[s_id], self.outputs[s_id]
281
275
  )
282
- for idx in range(len(self.target_columns))
283
- ]
284
- sec3 = utils._select_plot_list(
285
- lambda idx, *args: sec3_figs[idx], target_columns=self.target_columns
286
- )
276
+ sec3 = _select_plot_list(lambda s_id: sec3_figs[s_id], series_ids=series_ids)
287
277
 
288
278
  all_sections = [sec1_text, sec1, sec2_text, sec2, sec3_text, sec3]
289
279
 
290
280
  sec5_text = dp.Text(f"## Prophet Model Seasonality Components")
291
281
  model_states = []
292
- for i, m in enumerate(self.models):
282
+ for s_id in series_ids:
283
+ m = self.models[s_id]
293
284
  model_states.append(
294
285
  pd.Series(
295
286
  m.seasonalities,
296
287
  index=pd.Index(m.seasonalities.keys(), dtype="object"),
297
- name=self.target_columns[i],
288
+ name=s_id,
298
289
  dtype="object",
299
290
  )
300
291
  )
@@ -306,10 +297,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
306
297
  if self.spec.generate_explanations:
307
298
  try:
308
299
  # If the key is present, call the "explain_model" method
309
- self.explain_model(
310
- datetime_col_name=PROPHET_INTERNAL_DATE_COL,
311
- explain_predict_fn=self._custom_predict_prophet,
312
- )
300
+ self.explain_model()
313
301
 
314
302
  # Create a markdown text block for the global explanation section
315
303
  global_explanation_text = dp.Text(
@@ -333,7 +321,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
333
321
  aggregate_local_explanations = pd.DataFrame()
334
322
  for s_id, local_ex_df in self.local_explanation.items():
335
323
  local_ex_df_copy = local_ex_df.copy()
336
- local_ex_df_copy["Series"] = s_id
324
+ local_ex_df_copy[ForecastOutputColumns.SERIES] = s_id
337
325
  aggregate_local_explanations = pd.concat(
338
326
  [aggregate_local_explanations, local_ex_df_copy], axis=0
339
327
  )
@@ -376,8 +364,3 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
376
364
  model_description,
377
365
  other_sections,
378
366
  )
379
-
380
- def _custom_predict_prophet(self, data):
381
- return self.models[self.target_columns.index(self.series_id)].predict(
382
- data.reset_index()
383
- )["yhat"]
@@ -10,44 +10,16 @@ from typing import Dict, List
10
10
 
11
11
  from ads.common.serializer import DataClassSerializable
12
12
  from ads.opctl.operator.common.utils import _load_yaml_from_uri
13
- from ads.opctl.operator.common.operator_config import OperatorConfig
13
+ from ads.opctl.operator.common.operator_config import OperatorConfig, OutputDirectory, InputData
14
14
 
15
- from .const import SupportedMetrics
15
+ from .const import SupportedMetrics, SpeedAccuracyMode
16
16
  from .const import SupportedModels
17
17
 
18
- @dataclass(repr=True)
19
- class InputData(DataClassSerializable):
20
- """Class representing operator specification input data details."""
21
-
22
- format: str = None
23
- columns: List[str] = None
24
- url: str = None
25
- options: Dict = None
26
- limit: int = None
27
-
28
18
 
29
19
  @dataclass(repr=True)
30
- class TestData(DataClassSerializable):
20
+ class TestData(InputData):
31
21
  """Class representing operator specification test data details."""
32
22
 
33
- connect_args: Dict = None
34
- format: str = None
35
- columns: List[str] = None
36
- url: str = None
37
- name: str = None
38
- options: Dict = None
39
-
40
-
41
- @dataclass(repr=True)
42
- class OutputDirectory(DataClassSerializable):
43
- """Class representing operator specification output directory details."""
44
-
45
- connect_args: Dict = None
46
- format: str = None
47
- url: str = None
48
- name: str = None
49
- options: Dict = None
50
-
51
23
 
52
24
  @dataclass(repr=True)
53
25
  class DateTimeColumn(DataClassSerializable):
@@ -88,10 +60,14 @@ class ForecastOperatorSpec(DataClassSerializable):
88
60
  generate_report: bool = None
89
61
  generate_metrics: bool = None
90
62
  generate_explanations: bool = None
63
+ explanations_accuracy_mode: str = None
91
64
  horizon: int = None
92
- freq: str = None
93
65
  model: str = None
94
66
  model_kwargs: Dict = field(default_factory=dict)
67
+ model_parameters: str = None
68
+ previous_output_dir: str = None
69
+ generate_model_parameters: bool = None
70
+ generate_model_pickle: bool = None
95
71
  confidence_interval_width: float = None
96
72
  metric: str = None
97
73
  tuning: Tuning = field(default_factory=Tuning)
@@ -99,7 +75,7 @@ class ForecastOperatorSpec(DataClassSerializable):
99
75
  def __post_init__(self):
100
76
  """Adjusts the specification details."""
101
77
  self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower()
102
- self.model = (self.model or SupportedModels.Auto)
78
+ self.model = self.model or SupportedModels.Auto
103
79
  self.confidence_interval_width = self.confidence_interval_width or 0.80
104
80
  self.report_filename = self.report_filename or "report.html"
105
81
  self.preprocessing = (
@@ -119,6 +95,20 @@ class ForecastOperatorSpec(DataClassSerializable):
119
95
  if self.generate_explanations is not None
120
96
  else False
121
97
  )
98
+ self.explanations_accuracy_mode = (
99
+ self.explanations_accuracy_mode or SpeedAccuracyMode.FAST_APPROXIMATE
100
+ )
101
+
102
+ self.generate_model_parameters = (
103
+ self.generate_model_parameters
104
+ if self.generate_model_parameters is not None
105
+ else False
106
+ )
107
+ self.generate_model_pickle = (
108
+ self.generate_model_pickle
109
+ if self.generate_model_pickle is not None
110
+ else False
111
+ )
122
112
  self.report_theme = self.report_theme or "light"
123
113
  self.metrics_filename = self.metrics_filename or "metrics.csv"
124
114
  self.test_metrics_filename = self.test_metrics_filename or "test_metrics.csv"
@@ -130,6 +120,7 @@ class ForecastOperatorSpec(DataClassSerializable):
130
120
  self.local_explanation_filename or "local_explanation.csv"
131
121
  )
132
122
  self.target_column = self.target_column or "Sales"
123
+ self.errors_dict_filename = "errors.json"
133
124
  self.model_kwargs = self.model_kwargs or dict()
134
125
 
135
126