oracle-ads 2.10.0__py3-none-any.whl → 2.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. ads/aqua/__init__.py +12 -0
  2. ads/aqua/base.py +324 -0
  3. ads/aqua/cli.py +19 -0
  4. ads/aqua/config/deployment_config_defaults.json +9 -0
  5. ads/aqua/config/resource_limit_names.json +7 -0
  6. ads/aqua/constants.py +45 -0
  7. ads/aqua/data.py +40 -0
  8. ads/aqua/decorator.py +101 -0
  9. ads/aqua/deployment.py +643 -0
  10. ads/aqua/dummy_data/icon.txt +1 -0
  11. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  12. ads/aqua/dummy_data/oci_models.json +1 -0
  13. ads/aqua/dummy_data/readme.md +26 -0
  14. ads/aqua/evaluation.py +1751 -0
  15. ads/aqua/exception.py +82 -0
  16. ads/aqua/extension/__init__.py +40 -0
  17. ads/aqua/extension/base_handler.py +138 -0
  18. ads/aqua/extension/common_handler.py +21 -0
  19. ads/aqua/extension/deployment_handler.py +202 -0
  20. ads/aqua/extension/evaluation_handler.py +135 -0
  21. ads/aqua/extension/finetune_handler.py +66 -0
  22. ads/aqua/extension/model_handler.py +59 -0
  23. ads/aqua/extension/ui_handler.py +201 -0
  24. ads/aqua/extension/utils.py +23 -0
  25. ads/aqua/finetune.py +579 -0
  26. ads/aqua/job.py +29 -0
  27. ads/aqua/model.py +819 -0
  28. ads/aqua/training/__init__.py +4 -0
  29. ads/aqua/training/exceptions.py +459 -0
  30. ads/aqua/ui.py +453 -0
  31. ads/aqua/utils.py +715 -0
  32. ads/cli.py +37 -6
  33. ads/common/auth.py +7 -0
  34. ads/common/decorator/__init__.py +7 -3
  35. ads/common/decorator/require_nonempty_arg.py +65 -0
  36. ads/common/object_storage_details.py +166 -7
  37. ads/common/oci_client.py +18 -1
  38. ads/common/oci_logging.py +2 -2
  39. ads/common/oci_mixin.py +4 -5
  40. ads/common/serializer.py +34 -5
  41. ads/common/utils.py +75 -10
  42. ads/config.py +40 -1
  43. ads/dataset/correlation_plot.py +10 -12
  44. ads/jobs/ads_job.py +43 -25
  45. ads/jobs/builders/infrastructure/base.py +4 -2
  46. ads/jobs/builders/infrastructure/dsc_job.py +49 -39
  47. ads/jobs/builders/runtimes/base.py +71 -1
  48. ads/jobs/builders/runtimes/container_runtime.py +4 -4
  49. ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
  50. ads/jobs/templates/driver_pytorch.py +27 -10
  51. ads/model/artifact_downloader.py +84 -14
  52. ads/model/artifact_uploader.py +25 -23
  53. ads/model/datascience_model.py +388 -38
  54. ads/model/deployment/model_deployment.py +10 -2
  55. ads/model/generic_model.py +8 -0
  56. ads/model/model_file_description_schema.json +68 -0
  57. ads/model/model_metadata.py +1 -1
  58. ads/model/service/oci_datascience_model.py +34 -5
  59. ads/opctl/config/merger.py +2 -2
  60. ads/opctl/operator/__init__.py +3 -1
  61. ads/opctl/operator/cli.py +7 -1
  62. ads/opctl/operator/cmd.py +3 -3
  63. ads/opctl/operator/common/errors.py +2 -1
  64. ads/opctl/operator/common/operator_config.py +22 -3
  65. ads/opctl/operator/common/utils.py +16 -0
  66. ads/opctl/operator/lowcode/anomaly/MLoperator +15 -0
  67. ads/opctl/operator/lowcode/anomaly/README.md +209 -0
  68. ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
  69. ads/opctl/operator/lowcode/anomaly/__main__.py +104 -0
  70. ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
  71. ads/opctl/operator/lowcode/anomaly/const.py +88 -0
  72. ads/opctl/operator/lowcode/anomaly/environment.yaml +12 -0
  73. ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
  74. ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +147 -0
  75. ads/opctl/operator/lowcode/anomaly/model/automlx.py +89 -0
  76. ads/opctl/operator/lowcode/anomaly/model/autots.py +103 -0
  77. ads/opctl/operator/lowcode/anomaly/model/base_model.py +354 -0
  78. ads/opctl/operator/lowcode/anomaly/model/factory.py +67 -0
  79. ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
  80. ads/opctl/operator/lowcode/anomaly/operator_config.py +105 -0
  81. ads/opctl/operator/lowcode/anomaly/schema.yaml +359 -0
  82. ads/opctl/operator/lowcode/anomaly/utils.py +81 -0
  83. ads/opctl/operator/lowcode/common/__init__.py +5 -0
  84. ads/opctl/operator/lowcode/common/const.py +10 -0
  85. ads/opctl/operator/lowcode/common/data.py +96 -0
  86. ads/opctl/operator/lowcode/common/errors.py +41 -0
  87. ads/opctl/operator/lowcode/common/transformations.py +191 -0
  88. ads/opctl/operator/lowcode/common/utils.py +250 -0
  89. ads/opctl/operator/lowcode/forecast/README.md +3 -2
  90. ads/opctl/operator/lowcode/forecast/__main__.py +18 -2
  91. ads/opctl/operator/lowcode/forecast/cmd.py +8 -7
  92. ads/opctl/operator/lowcode/forecast/const.py +17 -1
  93. ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
  94. ads/opctl/operator/lowcode/forecast/model/arima.py +106 -117
  95. ads/opctl/operator/lowcode/forecast/model/automlx.py +204 -180
  96. ads/opctl/operator/lowcode/forecast/model/autots.py +144 -253
  97. ads/opctl/operator/lowcode/forecast/model/base_model.py +326 -259
  98. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +325 -176
  99. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +293 -237
  100. ads/opctl/operator/lowcode/forecast/model/prophet.py +191 -208
  101. ads/opctl/operator/lowcode/forecast/operator_config.py +24 -33
  102. ads/opctl/operator/lowcode/forecast/schema.yaml +116 -29
  103. ads/opctl/operator/lowcode/forecast/utils.py +186 -356
  104. ads/opctl/operator/lowcode/pii/model/guardrails.py +18 -15
  105. ads/opctl/operator/lowcode/pii/model/report.py +7 -7
  106. ads/opctl/operator/lowcode/pii/operator_config.py +1 -8
  107. ads/opctl/operator/lowcode/pii/utils.py +0 -82
  108. ads/opctl/operator/runtime/runtime.py +3 -2
  109. ads/telemetry/base.py +62 -0
  110. ads/telemetry/client.py +105 -0
  111. ads/telemetry/telemetry.py +6 -3
  112. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +44 -7
  113. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +116 -59
  114. ads/opctl/operator/lowcode/forecast/model/transformations.py +0 -125
  115. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
  116. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
  117. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
