oracle-ads 2.13.2__py3-none-any.whl → 2.13.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -71,7 +71,7 @@ class AnomalyOperatorBaseModel(ABC):
71
71
  try:
72
72
  anomaly_output = self._build_model()
73
73
  except Exception as e:
74
- logger.warn(f"Found exception: {e}")
74
+ logger.warning(f"Found exception: {e}")
75
75
  if self.spec.datetime_column:
76
76
  anomaly_output = self._fallback_build_model()
77
77
  raise e
@@ -347,7 +347,7 @@ class AnomalyOperatorBaseModel(ABC):
347
347
  storage_options=storage_options,
348
348
  )
349
349
 
350
- logger.warn(
350
+ logger.warning(
351
351
  f"The report has been successfully "
352
352
  f"generated and placed to the: {unique_output_dir}."
353
353
  )
@@ -356,7 +356,7 @@ class AnomalyOperatorBaseModel(ABC):
356
356
  """
357
357
  Fallback method for the sub model _build_model method.
358
358
  """
359
- logger.warn(
359
+ logger.warning(
360
360
  f"The build_model method has failed for the model: {self.spec.model}. "
361
361
  "A fallback model will be built."
362
362
  )
@@ -95,7 +95,7 @@ class RandomCutForestOperatorModel(AnomalyOperatorBaseModel):
95
95
 
96
96
  anomaly_output.add_output(target, anomaly, score)
97
97
  except Exception as e:
98
- logger.warn(f"Encountered Error: {e}. Skipping series {target}.")
98
+ logger.warning(f"Encountered Error: {e}. Skipping series {target}.")
99
99
 
100
100
  return anomaly_output
101
101
 
@@ -44,7 +44,7 @@ def _build_metrics_df(y_true, y_pred, column_name):
44
44
  # Throws exception if y_true has only one class
45
45
  metrics[SupportedMetrics.ROC_AUC] = roc_auc_score(y_true, y_pred)
46
46
  except Exception as e:
47
- logger.warn(f"An exception occurred: {e}")
47
+ logger.warning(f"An exception occurred: {e}")
48
48
  metrics[SupportedMetrics.ROC_AUC] = None
49
49
  precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
50
50
  metrics[SupportedMetrics.PRC_AUC] = auc(recall, precision)
@@ -98,7 +98,11 @@ class Transformations(ABC):
98
98
  return clean_df
99
99
 
100
100
  def _remove_trailing_whitespace(self, df):
101
- return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
101
+ return df.apply(
102
+ lambda x: x.str.strip()
103
+ if hasattr(x, "dtype") and x.dtype == "object"
104
+ else x
105
+ )
102
106
 
103
107
  def _clean_column_names(self, df):
104
108
  """
@@ -3,6 +3,7 @@
3
3
  # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
+ import json
6
7
  import logging
7
8
  import os
8
9
  import shutil
@@ -12,7 +13,6 @@ from typing import List, Union
12
13
 
13
14
  import fsspec
14
15
  import oracledb
15
- import json
16
16
  import pandas as pd
17
17
 
18
18
  from ads.common.object_storage_details import ObjectStorageDetails
@@ -142,6 +142,11 @@ def write_data(data, filename, format, storage_options=None, index=False, **kwar
142
142
  )
143
143
 
144
144
 
145
+ def write_json(json_dict, filename, storage_options=None):
146
+ with fsspec.open(filename, mode="w", **storage_options) as f:
147
+ f.write(json.dumps(json_dict))
148
+
149
+
145
150
  def write_simple_json(data, path):
146
151
  if ObjectStorageDetails.is_oci_path(path):
147
152
  storage_options = default_signer()
@@ -265,7 +270,7 @@ def find_output_dirname(output_dir: OutputDirectory):
265
270
  while os.path.exists(unique_output_dir):
266
271
  unique_output_dir = f"{output_dir}_{counter}"
267
272
  counter += 1
268
- logger.warn(
273
+ logger.warning(
269
274
  f"Since the output directory was not specified, the output will be saved to {unique_output_dir} directory."
270
275
  )
271
276
  return unique_output_dir
@@ -1,6 +1,6 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
6
  import logging
@@ -132,13 +132,14 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
132
132
 
133
133
  logger.debug("===========Done===========")
134
134
  except Exception as e:
135
- self.errors_dict[s_id] = {
135
+ new_error = {
136
136
  "model_name": self.spec.model,
137
137
  "error": str(e),
138
138
  "error_trace": traceback.format_exc(),
139
139
  }
140
- logger.warn(f"Encountered Error: {e}. Skipping.")
141
- logger.warn(traceback.format_exc())
140
+ self.errors_dict[s_id] = new_error
141
+ logger.warning(f"Encountered Error: {e}. Skipping.")
142
+ logger.warning(traceback.format_exc())
142
143
 
143
144
  def _build_model(self) -> pd.DataFrame:
144
145
  full_data_dict = self.datasets.get_data_by_series()
@@ -166,7 +167,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
166
167
  sec5_text = rc.Heading("ARIMA Model Parameters", level=2)
167
168
  blocks = [
168
169
  rc.Html(
169
- m['model'].summary().as_html(),
170
+ m["model"].summary().as_html(),
170
171
  label=s_id if self.target_cat_col else None,
171
172
  )
172
173
  for i, (s_id, m) in enumerate(self.models.items())
@@ -201,11 +202,15 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
201
202
  self.formatted_local_explanation = aggregate_local_explanations
202
203
 
203
204
  if not self.target_cat_col:
204
- self.formatted_global_explanation = self.formatted_global_explanation.rename(
205
- {"Series 1": self.original_target_column},
206
- axis=1,
205
+ self.formatted_global_explanation = (
206
+ self.formatted_global_explanation.rename(
207
+ {"Series 1": self.original_target_column},
208
+ axis=1,
209
+ )
210
+ )
211
+ self.formatted_local_explanation.drop(
212
+ "Series", axis=1, inplace=True
207
213
  )
208
- self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
209
214
 
210
215
  # Create a markdown section for the global explainability
211
216
  global_explanation_section = rc.Block(
@@ -235,7 +240,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
235
240
  local_explanation_section,
236
241
  ]
237
242
  except Exception as e:
238
- logger.warn(f"Failed to generate Explanations with error: {e}.")
243
+ logger.warning(f"Failed to generate Explanations with error: {e}.")
239
244
  logger.debug(f"Full Traceback: {traceback.format_exc()}")
240
245
 
241
246
  model_description = rc.Text(
@@ -184,13 +184,18 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
184
184
  "selected_model_params": model.selected_model_params_,
185
185
  }
186
186
  except Exception as e:
187
- self.errors_dict[s_id] = {
187
+ new_error = {
188
188
  "model_name": self.spec.model,
189
189
  "error": str(e),
190
190
  "error_trace": traceback.format_exc(),
191
191
  }
192
- logger.warn(f"Encountered Error: {e}. Skipping.")
193
- logger.warn(traceback.format_exc())
192
+ if s_id in self.errors_dict:
193
+ self.errors_dict[s_id]["model_fitting"] = new_error
194
+ else:
195
+ self.errors_dict[s_id] = {"model_fitting": new_error}
196
+ logger.warning(f"Encountered Error: {e}. Skipping.")
197
+ logger.warning(f"self.errors_dict[s_id]: {self.errors_dict[s_id]}")
198
+ logger.warning(traceback.format_exc())
194
199
 
195
200
  logger.debug("===========Forecast Generated===========")
196
201
 
@@ -257,7 +262,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
257
262
  )
258
263
 
259
264
  self.formatted_global_explanation.rename(
260
- columns={self.spec.datetime_column.name: ForecastOutputColumns.DATE},
265
+ columns={
266
+ self.spec.datetime_column.name: ForecastOutputColumns.DATE
267
+ },
261
268
  inplace=True,
262
269
  )
263
270
 
@@ -312,7 +319,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
312
319
  local_explanation_section,
313
320
  ]
314
321
  except Exception as e:
315
- logger.warn(f"Failed to generate Explanations with error: {e}.")
322
+ logger.warning(f"Failed to generate Explanations with error: {e}.")
316
323
  logger.debug(f"Full Traceback: {traceback.format_exc()}")
317
324
 
318
325
  model_description = rc.Text(
@@ -462,14 +469,27 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
462
469
  index="row", columns="Feature", values="Attribution"
463
470
  )
464
471
  explanations_df = explanations_df.reset_index(drop=True)
465
-
472
+ explanations_df[ForecastOutputColumns.DATE] = (
473
+ self.datasets.get_horizon_at_series(
474
+ s_id=s_id
475
+ )[self.spec.datetime_column.name].reset_index(drop=True)
476
+ )
466
477
  # Store the explanations in the local_explanation dictionary
467
478
  self.local_explanation[s_id] = explanations_df
468
479
 
469
480
  self.global_explanation[s_id] = dict(
470
481
  zip(
471
- self.local_explanation[s_id].columns,
472
- np.nanmean(np.abs(self.local_explanation[s_id]), axis=0),
482
+ self.local_explanation[s_id]
483
+ .drop(ForecastOutputColumns.DATE, axis=1)
484
+ .columns,
485
+ np.nanmean(
486
+ np.abs(
487
+ self.local_explanation[s_id].drop(
488
+ ForecastOutputColumns.DATE, axis=1
489
+ )
490
+ ),
491
+ axis=0,
492
+ ),
473
493
  )
474
494
  )
475
495
  else:
@@ -478,7 +498,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
478
498
  except Exception as e:
479
499
  if s_id in self.errors_dict:
480
500
  self.errors_dict[s_id]["explainer_error"] = str(e)
481
- self.errors_dict[s_id]["explainer_error_trace"] = traceback.format_exc()
501
+ self.errors_dict[s_id]["explainer_error_trace"] = (
502
+ traceback.format_exc()
503
+ )
482
504
  else:
483
505
  self.errors_dict[s_id] = {
484
506
  "model_name": self.spec.model,
@@ -211,8 +211,8 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
211
211
  "error": str(e),
212
212
  "error_trace": traceback.format_exc(),
213
213
  }
214
- logger.warn(f"Encountered Error: {e}. Skipping.")
215
- logger.warn(traceback.format_exc())
214
+ logger.warning(f"Encountered Error: {e}. Skipping.")
215
+ logger.warning(traceback.format_exc())
216
216
 
217
217
  logger.debug("===========Done===========")
218
218
 
@@ -242,7 +242,7 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
242
242
  self.models.df_wide_numeric, series=s_id
243
243
  ),
244
244
  self.datasets.list_series_ids(),
245
- target_category_column=self.target_cat_col
245
+ target_category_column=self.target_cat_col,
246
246
  )
247
247
  section_1 = rc.Block(
248
248
  rc.Heading("Forecast Overview", level=2),
@@ -260,7 +260,9 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
260
260
  )
261
261
 
262
262
  except KeyError:
263
- logger.warn("Issue generating Model Parameters Table Section. Skipping")
263
+ logger.warning(
264
+ "Issue generating Model Parameters Table Section. Skipping"
265
+ )
264
266
  sec2 = rc.Text("Error generating model parameters.")
265
267
 
266
268
  section_2 = rc.Block(sec2_text, sec2)
@@ -268,7 +270,7 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
268
270
  all_sections = [section_1, section_2]
269
271
 
270
272
  if self.spec.generate_explanations:
271
- logger.warn("Explanations not yet supported for the AutoTS Module")
273
+ logger.warning("Explanations not yet supported for the AutoTS Module")
272
274
 
273
275
  # Model Description
274
276
  model_description = rc.Text(
@@ -28,8 +28,8 @@ from ads.opctl.operator.lowcode.common.utils import (
28
28
  merged_category_column_name,
29
29
  seconds_to_datetime,
30
30
  write_data,
31
+ write_json,
31
32
  )
32
- from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
33
33
  from ads.opctl.operator.lowcode.forecast.utils import (
34
34
  _build_metrics_df,
35
35
  _build_metrics_per_horizon,
@@ -46,6 +46,7 @@ from ..const import (
46
46
  AUTO_SELECT,
47
47
  BACKTEST_REPORT_NAME,
48
48
  SUMMARY_METRICS_HORIZON_LIMIT,
49
+ ForecastOutputColumns,
49
50
  SpeedAccuracyMode,
50
51
  SupportedMetrics,
51
52
  SupportedModels,
@@ -132,11 +133,10 @@ class ForecastOperatorBaseModel(ABC):
132
133
 
133
134
  if self.datasets.test_data is not None:
134
135
  try:
135
- (
136
- self.test_eval_metrics,
137
- summary_metrics
138
- ) = self._test_evaluate_metrics(
139
- elapsed_time=elapsed_time,
136
+ (self.test_eval_metrics, summary_metrics) = (
137
+ self._test_evaluate_metrics(
138
+ elapsed_time=elapsed_time,
139
+ )
140
140
  )
141
141
  if not self.target_cat_col:
142
142
  self.test_eval_metrics.rename(
@@ -145,7 +145,7 @@ class ForecastOperatorBaseModel(ABC):
145
145
  inplace=True,
146
146
  )
147
147
  except Exception:
148
- logger.warn("Unable to generate Test Metrics.")
148
+ logger.warning("Unable to generate Test Metrics.")
149
149
  logger.debug(f"Full Traceback: {traceback.format_exc()}")
150
150
  report_sections = []
151
151
 
@@ -155,9 +155,8 @@ class ForecastOperatorBaseModel(ABC):
155
155
  model_description,
156
156
  other_sections,
157
157
  ) = self._generate_report()
158
-
159
158
  header_section = rc.Block(
160
- rc.Heading("Forecast Report", level=1),
159
+ rc.Heading(self.spec.report_title, level=1),
161
160
  rc.Text(
162
161
  f"You selected the {self.spec.model} model.\nBased on your dataset, you could have also selected any of the models: {SupportedModels.keys()}."
163
162
  ),
@@ -369,7 +368,7 @@ class ForecastOperatorBaseModel(ABC):
369
368
  -self.spec.horizon :
370
369
  ]
371
370
  except KeyError as ke:
372
- logger.warn(
371
+ logger.warning(
373
372
  f"Error Generating Metrics: Unable to find {s_id} in the test data. Error: {ke.args}"
374
373
  )
375
374
  y_pred = self.forecast_output.get_forecast(s_id)["forecast_value"].values[
@@ -478,10 +477,11 @@ class ForecastOperatorBaseModel(ABC):
478
477
  unique_output_dir = self.spec.output_directory.url
479
478
  results = ForecastResults()
480
479
 
481
- if ObjectStorageDetails.is_oci_path(unique_output_dir):
482
- storage_options = default_signer()
483
- else:
484
- storage_options = {}
480
+ storage_options = (
481
+ default_signer()
482
+ if ObjectStorageDetails.is_oci_path(unique_output_dir)
483
+ else {}
484
+ )
485
485
 
486
486
  # report-creator html report
487
487
  if self.spec.generate_report:
@@ -512,12 +512,13 @@ class ForecastOperatorBaseModel(ABC):
512
512
  if self.target_cat_col
513
513
  else result_df.drop(DataColumns.Series, axis=1)
514
514
  )
515
- write_data(
516
- data=result_df,
517
- filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
518
- format="csv",
519
- storage_options=storage_options,
520
- )
515
+ if self.spec.generate_forecast_file:
516
+ write_data(
517
+ data=result_df,
518
+ filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
519
+ format="csv",
520
+ storage_options=storage_options,
521
+ )
521
522
  results.set_forecast(result_df)
522
523
 
523
524
  # metrics csv report
@@ -531,18 +532,19 @@ class ForecastOperatorBaseModel(ABC):
531
532
  metrics_df_formatted = metrics_df.reset_index().rename(
532
533
  {"index": "metrics", "Series 1": metrics_col_name}, axis=1
533
534
  )
534
- write_data(
535
- data=metrics_df_formatted,
536
- filename=os.path.join(
537
- unique_output_dir, self.spec.metrics_filename
538
- ),
539
- format="csv",
540
- storage_options=storage_options,
541
- index=False,
542
- )
535
+ if self.spec.generate_metrics_file:
536
+ write_data(
537
+ data=metrics_df_formatted,
538
+ filename=os.path.join(
539
+ unique_output_dir, self.spec.metrics_filename
540
+ ),
541
+ format="csv",
542
+ storage_options=storage_options,
543
+ index=False,
544
+ )
543
545
  results.set_metrics(metrics_df_formatted)
544
546
  else:
545
- logger.warn(
547
+ logger.warning(
546
548
  f"Attempted to generate the {self.spec.metrics_filename} file with the training metrics, however the training metrics could not be properly generated."
547
549
  )
548
550
 
@@ -552,56 +554,59 @@ class ForecastOperatorBaseModel(ABC):
552
554
  test_metrics_df_formatted = test_metrics_df.reset_index().rename(
553
555
  {"index": "metrics", "Series 1": metrics_col_name}, axis=1
554
556
  )
555
- write_data(
556
- data=test_metrics_df_formatted,
557
- filename=os.path.join(
558
- unique_output_dir, self.spec.test_metrics_filename
559
- ),
560
- format="csv",
561
- storage_options=storage_options,
562
- index=False,
563
- )
557
+ if self.spec.generate_metrics_file:
558
+ write_data(
559
+ data=test_metrics_df_formatted,
560
+ filename=os.path.join(
561
+ unique_output_dir, self.spec.test_metrics_filename
562
+ ),
563
+ format="csv",
564
+ storage_options=storage_options,
565
+ index=False,
566
+ )
564
567
  results.set_test_metrics(test_metrics_df_formatted)
565
568
  else:
566
- logger.warn(
569
+ logger.warning(
567
570
  f"Attempted to generate the {self.spec.test_metrics_filename} file with the test metrics, however the test metrics could not be properly generated."
568
571
  )
569
572
  # explanations csv reports
570
573
  if self.spec.generate_explanations:
571
574
  try:
572
575
  if not self.formatted_global_explanation.empty:
573
- write_data(
574
- data=self.formatted_global_explanation,
575
- filename=os.path.join(
576
- unique_output_dir, self.spec.global_explanation_filename
577
- ),
578
- format="csv",
579
- storage_options=storage_options,
580
- index=True,
581
- )
576
+ if self.spec.generate_explanation_files:
577
+ write_data(
578
+ data=self.formatted_global_explanation,
579
+ filename=os.path.join(
580
+ unique_output_dir, self.spec.global_explanation_filename
581
+ ),
582
+ format="csv",
583
+ storage_options=storage_options,
584
+ index=True,
585
+ )
582
586
  results.set_global_explanations(self.formatted_global_explanation)
583
587
  else:
584
- logger.warn(
588
+ logger.warning(
585
589
  f"Attempted to generate global explanations for the {self.spec.global_explanation_filename} file, but an issue occured in formatting the explanations."
586
590
  )
587
591
 
588
592
  if not self.formatted_local_explanation.empty:
589
- write_data(
590
- data=self.formatted_local_explanation,
591
- filename=os.path.join(
592
- unique_output_dir, self.spec.local_explanation_filename
593
- ),
594
- format="csv",
595
- storage_options=storage_options,
596
- index=True,
597
- )
593
+ if self.spec.generate_explanation_files:
594
+ write_data(
595
+ data=self.formatted_local_explanation,
596
+ filename=os.path.join(
597
+ unique_output_dir, self.spec.local_explanation_filename
598
+ ),
599
+ format="csv",
600
+ storage_options=storage_options,
601
+ index=True,
602
+ )
598
603
  results.set_local_explanations(self.formatted_local_explanation)
599
604
  else:
600
- logger.warn(
605
+ logger.warning(
601
606
  f"Attempted to generate local explanations for the {self.spec.local_explanation_filename} file, but an issue occured in formatting the explanations."
602
607
  )
603
608
  except AttributeError as e:
604
- logger.warn(
609
+ logger.warning(
605
610
  "Unable to generate explanations for this model type or for this dataset."
606
611
  )
607
612
  logger.debug(f"Got error: {e.args}")
@@ -631,15 +636,12 @@ class ForecastOperatorBaseModel(ABC):
631
636
  f"The outputs have been successfully generated and placed into the directory: {unique_output_dir}."
632
637
  )
633
638
  if self.errors_dict:
634
- write_data(
635
- data=pd.DataFrame.from_dict(self.errors_dict),
639
+ write_json(
640
+ json_dict=self.errors_dict,
636
641
  filename=os.path.join(
637
642
  unique_output_dir, self.spec.errors_dict_filename
638
643
  ),
639
- format="json",
640
644
  storage_options=storage_options,
641
- index=True,
642
- indent=4,
643
645
  )
644
646
  results.set_errors_dict(self.errors_dict)
645
647
  else:
@@ -742,45 +744,62 @@ class ForecastOperatorBaseModel(ABC):
742
744
  include_horizon=False
743
745
  ).items():
744
746
  if s_id in self.models:
745
- explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
746
- data_trimmed = data_i.tail(
747
- max(int(len(data_i) * ratio), 5)
748
- ).reset_index(drop=True)
749
- data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply(
750
- lambda x: x.timestamp()
751
- )
752
-
753
- # Explainer fails when boolean columns are passed
754
-
755
- _, data_trimmed_encoded = _label_encode_dataframe(
756
- data_trimmed,
757
- no_encode={datetime_col_name, self.original_target_column},
758
- )
759
-
760
- kernel_explnr = PermutationExplainer(
761
- model=explain_predict_fn, masker=data_trimmed_encoded
762
- )
763
- kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
764
- exp_end_time = time.time()
765
- global_ex_time = global_ex_time + exp_end_time - exp_start_time
766
- self.local_explainer(
767
- kernel_explnr, series_id=s_id, datetime_col_name=datetime_col_name
768
- )
769
- local_ex_time = local_ex_time + time.time() - exp_end_time
747
+ try:
748
+ explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
749
+ data_trimmed = data_i.tail(
750
+ max(int(len(data_i) * ratio), 5)
751
+ ).reset_index(drop=True)
752
+ data_trimmed[datetime_col_name] = data_trimmed[
753
+ datetime_col_name
754
+ ].apply(lambda x: x.timestamp())
755
+
756
+ # Explainer fails when boolean columns are passed
757
+
758
+ _, data_trimmed_encoded = _label_encode_dataframe(
759
+ data_trimmed,
760
+ no_encode={datetime_col_name, self.original_target_column},
761
+ )
770
762
 
771
- if not len(kernel_explnr_vals):
772
- logger.warn(
773
- "No explanations generated. Ensure that additional data has been provided."
763
+ kernel_explnr = PermutationExplainer(
764
+ model=explain_predict_fn, masker=data_trimmed_encoded
774
765
  )
775
- else:
776
- self.global_explanation[s_id] = dict(
777
- zip(
778
- data_trimmed.columns[1:],
779
- np.average(np.absolute(kernel_explnr_vals[:, 1:]), axis=0),
780
- )
766
+ kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
767
+ exp_end_time = time.time()
768
+ global_ex_time = global_ex_time + exp_end_time - exp_start_time
769
+ self.local_explainer(
770
+ kernel_explnr,
771
+ series_id=s_id,
772
+ datetime_col_name=datetime_col_name,
781
773
  )
774
+ local_ex_time = local_ex_time + time.time() - exp_end_time
775
+
776
+ if not len(kernel_explnr_vals):
777
+ logger.warning(
778
+ "No explanations generated. Ensure that additional data has been provided."
779
+ )
780
+ else:
781
+ self.global_explanation[s_id] = dict(
782
+ zip(
783
+ data_trimmed.columns[1:],
784
+ np.average(
785
+ np.absolute(kernel_explnr_vals[:, 1:]), axis=0
786
+ ),
787
+ )
788
+ )
789
+ except Exception as e:
790
+ if s_id in self.errors_dict:
791
+ self.errors_dict[s_id]["explainer_error"] = str(e)
792
+ self.errors_dict[s_id]["explainer_error_trace"] = (
793
+ traceback.format_exc()
794
+ )
795
+ else:
796
+ self.errors_dict[s_id] = {
797
+ "model_name": self.spec.model,
798
+ "explainer_error": str(e),
799
+ "explainer_error_trace": traceback.format_exc(),
800
+ }
782
801
  else:
783
- logger.warn(
802
+ logger.warning(
784
803
  f"Skipping explanations for {s_id}, as forecast was not generated."
785
804
  )
786
805
 
@@ -815,6 +834,13 @@ class ForecastOperatorBaseModel(ABC):
815
834
  local_kernel_explnr_df = pd.DataFrame(
816
835
  local_kernel_explnr_vals, columns=data.columns
817
836
  )
837
+
838
+ # Add date column to local explanation DataFrame
839
+ local_kernel_explnr_df[ForecastOutputColumns.DATE] = (
840
+ self.datasets.get_horizon_at_series(
841
+ s_id=series_id
842
+ )[self.spec.datetime_column.name].reset_index(drop=True)
843
+ )
818
844
  self.local_explanation[series_id] = local_kernel_explnr_df
819
845
 
820
846
  def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"):