oracle-ads 2.13.0__py3-none-any.whl → 2.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. ads/aqua/__init__.py +7 -1
  2. ads/aqua/app.py +24 -23
  3. ads/aqua/client/client.py +48 -11
  4. ads/aqua/common/entities.py +28 -1
  5. ads/aqua/common/enums.py +13 -7
  6. ads/aqua/common/utils.py +8 -13
  7. ads/aqua/config/container_config.py +203 -0
  8. ads/aqua/config/evaluation/evaluation_service_config.py +5 -181
  9. ads/aqua/constants.py +0 -1
  10. ads/aqua/evaluation/evaluation.py +4 -4
  11. ads/aqua/extension/base_handler.py +4 -0
  12. ads/aqua/extension/model_handler.py +19 -28
  13. ads/aqua/finetuning/finetuning.py +2 -3
  14. ads/aqua/model/entities.py +2 -3
  15. ads/aqua/model/model.py +25 -30
  16. ads/aqua/modeldeployment/deployment.py +6 -14
  17. ads/aqua/modeldeployment/entities.py +2 -2
  18. ads/aqua/server/__init__.py +4 -0
  19. ads/aqua/server/__main__.py +24 -0
  20. ads/aqua/server/app.py +47 -0
  21. ads/aqua/server/aqua_spec.yml +1291 -0
  22. ads/aqua/ui.py +5 -199
  23. ads/common/auth.py +20 -11
  24. ads/common/utils.py +91 -11
  25. ads/config.py +3 -0
  26. ads/llm/__init__.py +1 -0
  27. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +32 -23
  28. ads/model/artifact_downloader.py +4 -1
  29. ads/model/common/utils.py +15 -3
  30. ads/model/datascience_model.py +339 -8
  31. ads/model/model_metadata.py +54 -14
  32. ads/model/model_version_set.py +5 -3
  33. ads/model/service/oci_datascience_model.py +477 -5
  34. ads/opctl/anomaly_detection.py +11 -0
  35. ads/opctl/forecast.py +11 -0
  36. ads/opctl/operator/common/utils.py +16 -0
  37. ads/opctl/operator/lowcode/common/data.py +5 -2
  38. ads/opctl/operator/lowcode/common/transformations.py +2 -12
  39. ads/opctl/operator/lowcode/forecast/__main__.py +5 -5
  40. ads/opctl/operator/lowcode/forecast/model/arima.py +6 -3
  41. ads/opctl/operator/lowcode/forecast/model/automlx.py +61 -31
  42. ads/opctl/operator/lowcode/forecast/model/base_model.py +66 -40
  43. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +79 -13
  44. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +5 -2
  45. ads/opctl/operator/lowcode/forecast/model/prophet.py +28 -15
  46. ads/opctl/operator/lowcode/forecast/model_evaluator.py +13 -15
  47. ads/opctl/operator/lowcode/forecast/schema.yaml +1 -1
  48. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +7 -0
  49. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +19 -11
  50. {oracle_ads-2.13.0.dist-info → oracle_ads-2.13.1.dist-info}/METADATA +18 -15
  51. {oracle_ads-2.13.0.dist-info → oracle_ads-2.13.1.dist-info}/RECORD +54 -48
  52. {oracle_ads-2.13.0.dist-info → oracle_ads-2.13.1.dist-info}/WHEEL +1 -1
  53. ads/aqua/config/evaluation/evaluation_service_model_config.py +0 -8
  54. {oracle_ads-2.13.0.dist-info → oracle_ads-2.13.1.dist-info}/entry_points.txt +0 -0
  55. {oracle_ads-2.13.0.dist-info → oracle_ads-2.13.1.dist-info/licenses}/LICENSE.txt +0 -0
@@ -116,7 +116,10 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
116
116
  lower_bound=self.get_horizon(forecast["yhat_lower"]).values,
117
117
  )
118
118
 