@@ -11,13 +11,17 @@ from ads.common.decorator.runtime_dependency import runtime_dependency
11
11
  from ads.opctl.operator.lowcode.forecast.const import (
12
12
  AUTOMLX_METRIC_MAP,
13
13
  ForecastOutputColumns,
14
+ SupportedModels,
14
15
  )
15
16
  from ads.opctl import logger
16
17
 
17
- from .. import utils
18
18
  from .base_model import ForecastOperatorBaseModel
19
19
  from ..operator_config import ForecastOperatorConfig
20
20
  from .forecast_datasets import ForecastDatasets, ForecastOutput
21
+ from ads.opctl.operator.lowcode.common.utils import (
22
+ seconds_to_datetime,
23
+ datetime_to_seconds,
24
+ )
21
25
 
22
26
  AUTOMLX_N_ALGOS_TUNED = 4
23
27
  AUTOMLX_DEFAULT_SCORE_METRIC = "neg_sym_mean_abs_percent_error"
@@ -30,12 +34,32 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
30
34
  super().__init__(config, datasets)
31
35
  self.global_explanation = {}
32
36
  self.local_explanation = {}
33
- self.train_metrics = True
37
+
38
+ def set_kwargs(self):
39
+ model_kwargs_cleaned = self.spec.model_kwargs
40
+ model_kwargs_cleaned["n_algos_tuned"] = model_kwargs_cleaned.get(
41
+ "n_algos_tuned", AUTOMLX_N_ALGOS_TUNED
42
+ )
43
+ model_kwargs_cleaned["score_metric"] = AUTOMLX_METRIC_MAP.get(
44
+ self.spec.metric,
45
+ model_kwargs_cleaned.get("score_metric", AUTOMLX_DEFAULT_SCORE_METRIC),
46
+ )
47
+ model_kwargs_cleaned.pop("task", None)
48
+ time_budget = model_kwargs_cleaned.pop("time_budget", -1)
49
+ model_kwargs_cleaned[
50
+ "preprocessing"
51
+ ] = self.spec.preprocessing or model_kwargs_cleaned.get("preprocessing", True)
52
+ return model_kwargs_cleaned, time_budget
53
+
54
+ def preprocess(self, data, series_id=None):
55
+ return data.set_index(self.spec.datetime_column.name)
34
56
 
