oracle-ads 2.13.1rc0__py3-none-any.whl → 2.13.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. ads/aqua/__init__.py +7 -1
  2. ads/aqua/app.py +24 -23
  3. ads/aqua/client/client.py +48 -11
  4. ads/aqua/common/entities.py +28 -1
  5. ads/aqua/common/enums.py +13 -7
  6. ads/aqua/common/utils.py +8 -13
  7. ads/aqua/config/container_config.py +203 -0
  8. ads/aqua/config/evaluation/evaluation_service_config.py +5 -181
  9. ads/aqua/constants.py +0 -1
  10. ads/aqua/evaluation/evaluation.py +4 -4
  11. ads/aqua/extension/base_handler.py +4 -0
  12. ads/aqua/extension/model_handler.py +19 -28
  13. ads/aqua/finetuning/finetuning.py +2 -3
  14. ads/aqua/model/entities.py +2 -3
  15. ads/aqua/model/model.py +25 -30
  16. ads/aqua/modeldeployment/deployment.py +6 -14
  17. ads/aqua/modeldeployment/entities.py +2 -2
  18. ads/aqua/server/__init__.py +4 -0
  19. ads/aqua/server/__main__.py +24 -0
  20. ads/aqua/server/app.py +47 -0
  21. ads/aqua/server/aqua_spec.yml +1291 -0
  22. ads/aqua/ui.py +5 -199
  23. ads/common/auth.py +20 -11
  24. ads/common/utils.py +91 -11
  25. ads/config.py +3 -0
  26. ads/llm/__init__.py +1 -0
  27. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +32 -23
  28. ads/model/artifact_downloader.py +4 -1
  29. ads/model/common/utils.py +15 -3
  30. ads/model/datascience_model.py +339 -8
  31. ads/model/model_metadata.py +54 -14
  32. ads/model/model_version_set.py +5 -3
  33. ads/model/service/oci_datascience_model.py +477 -5
  34. ads/opctl/operator/common/utils.py +16 -0
  35. ads/opctl/operator/lowcode/anomaly/model/base_model.py +3 -3
  36. ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +1 -1
  37. ads/opctl/operator/lowcode/anomaly/utils.py +1 -1
  38. ads/opctl/operator/lowcode/common/data.py +5 -2
  39. ads/opctl/operator/lowcode/common/transformations.py +7 -13
  40. ads/opctl/operator/lowcode/common/utils.py +7 -2
  41. ads/opctl/operator/lowcode/forecast/model/arima.py +15 -10
  42. ads/opctl/operator/lowcode/forecast/model/automlx.py +39 -9
  43. ads/opctl/operator/lowcode/forecast/model/autots.py +7 -5
  44. ads/opctl/operator/lowcode/forecast/model/base_model.py +135 -110
  45. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +30 -14
  46. ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +2 -2
  47. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +46 -32
  48. ads/opctl/operator/lowcode/forecast/model/prophet.py +82 -29
  49. ads/opctl/operator/lowcode/forecast/model_evaluator.py +142 -62
  50. ads/opctl/operator/lowcode/forecast/operator_config.py +29 -3
  51. ads/opctl/operator/lowcode/forecast/schema.yaml +1 -1
  52. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +108 -56
  53. {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info}/METADATA +15 -12
  54. {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info}/RECORD +57 -53
  55. {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info}/WHEEL +1 -1
  56. ads/aqua/config/evaluation/evaluation_service_model_config.py +0 -8
  57. {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info}/entry_points.txt +0 -0
  58. {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info/licenses}/LICENSE.txt +0 -0
@@ -3,6 +3,7 @@
3
3
  # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
+ import json
6
7
  import logging
7
8
  import os
8
9
  import shutil
@@ -12,7 +13,6 @@ from typing import List, Union
12
13
 
13
14
  import fsspec
14
15
  import oracledb
15
- import json
16
16
  import pandas as pd
17
17
 
18
18
  from ads.common.object_storage_details import ObjectStorageDetails
@@ -142,6 +142,11 @@ def write_data(data, filename, format, storage_options=None, index=False, **kwar
142
142
  )
143
143
 
144
144
 