119
- self.models[s_id] = model
119
+ self.models[s_id] = {}
120
+ self.models[s_id]["model"] = model
121
+ self.models[s_id]["le"] = self.le[s_id]
122
+ self.models[s_id]["predict_component_cols"] = X_pred.columns
120
123
 
121
124
  params = vars(model).copy()
122
125
  for param in ["arima_res_", "endog_index_"]:
@@ -163,7 +166,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
163
166
  sec5_text = rc.Heading("ARIMA Model Parameters", level=2)
164
167
  blocks = [
165
168
  rc.Html(
166
- m.summary().as_html(),
169
+ m['model'].summary().as_html(),
167
170
  label=s_id if self.target_cat_col else None,
168
171
  )
169
172
  for i, (s_id, m) in enumerate(self.models.items())
@@ -251,7 +254,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
251
254
  def get_explain_predict_fn(self, series_id):
252
255
  def _custom_predict(
253
256
  data,
254
- model=self.models[series_id],
257
+ model=self.models[series_id]["model"],
255
258
  dt_column_name=self.datasets._datetime_column_name,
256
259
  target_col=self.original_target_column,
257
260
  ):
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
  import logging
5
5
  import os
@@ -56,8 +56,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
56
56
  )
57
57
  return model_kwargs_cleaned, time_budget
58
58
 