35
57
  @runtime_dependency(
36
- module="automl",
58
+ module="automlx",
37
59
  err_msg=(
38
- "Please run `pip3 install oracle-automlx==23.2.3` to install the required dependencies for automlx."
60
+ "Please run `pip3 install oracle-automlx==23.4.1` and "
61
+ "`pip3 install oracle-automlx[forecasting]==23.4.1` "
62
+ "to install the required dependencies for automlx."
39
63
  ),
40
64
  )
41
65
  @runtime_dependency(
@@ -45,145 +69,99 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
45
69
  ),
46
70
  )
47
71
  def _build_model(self) -> pd.DataFrame:
48
- from automl import init
72
+ from automlx import init
49
73
  from sktime.forecasting.model_selection import temporal_train_test_split
74
+ try:
75
+ init(engine="ray", engine_opts={"ray_setup": {"_temp_dir": "/tmp/ray-temp"}})
76
+ except Exception as e:
77
+ logger.info("Ray already initialized")
50
78
 
51
- init(engine="local", check_deprecation_warnings=False)
52
79
 
53
- full_data_dict = self.datasets.full_data_dict
80
+ full_data_dict = self.datasets.get_data_by_series()
54
81
 
55
- models = dict()
56
- outputs = dict()
57
- outputs_legacy = dict()
58
- selected_models = dict()
82
+ self.models = dict()
59
83
  date_column = self.spec.datetime_column.name
60
84
  horizon = self.spec.horizon
61
- self.datasets.datetime_col = date_column
62
85
  self.spec.confidence_interval_width = self.spec.confidence_interval_width or 0.8
63
86
  self.forecast_output = ForecastOutput(
64
- confidence_interval_width=self.spec.confidence_interval_width
87
+ confidence_interval_width=self.spec.confidence_interval_width,
88
+ horizon=self.spec.horizon,
89
+ target_column=self.original_target_column,
90
+ dt_column=self.spec.datetime_column.name,
65
91
  )
66
92
 
67
93
  # Clean up kwargs for pass through
68
- model_kwargs_cleaned = self.spec.model_kwargs.copy()
69
- model_kwargs_cleaned["n_algos_tuned"] = model_kwargs_cleaned.get(
70
- "n_algos_tuned", AUTOMLX_N_ALGOS_TUNED
71
- )
72
- model_kwargs_cleaned["score_metric"] = AUTOMLX_METRIC_MAP.get(
73
- self.spec.metric,
74
- model_kwargs_cleaned.get("score_metric", AUTOMLX_DEFAULT_SCORE_METRIC),
75
- )
76
- model_kwargs_cleaned.pop("task", None)
77
- time_budget = model_kwargs_cleaned.pop("time_budget", 0)
78
- model_kwargs_cleaned[
79
- "preprocessing"
80
- ] = self.spec.preprocessing or model_kwargs_cleaned.get("preprocessing", True)
94
+ model_kwargs_cleaned, time_budget = self.set_kwargs()
81
95
 