145
+ def write_json(json_dict, filename, storage_options=None):
146
+ with fsspec.open(filename, mode="w", **storage_options) as f:
147
+ f.write(json.dumps(json_dict))
148
+
149
+
145
150
  def write_simple_json(data, path):
146
151
  if ObjectStorageDetails.is_oci_path(path):
147
152
  storage_options = default_signer()
@@ -265,7 +270,7 @@ def find_output_dirname(output_dir: OutputDirectory):
265
270
  while os.path.exists(unique_output_dir):
266
271
  unique_output_dir = f"{output_dir}_{counter}"
267
272
  counter += 1
268
- logger.warn(
273
+ logger.warning(
269
274
  f"Since the output directory was not specified, the output will be saved to {unique_output_dir} directory."
270
275
  )
271
276
  return unique_output_dir
@@ -1,6 +1,6 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
6
  import logging
@@ -132,13 +132,14 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
132
132
 
133
133
  logger.debug("===========Done===========")
134
134
  except Exception as e:
135
- self.errors_dict[s_id] = {
135
+ new_error = {
136
136
  "model_name": self.spec.model,
137
137
  "error": str(e),
138
138
  "error_trace": traceback.format_exc(),
139
139
  }
140
- logger.warn(f"Encountered Error: {e}. Skipping.")
141
- logger.warn(traceback.format_exc())
140
+ self.errors_dict[s_id] = new_error
141
+ logger.warning(f"Encountered Error: {e}. Skipping.")
142
+ logger.warning(traceback.format_exc())
142
143
 
143
144
  def _build_model(self) -> pd.DataFrame:
144
145
  full_data_dict = self.datasets.get_data_by_series()
@@ -166,7 +167,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
166
167
  sec5_text = rc.Heading("ARIMA Model Parameters", level=2)
167
168
  blocks = [
168
169
  rc.Html(
169
- m['model'].summary().as_html(),
170
+ m["model"].summary().as_html(),
170
171
  label=s_id if self.target_cat_col else None,
171
172
  )
172
173
  for i, (s_id, m) in enumerate(self.models.items())
@@ -201,11 +202,15 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
201
202
  self.formatted_local_explanation = aggregate_local_explanations
202
203
 
203
204
  if not self.target_cat_col:
204
- self.formatted_global_explanation = self.formatted_global_explanation.rename(
205
- {"Series 1": self.original_target_column},
206
- axis=1,
205
+ self.formatted_global_explanation = (
206
+ self.formatted_global_explanation.rename(
207
+ {"Series 1": self.original_target_column},
208
+ axis=1,
209
+ )
210
+ )
211
+ self.formatted_local_explanation.drop(
212
+ "Series", axis=1, inplace=True
207
213
  )
208
- self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
209
214
 
210
215
  # Create a markdown section for the global explainability
211
216
  global_explanation_section = rc.Block(
@@ -235,7 +240,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
235
240
  local_explanation_section,
236
241
  ]
237
242
  except Exception as e:
238
- logger.warn(f"Failed to generate Explanations with error: {e}.")
243
+ logger.warning(f"Failed to generate Explanations with error: {e}.")
239
244
  logger.debug(f"Full Traceback: {traceback.format_exc()}")
240
245
 
241
246
  model_description = rc.Text(
@@ -184,13 +184,18 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
184
184
  "selected_model_params": model.selected_model_params_,
185
185
  }
186
186
  except Exception as e:
187
- self.errors_dict[s_id] = {
187
+ new_error = {
188
188
  "model_name": self.spec.model,
189
189
  "error": str(e),
190
190
  "error_trace": traceback.format_exc(),
191
191
  }
192
- logger.warn(f"Encountered Error: {e}. Skipping.")
193
- logger.warn(traceback.format_exc())
192
+ if s_id in self.errors_dict:
193
+ self.errors_dict[s_id]["model_fitting"] = new_error
194
+ else:
195
+ self.errors_dict[s_id] = {"model_fitting": new_error}
196
+ logger.warning(f"Encountered Error: {e}. Skipping.")
197
+ logger.warning(f"self.errors_dict[s_id]: {self.errors_dict[s_id]}")
198
+ logger.warning(traceback.format_exc())
194
199
 
195
200
  logger.debug("===========Forecast Generated===========")
196
201
 
@@ -249,7 +254,6 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
249
254
  self.explain_model()
250
255
 
251
256
  global_explanation_section = None
252
-
253
257
  # Convert the global explanation data to a DataFrame
254
258
  global_explanation_df = pd.DataFrame(self.global_explanation)
255
259
 
@@ -258,7 +262,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
258
262
  )
259
263
 
260
264
  self.formatted_global_explanation.rename(
261
- columns={self.spec.datetime_column.name: ForecastOutputColumns.DATE},
265
+ columns={
266
+ self.spec.datetime_column.name: ForecastOutputColumns.DATE
267
+ },
262
268
  inplace=True,
263
269
  )
264
270
 
@@ -313,7 +319,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
313
319
  local_explanation_section,
314
320
  ]
315
321
  except Exception as e:
316
- logger.warn(f"Failed to generate Explanations with error: {e}.")
322
+ logger.warning(f"Failed to generate Explanations with error: {e}.")
317
323
  logger.debug(f"Full Traceback: {traceback.format_exc()}")
318
324
 
319
325
  model_description = rc.Text(
@@ -463,20 +469,44 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
463
469
  index="row", columns="Feature", values="Attribution"
464
470
  )
465
471
  explanations_df = explanations_df.reset_index(drop=True)
466
-
472
+ explanations_df[ForecastOutputColumns.DATE] = (
473
+ self.datasets.get_horizon_at_series(
474
+ s_id=s_id
475
+ )[self.spec.datetime_column.name].reset_index(drop=True)
476
+ )
467
477
  # Store the explanations in the local_explanation dictionary
468
478
  self.local_explanation[s_id] = explanations_df
469
479
 
470
480
  self.global_explanation[s_id] = dict(
471
481
  zip(
472
- self.local_explanation[s_id].columns,
473
- np.nanmean((self.local_explanation[s_id]), axis=0),
482
+ self.local_explanation[s_id]
483
+ .drop(ForecastOutputColumns.DATE, axis=1)
484
+ .columns,
485
+ np.nanmean(
486
+ np.abs(
487
+ self.local_explanation[s_id].drop(
488
+ ForecastOutputColumns.DATE, axis=1
489
+ )
490
+ ),
491
+ axis=0,
492
+ ),
474
493
  )
475
494
  )
476
495
  else:
477
496
  # Fall back to the default explanation generation method
478
497
  super().explain_model()
479
498
  except Exception as e:
499
+ if s_id in self.errors_dict:
500
+ self.errors_dict[s_id]["explainer_error"] = str(e)
501
+ self.errors_dict[s_id]["explainer_error_trace"] = (
502
+ traceback.format_exc()
503
+ )
504
+ else:
505
+ self.errors_dict[s_id] = {
506
+ "model_name": self.spec.model,
507
+ "explainer_error": str(e),
508
+ "explainer_error_trace": traceback.format_exc(),
509
+ }
480
510
  logger.warning(
481
511
  f"Failed to generate explanations for series {s_id} with error: {e}."
482
512
  )
@@ -211,8 +211,8 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
211
211
  "error": str(e),
212
212
  "error_trace": traceback.format_exc(),
213
213
  }
214
- logger.warn(f"Encountered Error: {e}. Skipping.")
215
- logger.warn(traceback.format_exc())
214
+ logger.warning(f"Encountered Error: {e}. Skipping.")
215
+ logger.warning(traceback.format_exc())
216
216
 
217
217
  logger.debug("===========Done===========")
218
218
 
@@ -242,7 +242,7 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
242
242
  self.models.df_wide_numeric, series=s_id
243
243
  ),
244
244
  self.datasets.list_series_ids(),
245
- target_category_column=self.target_cat_col
245
+ target_category_column=self.target_cat_col,
246
246
  )