59
- def preprocess(self, data): # TODO: re-use self.le for explanations
60
- _, df_encoded = _label_encode_dataframe(
59
+ def preprocess(self, data, series_id): # TODO: re-use self.le for explanations
60
+ self.le[series_id], df_encoded = _label_encode_dataframe(
61
61
  data,
62
62
  no_encode={self.spec.datetime_column.name, self.original_target_column},
63
63
  )
@@ -66,8 +66,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
66
66
  @runtime_dependency(
67
67
  module="automlx",
68
68
  err_msg=(
69
- "Please run `pip3 install oracle-automlx>=23.4.1` and "
70
- "`pip3 install oracle-automlx[forecasting]>=23.4.1` "
69
+ "Please run `pip3 install oracle-automlx[forecasting]>=25.1.1` "
71
70
  "to install the required dependencies for automlx."
72
71
  ),
73
72
  )
@@ -105,7 +104,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
105
104
  engine_opts = (
106
105
  None
107
106
  if engine_type == "local"
108
- else ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
107
+ else {"ray_setup": {"_temp_dir": "/tmp/ray-temp"}}
109
108
  )
110
109
  init(
111
110
  engine=engine_type,
@@ -125,7 +124,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
125
124
  self.forecast_output.init_series_output(
126
125
  series_id=s_id, data_at_series=df
127
126
  )
128
- data = self.preprocess(df)
127
+ data = self.preprocess(df, s_id)
129
128
  data_i = self.drop_horizon(data)
130
129
  X_pred = self.get_horizon(data).drop(target, axis=1)
131
130
 
@@ -157,7 +156,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
157
156
  target
158
157
  ].values
159
158
 
160
- self.models[s_id] = model
159
+ self.models[s_id] = {}
160
+ self.models[s_id]["model"] = model
161
+ self.models[s_id]["le"] = self.le[s_id]
161
162
 
162
163
  # In case of Naive model, model.forecast function call does not return confidence intervals.
163
164
  if f"{target}_ci_upper" not in summary_frame:
@@ -218,7 +219,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
218
219
  other_sections = []
219
220
 
220
221
  if len(self.models) > 0:
221
- for s_id, m in models.items():
222
+ for s_id, artifacts in models.items():
223
+ m = artifacts["model"]
222
224
  selected_models[s_id] = {
223
225
  "series_id": s_id,
224
226
  "selected_model": m.selected_model_,
@@ -247,17 +249,17 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
247
249
  self.explain_model()
248
250
 
249
251
  global_explanation_section = None
250
- if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
251
- # Convert the global explanation data to a DataFrame
252
- global_explanation_df = pd.DataFrame(self.global_explanation)
252
+ # Convert the global explanation data to a DataFrame
253
+ global_explanation_df = pd.DataFrame(self.global_explanation)
253
254
 
254
- self.formatted_global_explanation = (
255
- global_explanation_df / global_explanation_df.sum(axis=0) * 100
256
- )
257
- self.formatted_global_explanation = self.formatted_global_explanation.rename(
258
- {self.spec.datetime_column.name: ForecastOutputColumns.DATE},
259
- axis=1,
260
- )
255
+ self.formatted_global_explanation = (
256
+ global_explanation_df / global_explanation_df.sum(axis=0) * 100
257
+ )
258
+
259
+ self.formatted_global_explanation.rename(
260
+ columns={self.spec.datetime_column.name: ForecastOutputColumns.DATE},
261
+ inplace=True,
262
+ )
261
263
 
262
264
  aggregate_local_explanations = pd.DataFrame()
263
265
  for s_id, local_ex_df in self.local_explanation.items():
@@ -269,11 +271,15 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
269
271
  self.formatted_local_explanation = aggregate_local_explanations
270
272
 
271
273
  if not self.target_cat_col:
272
- self.formatted_global_explanation = self.formatted_global_explanation.rename(
273
- {"Series 1": self.original_target_column},
274
- axis=1,
274
+ self.formatted_global_explanation = (
275
+ self.formatted_global_explanation.rename(
276
+ {"Series 1": self.original_target_column},
277
+ axis=1,
278
+ )
279
+ )
280
+ self.formatted_local_explanation.drop(
281
+ "Series", axis=1, inplace=True
275
282
  )
276
- self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
277
283
 
278
284
  # Create a markdown section for the global explainability
279
285
  global_explanation_section = rc.Block(
@@ -320,7 +326,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
320
326
  )
321
327
 
322
328
  def get_explain_predict_fn(self, series_id):
323
- selected_model = self.models[series_id]
329
+ selected_model = self.models[series_id]["model"]
324
330
 
325
331
  # If training date, use method below. If future date, use forecast!
326
332
  def _custom_predict_fn(
@@ -338,12 +344,12 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
338
344
  data[dt_column_name] = seconds_to_datetime(
339
345
  data[dt_column_name], dt_format=self.spec.datetime_column.format
340
346
  )
341
- data = self.preprocess(data)
347
+ data = self.preprocess(data, series_id)
342
348
  horizon_data = horizon_data.drop(target_col, axis=1)
343
349
  horizon_data[dt_column_name] = seconds_to_datetime(
344
350
  horizon_data[dt_column_name], dt_format=self.spec.datetime_column.format
345
351
  )
346
- horizon_data = self.preprocess(horizon_data)
352
+ horizon_data = self.preprocess(horizon_data, series_id)
347
353
 
348
354
  rows = []
349
355
  for i in range(data.shape[0]):
@@ -421,8 +427,10 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
421
427
  if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
422
428
  # Use the MLExplainer class from AutoMLx to generate explanations
423
429
  explainer = automlx.MLExplainer(
424
- self.models[s_id],
425
- self.datasets.additional_data.get_data_for_series(series_id=s_id)
430
+ self.models[s_id]["model"],
431
+ self.datasets.additional_data.get_data_for_series(
432
+ series_id=s_id
433
+ )
426
434
  .drop(self.spec.datetime_column.name, axis=1)
427
435
  .head(-self.spec.horizon)
428
436
  if self.spec.additional_data
@@ -433,7 +441,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
433
441
 
434
442
  # Generate explanations for the forecast
435
443
  explanations = explainer.explain_prediction(
436
- X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
444
+ X=self.datasets.additional_data.get_data_for_series(
445
+ series_id=s_id
446
+ )
437
447
  .drop(self.spec.datetime_column.name, axis=1)
438
448
  .tail(self.spec.horizon)
439
449
  if self.spec.additional_data
@@ -445,7 +455,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
445
455
  explanations_df = pd.concat(
446
456
  [exp.to_dataframe() for exp in explanations]
447
457
  )
448
- explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
458
+ explanations_df["row"] = explanations_df.groupby(
459
+ "Feature"
460
+ ).cumcount()
449
461
  explanations_df = explanations_df.pivot(
450
462
  index="row", columns="Feature", values="Attribution"
451
463
  )
@@ -453,9 +465,27 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
453
465
 
454
466
  # Store the explanations in the local_explanation dictionary
455
467
  self.local_explanation[s_id] = explanations_df
468
+
469
+ self.global_explanation[s_id] = dict(
470
+ zip(
471
+ self.local_explanation[s_id].columns,
472
+ np.nanmean(np.abs(self.local_explanation[s_id]), axis=0),
473
+ )
474
+ )
456
475
  else:
457
476
  # Fall back to the default explanation generation method
458
477
  super().explain_model()
459
478
  except Exception as e:
460
- logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.")
479
+ if s_id in self.errors_dict:
480
+ self.errors_dict[s_id]["explainer_error"] = str(e)
481
+ self.errors_dict[s_id]["explainer_error_trace"] = traceback.format_exc()
482
+ else:
483
+ self.errors_dict[s_id] = {
484
+ "model_name": self.spec.model,
485
+ "explainer_error": str(e),
486
+ "explainer_error_trace": traceback.format_exc(),
487
+ }
488
+ logger.warning(
489
+ f"Failed to generate explanations for series {s_id} with error: {e}."
490
+ )
461
491
  logger.debug(f"Full Traceback: {traceback.format_exc()}")
@@ -1,6 +1,6 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
6
  import logging
@@ -19,6 +19,7 @@ import report_creator as rc
19
19
  from ads.common.decorator.runtime_dependency import runtime_dependency
20
20
  from ads.common.object_storage_details import ObjectStorageDetails
21
21
  from ads.opctl import logger
22
+ from ads.opctl.operator.lowcode.common.const import DataColumns
22
23
  from ads.opctl.operator.lowcode.common.utils import (
23
24
  datetime_to_seconds,
24
25
  disable_print,
@@ -28,7 +29,6 @@ from ads.opctl.operator.lowcode.common.utils import (
28
29
  seconds_to_datetime,
29
30
  write_data,
30
31
  )
31
- from ads.opctl.operator.lowcode.common.const import DataColumns
32
32
  from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
33
33
  from ads.opctl.operator.lowcode.forecast.utils import (
34
34
  _build_metrics_df,
@@ -49,10 +49,9 @@ from ..const import (
49
49
  SpeedAccuracyMode,
50
50
  SupportedMetrics,
51
51
  SupportedModels,
52
- BACKTEST_REPORT_NAME,
53
52
  )
54
53
  from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
55
- from .forecast_datasets import ForecastDatasets
54
+ from .forecast_datasets import ForecastDatasets, ForecastResults
56
55
 
57
56
  logging.getLogger("report_creator").setLevel(logging.WARNING)
58
57
 
@@ -121,27 +120,30 @@ class ForecastOperatorBaseModel(ABC):
121
120
 
122
121
  # Generate metrics
123
122
  summary_metrics = None
124
- test_data = None
123
+ test_data = self.datasets.test_data
125
124
  self.eval_metrics = None
126
125
 
127
126
  if self.spec.generate_report or self.spec.generate_metrics:
128
127
  self.eval_metrics = self.generate_train_metrics()
129
128
  if not self.target_cat_col:
130
- self.eval_metrics.rename({"Series 1": self.original_target_column},
131
- axis=1, inplace=True)
129
+ self.eval_metrics.rename(
130
+ {"Series 1": self.original_target_column}, axis=1, inplace=True
131
+ )
132
132
 
133
- if self.spec.test_data:
133
+ if self.datasets.test_data is not None:
134
134
  try:
135
135
  (
136
136
  self.test_eval_metrics,
137
- summary_metrics,
138
- test_data,
137
+ summary_metrics
139
138
  ) = self._test_evaluate_metrics(
140
139
  elapsed_time=elapsed_time,
141
140
  )
142
141
  if not self.target_cat_col:
143
- self.test_eval_metrics.rename({"Series 1": self.original_target_column},
144
- axis=1, inplace=True)
142
+ self.test_eval_metrics.rename(
143
+ {"Series 1": self.original_target_column},
144
+ axis=1,
145
+ inplace=True,
146
+ )
145
147
  except Exception:
146
148
  logger.warn("Unable to generate Test Metrics.")
147
149
  logger.debug(f"Full Traceback: {traceback.format_exc()}")
@@ -223,17 +225,23 @@ class ForecastOperatorBaseModel(ABC):
223
225
  rc.Block(
224
226
  first_10_title,
225
227
  # series_subtext,
226
- rc.Select(blocks=first_5_rows_blocks) if self.target_cat_col else first_5_rows_blocks[0],
228
+ rc.Select(blocks=first_5_rows_blocks)
229
+ if self.target_cat_col
230
+ else first_5_rows_blocks[0],
227
231
  ),
228
232
  rc.Block(
229
233
  last_10_title,
230
234
  # series_subtext,
231
- rc.Select(blocks=last_5_rows_blocks) if self.target_cat_col else last_5_rows_blocks[0],
235
+ rc.Select(blocks=last_5_rows_blocks)
236
+ if self.target_cat_col
237
+ else last_5_rows_blocks[0],
232
238
  ),
233
239
  rc.Block(
234
240
  summary_title,
235
241
  # series_subtext,
236
- rc.Select(blocks=data_summary_blocks) if self.target_cat_col else data_summary_blocks[0],
242
+ rc.Select(blocks=data_summary_blocks)
243
+ if self.target_cat_col
244
+ else data_summary_blocks[0],
237
245
  ),
238
246
  rc.Separator(),
239
247
  )
@@ -308,7 +316,7 @@ class ForecastOperatorBaseModel(ABC):
308
316
  horizon=self.spec.horizon,
309
317
  test_data=test_data,
310
318
  ci_interval_width=self.spec.confidence_interval_width,
311
- target_category_column=self.target_cat_col
319
+ target_category_column=self.target_cat_col,
312
320
  )
313
321
  if (
314
322
  series_name is not None
@@ -341,17 +349,18 @@ class ForecastOperatorBaseModel(ABC):
341
349
  )
342
350
 
343
351
  # save the report and result CSV
344
- self._save_report(
352
+ return self._save_report(
345
353
  report_sections=report_sections,
346
354
  result_df=result_df,
347
355
  metrics_df=self.eval_metrics,
348
356
  test_metrics_df=self.test_eval_metrics,
357
+ test_data=test_data,
349
358
  )
350
359
 
351
360
  def _test_evaluate_metrics(self, elapsed_time=0):
352
361
  total_metrics = pd.DataFrame()
353
362
  summary_metrics = pd.DataFrame()
354
- data = TestData(self.spec)
363
+ data = self.datasets.test_data
355
364
 
356
365
  # Generate y_pred and y_true for each series
357
366
  for s_id in self.forecast_output.list_series_ids():
@@ -388,7 +397,7 @@ class ForecastOperatorBaseModel(ABC):
388
397
  total_metrics = pd.concat([total_metrics, metrics_df], axis=1)
389
398
 
390
399
  if total_metrics.empty:
391
- return total_metrics, summary_metrics, data
400
+ return total_metrics, summary_metrics
392
401
 
393
402
  summary_metrics = pd.DataFrame(
394
403
  {
@@ -454,7 +463,7 @@ class ForecastOperatorBaseModel(ABC):
454
463
  ]
455
464
  summary_metrics = summary_metrics[new_column_order]
456
465
 
457
- return total_metrics, summary_metrics, data
466
+ return total_metrics, summary_metrics
458
467
 
459
468
  def _save_report(
460
469
  self,
@@ -462,10 +471,12 @@ class ForecastOperatorBaseModel(ABC):
462
471
  result_df: pd.DataFrame,
463
472
  metrics_df: pd.DataFrame,
464
473
  test_metrics_df: pd.DataFrame,
474
+ test_data: pd.DataFrame,
465
475
  ):
466
476
  """Saves resulting reports to the given folder."""
467
477
 
468
478
  unique_output_dir = self.spec.output_directory.url
479
+ results = ForecastResults()
469
480
 
470
481
  if ObjectStorageDetails.is_oci_path(unique_output_dir):
471
482
  storage_options = default_signer()
@@ -491,13 +502,23 @@ class ForecastOperatorBaseModel(ABC):
491
502
  f2.write(f1.read())
492
503
 
493
504
  # forecast csv report
494
- result_df = result_df if self.target_cat_col else result_df.drop(DataColumns.Series, axis=1)
505
+ # todo: add test data into forecast.csv
506
+ # if self.spec.test_data is not None:
507
+ # test_data_dict = test_data.get_dict_by_series()
508
+ # for series_id, test_data_values in test_data_dict.items():
509
+ # result_df[DataColumns.Series] = test_data_values[]
510
+ result_df = (
511
+ result_df
512
+ if self.target_cat_col
513
+ else result_df.drop(DataColumns.Series, axis=1)
514
+ )
495
515
  write_data(
496
516
  data=result_df,
497
517
  filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
498
518
  format="csv",
499
519
  storage_options=storage_options,
500
520
  )
521
+ results.set_forecast(result_df)
501
522
 
502
523
  # metrics csv report
503
524
  if self.spec.generate_metrics:
@@ -507,10 +528,11 @@ class ForecastOperatorBaseModel(ABC):
507
528
  else "Series 1"
508
529
  )
509
530
  if metrics_df is not None:
531
+ metrics_df_formatted = metrics_df.reset_index().rename(
532
+ {"index": "metrics", "Series 1": metrics_col_name}, axis=1
533
+ )
510
534
  write_data(
511
- data=metrics_df.reset_index().rename(
512
- {"index": "metrics", "Series 1": metrics_col_name}, axis=1
513
- ),
535
+ data=metrics_df_formatted,
514
536
  filename=os.path.join(
515
537
  unique_output_dir, self.spec.metrics_filename
516
538
  ),
@@ -518,18 +540,20 @@ class ForecastOperatorBaseModel(ABC):
518
540
  storage_options=storage_options,
519
541
  index=False,
520
542
  )
543
+ results.set_metrics(metrics_df_formatted)
521
544
  else:
522
545
  logger.warn(
523
546
  f"Attempted to generate the {self.spec.metrics_filename} file with the training metrics, however the training metrics could not be properly generated."
524
547
  )
525
548
 
526
549
  # test_metrics csv report
527
- if self.spec.test_data is not None:
550
+ if self.datasets.test_data is not None:
528
551
  if test_metrics_df is not None:
552
+ test_metrics_df_formatted = test_metrics_df.reset_index().rename(
553
+ {"index": "metrics", "Series 1": metrics_col_name}, axis=1
554
+ )
529
555
  write_data(
530
- data=test_metrics_df.reset_index().rename(
531
- {"index": "metrics", "Series 1": metrics_col_name}, axis=1
532
- ),
556
+ data=test_metrics_df_formatted,
533
557
  filename=os.path.join(
534
558
  unique_output_dir, self.spec.test_metrics_filename
535
559
  ),
@@ -537,6 +561,7 @@ class ForecastOperatorBaseModel(ABC):
537
561
  storage_options=storage_options,
538
562
  index=False,
539
563
  )
564
+ results.set_test_metrics(test_metrics_df_formatted)
540
565
  else:
541
566
  logger.warn(
542
567
  f"Attempted to generate the {self.spec.test_metrics_filename} file with the test metrics, however the test metrics could not be properly generated."
@@ -544,7 +569,7 @@ class ForecastOperatorBaseModel(ABC):
544
569
  # explanations csv reports
545
570
  if self.spec.generate_explanations:
546
571
  try:
547
- if self.formatted_global_explanation is not None:
572
+ if not self.formatted_global_explanation.empty:
548
573
  write_data(
549
574
  data=self.formatted_global_explanation,
550
575
  filename=os.path.join(
@@ -554,12 +579,13 @@ class ForecastOperatorBaseModel(ABC):
554
579
  storage_options=storage_options,
555
580
  index=True,
556
581
  )
582
+ results.set_global_explanations(self.formatted_global_explanation)
557
583
  else:
558
584
  logger.warn(
559
585
  f"Attempted to generate global explanations for the {self.spec.global_explanation_filename} file, but an issue occured in formatting the explanations."
560
586
  )
561
587
 
562
- if self.formatted_local_explanation is not None:
588
+ if not self.formatted_local_explanation.empty:
563
589
  write_data(
564
590
  data=self.formatted_local_explanation,
565
591
  filename=os.path.join(
@@ -569,6 +595,7 @@ class ForecastOperatorBaseModel(ABC):
569
595
  storage_options=storage_options,
570
596
  index=True,
571
597
  )
598
+ results.set_local_explanations(self.formatted_local_explanation)
572
599
  else:
573
600
  logger.warn(
574
601
  f"Attempted to generate local explanations for the {self.spec.local_explanation_filename} file, but an issue occured in formatting the explanations."
@@ -589,10 +616,12 @@ class ForecastOperatorBaseModel(ABC):
589
616
  index=True,
590
617
  indent=4,
591
618
  )
619
+ results.set_model_parameters(self.model_parameters)
592
620
 
593
621
  # model pickle
594
622
  if self.spec.generate_model_pickle:
595
623
  self._save_model(unique_output_dir, storage_options)
624
+ results.set_models(self.models)
596
625
 
597
626
  logger.info(
598
627
  f"The outputs have been successfully "
@@ -612,8 +641,10 @@ class ForecastOperatorBaseModel(ABC):
612
641
  index=True,
613
642
  indent=4,
614
643
  )
644
+ results.set_errors_dict(self.errors_dict)
615
645
  else:
616
646
  logger.info("All modeling completed successfully.")
647
+ return results
617
648
 
618
649
  def preprocess(self, df, series_id):
619
650
  """The method that needs to be implemented on the particular model level."""
@@ -667,7 +698,10 @@ class ForecastOperatorBaseModel(ABC):
667
698
  )
668
699
 
669
700
  def _validate_automlx_explanation_mode(self):
670
- if self.spec.model != SupportedModels.AutoMLX and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
701
+ if (
702
+ self.spec.model != SupportedModels.AutoMLX
703
+ and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX
704
+ ):
671
705
  raise ValueError(
672
706
  "AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
673
707
  "Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
@@ -738,14 +772,6 @@ class ForecastOperatorBaseModel(ABC):
738
772
  logger.warn(
739
773
  "No explanations generated. Ensure that additional data has been provided."
740
774
  )
741
- elif (
742
- self.spec.model == SupportedModels.AutoMLX
743
- and self.spec.explanations_accuracy_mode
744
- == SpeedAccuracyMode.AUTOMLX
745
- ):
746
- logger.warning(
747
- "Global explanations not available for AutoMLX models with inherent explainability"
748
- )
749
775
  else:
750
776
  self.global_explanation[s_id] = dict(
751
777
  zip(
@@ -794,7 +820,7 @@ class ForecastOperatorBaseModel(ABC):
794
820
  def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"):
795
821
  def _custom_predict(
796
822
  data,
797
- model=self.models[series_id],
823
+ model=self.models[series_id]["model"],
798
824
  dt_column_name=self.datasets._datetime_column_name,
799
825
  ):
800
826
  """