82
- for i, (target, df) in enumerate(full_data_dict.items()):
83
- logger.debug("Running automl for {} at position {}".format(target, i))
84
- series_values = df[df[target].notna()]
85
- # drop NaNs for the time period where data wasn't recorded
86
- series_values.dropna(inplace=True)
87
- df[date_column] = pd.to_datetime(
88
- df[date_column], format=self.spec.datetime_column.format
89
- )
90
- df = df.set_index(date_column)
91
- # if len(df.columns) > 1:
92
- # when additional columns are present
93
- y_train, y_test = temporal_train_test_split(df, test_size=horizon)
94
- forecast_x = y_test.drop(target, axis=1)
95
- # else:
96
- # y_train = df
97
- # forecast_x = None
98
- logger.debug(
99
- "Time Index is" + ""
100
- if y_train.index.is_monotonic
101
- else "NOT" + "monotonic."
102
- )
103
- model = automl.Pipeline(
104
- task="forecasting",
105
- **model_kwargs_cleaned,
106
- )
107
- model.fit(
108
- X=y_train.drop(target, axis=1),
109
- y=pd.DataFrame(y_train[target]),
110
- time_budget=time_budget,
111
- )
112
- logger.debug("Selected model: {}".format(model.selected_model_))
113
- logger.debug(
114
- "Selected model params: {}".format(model.selected_model_params_)
115
- )
116
- summary_frame = model.forecast(
117
- X=forecast_x,
118
- periods=horizon,
119
- alpha=1 - (self.spec.confidence_interval_width / 100),
120
- )
121
- input_values = pd.Series(
122
- y_train[target].values,
123
- name="input_value",
124
- index=y_train.index,
125
- )
126
- fitted_values_raw = model.predict(y_train.drop(target, axis=1))
127
- fitted_values = pd.Series(
128
- fitted_values_raw[target].values,
129
- name="fitted_value",
130
- index=y_train.index,
131
- )
96
+ for i, (s_id, df) in enumerate(full_data_dict.items()):
97
+ try:
98
+ logger.debug(f"Running automlx on series {s_id}")
99
+ model_kwargs = model_kwargs_cleaned.copy()
100
+ target = self.original_target_column
101
+ self.forecast_output.init_series_output(
102
+ series_id=s_id, data_at_series=df
103
+ )
104
+ data = self.preprocess(df)
105
+ data_i = self.drop_horizon(data)
106
+ X_pred = self.get_horizon(data).drop(target, axis=1)
107
+
108
+ logger.debug(f"Time Index Monotonic: {data_i.index.is_monotonic}")
109
+
110
+ if self.loaded_models is not None:
111
+ model = self.loaded_models[s_id]
112
+ else:
113
+ model = automlx.Pipeline(
114
+ task="forecasting",
115
+ **model_kwargs,
116
+ )
117
+ model.fit(
118
+ X=data_i.drop(target, axis=1),
119
+ y=data_i[[target]],
120
+ time_budget=time_budget,
121
+ )
122
+ logger.debug(f"Selected model: {model.selected_model_}")
123
+ logger.debug(f"Selected model params: {model.selected_model_params_}")
124
+ summary_frame = model.forecast(
125
+ X=X_pred,
126
+ periods=horizon,
127
+ alpha=1 - (self.spec.confidence_interval_width / 100),
128
+ )
132
129
 
133
- summary_frame = pd.concat(
134
- [input_values, fitted_values, summary_frame], axis=1
135
- )
130
+ fitted_values = model.predict(data_i.drop(target, axis=1))[
131
+ target
132
+ ].values
136
133
 