247
247
  section_1 = rc.Block(
248
248
  rc.Heading("Forecast Overview", level=2),
@@ -260,7 +260,9 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
260
260
  )
261
261
 
262
262
  except KeyError:
263
- logger.warn("Issue generating Model Parameters Table Section. Skipping")
263
+ logger.warning(
264
+ "Issue generating Model Parameters Table Section. Skipping"
265
+ )
264
266
  sec2 = rc.Text("Error generating model parameters.")
265
267
 
266
268
  section_2 = rc.Block(sec2_text, sec2)
@@ -268,7 +270,7 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
268
270
  all_sections = [section_1, section_2]
269
271
 
270
272
  if self.spec.generate_explanations:
271
- logger.warn("Explanations not yet supported for the AutoTS Module")
273
+ logger.warning("Explanations not yet supported for the AutoTS Module")
272
274
 
273
275
  # Model Description
274
276
  model_description = rc.Text(
@@ -28,8 +28,8 @@ from ads.opctl.operator.lowcode.common.utils import (
28
28
  merged_category_column_name,
29
29
  seconds_to_datetime,
30
30
  write_data,
31
+ write_json,
31
32
  )
32
- from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
33
33
  from ads.opctl.operator.lowcode.forecast.utils import (
34
34
  _build_metrics_df,
35
35
  _build_metrics_per_horizon,
@@ -46,6 +46,7 @@ from ..const import (
46
46
  AUTO_SELECT,
47
47
  BACKTEST_REPORT_NAME,
48
48
  SUMMARY_METRICS_HORIZON_LIMIT,
49
+ ForecastOutputColumns,
49
50
  SpeedAccuracyMode,
50
51
  SupportedMetrics,
51
52
  SupportedModels,
@@ -120,7 +121,7 @@ class ForecastOperatorBaseModel(ABC):
120
121
 
121
122
  # Generate metrics
122
123
  summary_metrics = None
123
- test_data = None
124
+ test_data = self.datasets.test_data
124
125
  self.eval_metrics = None
125
126
 
126
127
  if self.spec.generate_report or self.spec.generate_metrics:
@@ -130,14 +131,12 @@ class ForecastOperatorBaseModel(ABC):
130
131
  {"Series 1": self.original_target_column}, axis=1, inplace=True
131
132
  )
132
133
 
133
- if self.spec.test_data:
134
+ if self.datasets.test_data is not None:
134
135
  try:
135
- (
136
- self.test_eval_metrics,
137
- summary_metrics,
138
- test_data,
139
- ) = self._test_evaluate_metrics(
140
- elapsed_time=elapsed_time,
136
+ (self.test_eval_metrics, summary_metrics) = (
137
+ self._test_evaluate_metrics(
138
+ elapsed_time=elapsed_time,
139
+ )
141
140
  )
142
141
  if not self.target_cat_col:
143
142
  self.test_eval_metrics.rename(
@@ -146,7 +145,7 @@ class ForecastOperatorBaseModel(ABC):
146
145
  inplace=True,
147
146
  )
148
147
  except Exception:
149
- logger.warn("Unable to generate Test Metrics.")
148
+ logger.warning("Unable to generate Test Metrics.")
150
149
  logger.debug(f"Full Traceback: {traceback.format_exc()}")
151
150
  report_sections = []
152
151
 
@@ -156,9 +155,8 @@ class ForecastOperatorBaseModel(ABC):
156
155
  model_description,
157
156
  other_sections,
158
157
  ) = self._generate_report()
159
-
160
158
  header_section = rc.Block(
161
- rc.Heading("Forecast Report", level=1),
159
+ rc.Heading(self.spec.report_title, level=1),
162
160
  rc.Text(
163
161
  f"You selected the {self.spec.model} model.\nBased on your dataset, you could have also selected any of the models: {SupportedModels.keys()}."
164
162
  ),
@@ -361,7 +359,7 @@ class ForecastOperatorBaseModel(ABC):
361
359
  def _test_evaluate_metrics(self, elapsed_time=0):
362
360
  total_metrics = pd.DataFrame()
363
361
  summary_metrics = pd.DataFrame()
364
- data = TestData(self.spec)
362
+ data = self.datasets.test_data
365
363
 
366
364
  # Generate y_pred and y_true for each series
367
365
  for s_id in self.forecast_output.list_series_ids():
@@ -370,7 +368,7 @@ class ForecastOperatorBaseModel(ABC):
370
368
  -self.spec.horizon :
371
369
  ]
372
370
  except KeyError as ke:
373
- logger.warn(
371
+ logger.warning(
374
372
  f"Error Generating Metrics: Unable to find {s_id} in the test data. Error: {ke.args}"
375
373
  )
376
374
  y_pred = self.forecast_output.get_forecast(s_id)["forecast_value"].values[
@@ -398,7 +396,7 @@ class ForecastOperatorBaseModel(ABC):
398
396
  total_metrics = pd.concat([total_metrics, metrics_df], axis=1)
399
397
 
400
398
  if total_metrics.empty:
401
- return total_metrics, summary_metrics, data
399
+ return total_metrics, summary_metrics
402
400
 
403
401
  summary_metrics = pd.DataFrame(
404
402
  {
@@ -464,7 +462,7 @@ class ForecastOperatorBaseModel(ABC):
464
462
  ]
465
463
  summary_metrics = summary_metrics[new_column_order]
466
464
 
467
- return total_metrics, summary_metrics, data
465
+ return total_metrics, summary_metrics
468
466
 
469
467
  def _save_report(
470
468
  self,
@@ -479,10 +477,11 @@ class ForecastOperatorBaseModel(ABC):
479
477
  unique_output_dir = self.spec.output_directory.url
480
478
  results = ForecastResults()
481
479
 
482
- if ObjectStorageDetails.is_oci_path(unique_output_dir):
483
- storage_options = default_signer()
484
- else:
485
- storage_options = {}
480
+ storage_options = (
481
+ default_signer()
482
+ if ObjectStorageDetails.is_oci_path(unique_output_dir)
483
+ else {}
484
+ )
486
485
 
487
486
  # report-creator html report
488
487
  if self.spec.generate_report:
@@ -513,12 +512,13 @@ class ForecastOperatorBaseModel(ABC):
513
512
  if self.target_cat_col
514
513
  else result_df.drop(DataColumns.Series, axis=1)
515
514
  )
516
- write_data(
517
- data=result_df,
518
- filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
519
- format="csv",
520
- storage_options=storage_options,
521
- )
515
+ if self.spec.generate_forecast_file:
516
+ write_data(
517
+ data=result_df,
518
+ filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
519
+ format="csv",
520
+ storage_options=storage_options,
521
+ )
522
522
  results.set_forecast(result_df)
523
523
 
524
524
  # metrics csv report
@@ -532,77 +532,81 @@ class ForecastOperatorBaseModel(ABC):
532
532
  metrics_df_formatted = metrics_df.reset_index().rename(
533
533
  {"index": "metrics", "Series 1": metrics_col_name}, axis=1
534
534
  )
535
- write_data(
536
- data=metrics_df_formatted,
537
- filename=os.path.join(
538
- unique_output_dir, self.spec.metrics_filename
539
- ),
540
- format="csv",
541
- storage_options=storage_options,
542
- index=False,
543
- )
535
+ if self.spec.generate_metrics_file:
536
+ write_data(
537
+ data=metrics_df_formatted,
538
+ filename=os.path.join(
539
+ unique_output_dir, self.spec.metrics_filename
540
+ ),
541
+ format="csv",
542
+ storage_options=storage_options,
543
+ index=False,
544
+ )
544
545
  results.set_metrics(metrics_df_formatted)
545
546
  else:
546
- logger.warn(
547
+ logger.warning(
547
548
  f"Attempted to generate the {self.spec.metrics_filename} file with the training metrics, however the training metrics could not be properly generated."
548
549
  )
549
550
 
550
551
  # test_metrics csv report
551
- if self.spec.test_data is not None:
552
+ if self.datasets.test_data is not None:
552
553
  if test_metrics_df is not None:
553
554
  test_metrics_df_formatted = test_metrics_df.reset_index().rename(
554
555
  {"index": "metrics", "Series 1": metrics_col_name}, axis=1
555
556
  )
556
- write_data(
557
- data=test_metrics_df_formatted,
558
- filename=os.path.join(
559
- unique_output_dir, self.spec.test_metrics_filename
560
- ),
561
- format="csv",
562
- storage_options=storage_options,
563
- index=False,
564
- )
557
+ if self.spec.generate_metrics_file:
558
+ write_data(
559
+ data=test_metrics_df_formatted,
560
+ filename=os.path.join(
561
+ unique_output_dir, self.spec.test_metrics_filename
562
+ ),
563
+ format="csv",
564
+ storage_options=storage_options,
565
+ index=False,
566
+ )
565
567
  results.set_test_metrics(test_metrics_df_formatted)
566
568
  else:
567
- logger.warn(
569
+ logger.warning(
568
570
  f"Attempted to generate the {self.spec.test_metrics_filename} file with the test metrics, however the test metrics could not be properly generated."
569
571
  )
570
572
  # explanations csv reports
571
573
  if self.spec.generate_explanations:
572
574
  try:
573
- if self.formatted_global_explanation is not None:
574
- write_data(
575
- data=self.formatted_global_explanation,
576
- filename=os.path.join(
577
- unique_output_dir, self.spec.global_explanation_filename
578
- ),
579
- format="csv",
580
- storage_options=storage_options,
581
- index=True,
582
- )
575
+ if not self.formatted_global_explanation.empty:
576
+ if self.spec.generate_explanation_files:
577
+ write_data(
578
+ data=self.formatted_global_explanation,
579
+ filename=os.path.join(
580
+ unique_output_dir, self.spec.global_explanation_filename
581
+ ),
582
+ format="csv",
583
+ storage_options=storage_options,
584
+ index=True,
585
+ )
583
586
  results.set_global_explanations(self.formatted_global_explanation)
584
587
  else:
585
- logger.warn(
588
+ logger.warning(
586
589
  f"Attempted to generate global explanations for the {self.spec.global_explanation_filename} file, but an issue occured in formatting the explanations."
587
590
  )
588
591
 
589
- if self.formatted_local_explanation is not None:
590
- write_data(
591
- data=self.formatted_local_explanation,
592
- filename=os.path.join(
593
- unique_output_dir, self.spec.local_explanation_filename
594
- ),
595
- format="csv",
596
- storage_options=storage_options,
597
- index=True,
598
- )
592
+ if not self.formatted_local_explanation.empty:
593
+ if self.spec.generate_explanation_files:
594
+ write_data(
595
+ data=self.formatted_local_explanation,
596
+ filename=os.path.join(
597
+ unique_output_dir, self.spec.local_explanation_filename
598
+ ),
599
+ format="csv",
600
+ storage_options=storage_options,
601
+ index=True,
602
+ )
599
603
  results.set_local_explanations(self.formatted_local_explanation)
600
604
  else:
601
- logger.warn(
605
+ logger.warning(
602
606
  f"Attempted to generate local explanations for the {self.spec.local_explanation_filename} file, but an issue occured in formatting the explanations."
603
607
  )
604
608
  except AttributeError as e:
605
- logger.warn(
609
+ logger.warning(
606
610
  "Unable to generate explanations for this model type or for this dataset."
607
611
  )
608
612
  logger.debug(f"Got error: {e.args}")
@@ -632,15 +636,12 @@ class ForecastOperatorBaseModel(ABC):
632
636
  f"The outputs have been successfully generated and placed into the directory: {unique_output_dir}."
633
637
  )
634
638
  if self.errors_dict:
635
- write_data(
636
- data=pd.DataFrame.from_dict(self.errors_dict),
639
+ write_json(
640
+ json_dict=self.errors_dict,
637
641
  filename=os.path.join(
638
642
  unique_output_dir, self.spec.errors_dict_filename
639
643
  ),
640
- format="json",
641
644
  storage_options=storage_options,
642
- index=True,
643
- indent=4,
644
645
  )
645
646
  results.set_errors_dict(self.errors_dict)
646
647
  else:
@@ -743,45 +744,62 @@ class ForecastOperatorBaseModel(ABC):
743
744
  include_horizon=False
744
745
  ).items():
745
746
  if s_id in self.models:
746
- explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
747
- data_trimmed = data_i.tail(
748
- max(int(len(data_i) * ratio), 5)
749
- ).reset_index(drop=True)
750
- data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply(
751
- lambda x: x.timestamp()
752
- )
753
-
754
- # Explainer fails when boolean columns are passed
755
-
756
- _, data_trimmed_encoded = _label_encode_dataframe(
757
- data_trimmed,
758
- no_encode={datetime_col_name, self.original_target_column},
759
- )
760
-
761
- kernel_explnr = PermutationExplainer(
762
- model=explain_predict_fn, masker=data_trimmed_encoded
763
- )
764
- kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
765
- exp_end_time = time.time()
766
- global_ex_time = global_ex_time + exp_end_time - exp_start_time
767
- self.local_explainer(
768
- kernel_explnr, series_id=s_id, datetime_col_name=datetime_col_name
769
- )
770
- local_ex_time = local_ex_time + time.time() - exp_end_time
747
+ try:
748
+ explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
749
+ data_trimmed = data_i.tail(
750
+ max(int(len(data_i) * ratio), 5)
751
+ ).reset_index(drop=True)
752
+ data_trimmed[datetime_col_name] = data_trimmed[
753
+ datetime_col_name
754
+ ].apply(lambda x: x.timestamp())
755
+
756
+ # Explainer fails when boolean columns are passed
757
+
758
+ _, data_trimmed_encoded = _label_encode_dataframe(
759
+ data_trimmed,
760
+ no_encode={datetime_col_name, self.original_target_column},
761
+ )
771
762
 
772
- if not len(kernel_explnr_vals):
773
- logger.warn(
774
- "No explanations generated. Ensure that additional data has been provided."
763
+ kernel_explnr = PermutationExplainer(
764
+ model=explain_predict_fn, masker=data_trimmed_encoded
775
765
  )
776
- else:
777
- self.global_explanation[s_id] = dict(
778
- zip(
779
- data_trimmed.columns[1:],
780
- np.average(np.absolute(kernel_explnr_vals[:, 1:]), axis=0),
781
- )
766
+ kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
767
+ exp_end_time = time.time()
768
+ global_ex_time = global_ex_time + exp_end_time - exp_start_time
769
+ self.local_explainer(
770
+ kernel_explnr,
771
+ series_id=s_id,
772
+ datetime_col_name=datetime_col_name,
782
773
  )
774
+ local_ex_time = local_ex_time + time.time() - exp_end_time
775
+
776
+ if not len(kernel_explnr_vals):
777
+ logger.warning(
778
+ "No explanations generated. Ensure that additional data has been provided."
779
+ )
780
+ else:
781
+ self.global_explanation[s_id] = dict(
782
+ zip(
783
+ data_trimmed.columns[1:],
784
+ np.average(
785
+ np.absolute(kernel_explnr_vals[:, 1:]), axis=0
786
+ ),
787
+ )
788
+ )
789
+ except Exception as e:
790
+ if s_id in self.errors_dict:
791
+ self.errors_dict[s_id]["explainer_error"] = str(e)
792
+ self.errors_dict[s_id]["explainer_error_trace"] = (
793
+ traceback.format_exc()
794
+ )
795
+ else:
796
+ self.errors_dict[s_id] = {
797
+ "model_name": self.spec.model,
798
+ "explainer_error": str(e),
799
+ "explainer_error_trace": traceback.format_exc(),
800
+ }
783
801
  else:
784
- logger.warn(
802
+ logger.warning(
785
803
  f"Skipping explanations for {s_id}, as forecast was not generated."
786
804
  )
787
805
 
@@ -816,6 +834,13 @@ class ForecastOperatorBaseModel(ABC):
816
834
  local_kernel_explnr_df = pd.DataFrame(
817
835
  local_kernel_explnr_vals, columns=data.columns
818
836
  )
837
+
838
+ # Add date column to local explanation DataFrame
839
+ local_kernel_explnr_df[ForecastOutputColumns.DATE] = (
840
+ self.datasets.get_horizon_at_series(
841
+ s_id=series_id
842
+ )[self.spec.datetime_column.name].reset_index(drop=True)
843
+ )
819
844
  self.local_explanation[series_id] = local_kernel_explnr_df
820
845
 
821
846
  def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"):