137
- # Collect Outputs
138
- selected_models[target] = {
139
- "series_id": target,
140
- "selected_model": model.selected_model_,
141
- "model_params": model.selected_model_params_,
142
- }
143
- models[target] = model
144
- summary_frame = summary_frame.rename_axis("ds").reset_index()
145
- summary_frame = summary_frame.rename(
146
- columns={
147
- f"{target}_ci_upper": "yhat_upper",
148
- f"{target}_ci_lower": "yhat_lower",
149
- f"{target}": "yhat",
134
+ self.models[s_id] = model
135
+
136
+ # In case of Naive model, model.forecast function call does not return confidence intervals.
137
+ if f"{target}_ci_upper" not in summary_frame:
138
+ summary_frame[f"{target}_ci_upper"] = np.NAN
139
+ if f"{target}_ci_lower" not in summary_frame:
140
+ summary_frame[f"{target}_ci_lower"] = np.NAN
141
+
142
+ self.forecast_output.populate_series_output(
143
+ series_id=s_id,
144
+ fit_val=fitted_values,
145
+ forecast_val=summary_frame[target],
146
+ upper_bound=summary_frame[f"{target}_ci_upper"],
147
+ lower_bound=summary_frame[f"{target}_ci_lower"],
148
+ )
149
+
150
+ self.model_parameters[s_id] = {
151
+ "framework": SupportedModels.AutoMLX,
152
+ "time_series_period": model.time_series_period,
153
+ "selected_model": model.selected_model_,
154
+ "selected_model_params": model.selected_model_params_,
155
+ }
156
+ except Exception as e:
157
+ self.errors_dict[s_id] = {
158
+ "model_name": self.spec.model,
159
+ "error": str(e),
150
160
  }
151
- )
152
- # In case of Naive model, model.forecast function call does not return confidence intervals.
153
- if "yhat_upper" not in summary_frame:
154
- summary_frame["yhat_upper"] = np.NAN
155
- summary_frame["yhat_lower"] = np.NAN
156
- outputs[target] = summary_frame
157
- # outputs_legacy[target] = summary_frame
158
161
 
159
162
  logger.debug("===========Forecast Generated===========")
160
- outputs_merged = pd.DataFrame()
161
-
162
- # Merge the outputs from each model into 1 df with all outputs by target and category
163
- col = self.original_target_column
164
- yhat_upper_name = ForecastOutputColumns.UPPER_BOUND
165
- yhat_lower_name = ForecastOutputColumns.LOWER_BOUND
166
- for cat in self.categories: # Note: add [:2] to restrict
167
- output_i = pd.DataFrame()
168
- output_i["Date"] = outputs[f"{col}_{cat}"]["ds"]
169
- output_i["Series"] = cat
170
- output_i["input_value"] = outputs[f"{col}_{cat}"]["input_value"]
171
- output_i[f"fitted_value"] = outputs[f"{col}_{cat}"]["fitted_value"]
172
- output_i[f"forecast_value"] = outputs[f"{col}_{cat}"]["yhat"]
173
- output_i[yhat_upper_name] = outputs[f"{col}_{cat}"]["yhat_upper"]
174
- output_i[yhat_lower_name] = outputs[f"{col}_{cat}"]["yhat_lower"]
175
- outputs_merged = pd.concat([outputs_merged, output_i])
176
- outputs_legacy[f"{col}_{cat}"] = output_i
177
- self.forecast_output.add_category(
178
- category=cat, target_category_column=f"{col}_{cat}", forecast=output_i
179
- )
180
-
181
- # output_col = output_col.sort_values(self.spec.datetime_column.name).reset_index(drop=True)
182
- # output_col = output_col.reset_index(drop=True)
183
- # outputs_merged = pd.concat([outputs_merged, output_col], axis=1)
184
163
 
185
- self.models = models
186
- return outputs_merged
164
+ return self.forecast_output.get_forecast_long()
187
165
 
188
166
  @runtime_dependency(
189
167
  module="datapane",
@@ -219,11 +197,11 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
219
197
  )
220
198
  selected_models = dict()
221
199
  models = self.models
222
- for i, (target, df) in enumerate(self.full_data_dict.items()):
223
- selected_models[target] = {
224
- "series_id": target,
225
- "selected_model": models[target].selected_model_,
226
- "model_params": models[target].selected_model_params_,
200
+ for i, (s_id, df) in enumerate(self.full_data_dict.items()):
201
+ selected_models[s_id] = {
202
+ "series_id": s_id,
203
+ "selected_model": models[s_id].selected_model_,
204
+ "model_params": models[s_id].selected_model_params_,
227
205
  }
228
206
  selected_models_df = pd.DataFrame(
229
207
  selected_models.items(), columns=["series_id", "best_selected_model"]
@@ -236,63 +214,65 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
236
214
  all_sections = [selected_models_text, selected_models_section]
237
215
 
238
216
  if self.spec.generate_explanations:
239
- try:
240
- # If the key is present, call the "explain_model" method
241
- self.explain_model(
242
- datetime_col_name=self.spec.datetime_column.name,
243
- explain_predict_fn=self._custom_predict_automlx,
244
- )
245
-
246
- # Create a markdown text block for the global explanation section
247
- global_explanation_text = dp.Text(
248
- f"## Global Explanation of Models \n "
249
- "The following tables provide the feature attribution for the global explainability."
250
- )
217
+ # try:
218
+ # If the key is present, call the "explain_model" method
219
+ self.explain_model()
220
+
221
+ # Create a markdown text block for the global explanation section
222
+ global_explanation_text = dp.Text(
223
+ f"## Global Explanation of Models \n "
224
+ "The following tables provide the feature attribution for the global explainability."
225
+ )
251
226
 
252
- # Convert the global explanation data to a DataFrame
253
- global_explanation_df = pd.DataFrame(self.global_explanation)
227
+ # Convert the global explanation data to a DataFrame
228
+ global_explanation_df = pd.DataFrame(self.global_explanation)
254
229
 
255
- self.formatted_global_explanation = (
256
- global_explanation_df / global_explanation_df.sum(axis=0) * 100
230
+ self.formatted_global_explanation = (
231
+ global_explanation_df / global_explanation_df.sum(axis=0) * 100
232
+ )
233
+ self.formatted_global_explanation = (
234
+ self.formatted_global_explanation.rename(
235
+ {self.spec.datetime_column.name: ForecastOutputColumns.DATE}, axis=1
257
236
  )
237
+ )
258
238
 
259
- # Create a markdown section for the global explainability
260
- global_explanation_section = dp.Blocks(
261
- "### Global Explainability ",
262
- dp.DataTable(self.formatted_global_explanation),
263
- )
239
+ # Create a markdown section for the global explainability
240
+ global_explanation_section = dp.Blocks(
241
+ "### Global Explainability ",
242
+ dp.DataTable(self.formatted_global_explanation),
243
+ )
264
244
 
265
- aggregate_local_explanations = pd.DataFrame()
266
- for s_id, local_ex_df in self.local_explanation.items():
267
- local_ex_df_copy = local_ex_df.copy()
268
- local_ex_df_copy["Series"] = s_id
269
- aggregate_local_explanations = pd.concat(
270
- [aggregate_local_explanations, local_ex_df_copy], axis=0
271
- )
272
- self.formatted_local_explanation = aggregate_local_explanations
245
+ aggregate_local_explanations = pd.DataFrame()
246
+ for s_id, local_ex_df in self.local_explanation.items():
247
+ local_ex_df_copy = local_ex_df.copy()
248
+ local_ex_df_copy["Series"] = s_id
249
+ aggregate_local_explanations = pd.concat(
250
+ [aggregate_local_explanations, local_ex_df_copy], axis=0
251
+ )
252
+ self.formatted_local_explanation = aggregate_local_explanations
273
253
 
274
- local_explanation_text = dp.Text(f"## Local Explanation of Models \n ")
275
- blocks = [
276
- dp.DataTable(
277
- local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
278
- label=s_id,
279
- )
280
- for s_id, local_ex_df in self.local_explanation.items()
281
- ]
282
- local_explanation_section = (
283
- dp.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
254
+ local_explanation_text = dp.Text(f"## Local Explanation of Models \n ")
255
+ blocks = [
256
+ dp.DataTable(
257
+ local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
258
+ label=s_id,
284
259
  )
260
+ for s_id, local_ex_df in self.local_explanation.items()
261
+ ]
262
+ local_explanation_section = (
263
+ dp.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
264
+ )
285
265
 
286
- # Append the global explanation text and section to the "all_sections" list
287
- all_sections = all_sections + [
288
- global_explanation_text,
289
- global_explanation_section,
290
- local_explanation_text,
291
- local_explanation_section,
292
- ]
293
- except Exception as e:
294
- logger.warn(f"Failed to generate Explanations with error: {e}.")
295
- logger.debug(f"Full Traceback: {traceback.format_exc()}")
266
+ # Append the global explanation text and section to the "all_sections" list
267
+ all_sections = all_sections + [
268
+ global_explanation_text,
269
+ global_explanation_section,
270
+ local_explanation_text,
271
+ local_explanation_section,
272
+ ]
273
+ # except Exception as e:
274
+ # logger.warn(f"Failed to generate Explanations with error: {e}.")
275
+ # logger.debug(f"Full Traceback: {traceback.format_exc()}")
296
276
 
297
277
  model_description = dp.Text(
298
278
  "The AutoMLx model automatically preprocesses, selects and engineers "
@@ -305,6 +285,51 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
305
285
  other_sections,
306
286
  )
307
287
 
288
+ def get_explain_predict_fn(self, series_id):
289
+ selected_model = self.models[series_id]
290
+
291
+ # If training date, use method below. If future date, use forecast!
292
+ def _custom_predict_fn(
293
+ data,
294
+ model=selected_model,
295
+ dt_column_name=self.datasets._datetime_column_name,
296
+ target_col=self.original_target_column,
297
+ last_train_date=self.datasets.historical_data.get_max_time(),
298
+ horizon_data=self.datasets.get_horizon_at_series(series_id),
299
+ ):
300
+ """
301
+ data: ForecastDatasets.get_data_at_series(s_id)
302
+ """
303
+ data = data.drop(target_col, axis=1)
304
+ data[dt_column_name] = seconds_to_datetime(
305
+ data[dt_column_name], dt_format=self.spec.datetime_column.format
306
+ )
307
+ data = self.preprocess(data)
308
+ horizon_data = horizon_data.drop(target_col, axis=1)
309
+ horizon_data[dt_column_name] = seconds_to_datetime(
310
+ horizon_data[dt_column_name], dt_format=self.spec.datetime_column.format
311
+ )
312
+ horizon_data = self.preprocess(horizon_data)
313
+
314
+ rows = []
315
+ for i in range(data.shape[0]):
316
+ row = data.iloc[i : i + 1]
317
+ if row.index[0] > last_train_date:
318
+ X_new = horizon_data.copy()
319
+ X_new.loc[row.index[0]] = row.iloc[0]
320
+ row_i = (
321
+ model.forecast(X=X_new, periods=self.spec.horizon)[[target_col]]
322
+ .loc[row.index[0]]
323
+ .values[0]
324
+ )
325
+ else:
326
+ row_i = model.predict(X=row).values[0][0]
327
+ rows.append(row_i)
328
+ ret = np.asarray(rows).flatten()
329
+ return ret
330
+
331
+ return _custom_predict_fn
332
+
308
333
  def _custom_predict_automlx(self, data):
309
334
  """
310
335
  Predicts the future values of a time series using the AutoMLX model.
@@ -316,7 +341,6 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
316
341
  -------
317
342
  numpy.ndarray: The predicted future values of the time series.
318
343
  """
319
- temp = 0
320
344
  data_temp = pd.DataFrame(
321
345
  data,
322
346
  columns=[col for col in self.dataset_cols],