oracle-ads 2.13.2__py3-none-any.whl → 2.13.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +3 -3
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +1 -1
- ads/opctl/operator/lowcode/anomaly/utils.py +1 -1
- ads/opctl/operator/lowcode/common/transformations.py +5 -1
- ads/opctl/operator/lowcode/common/utils.py +7 -2
- ads/opctl/operator/lowcode/forecast/model/arima.py +15 -10
- ads/opctl/operator/lowcode/forecast/model/automlx.py +31 -9
- ads/opctl/operator/lowcode/forecast/model/autots.py +7 -5
- ads/opctl/operator/lowcode/forecast/model/base_model.py +127 -101
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +14 -6
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +2 -2
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +46 -32
- ads/opctl/operator/lowcode/forecast/model/prophet.py +82 -29
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +136 -54
- ads/opctl/operator/lowcode/forecast/operator_config.py +29 -3
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +103 -58
- {oracle_ads-2.13.2.dist-info → oracle_ads-2.13.2rc1.dist-info}/METADATA +1 -1
- {oracle_ads-2.13.2.dist-info → oracle_ads-2.13.2rc1.dist-info}/RECORD +21 -21
- {oracle_ads-2.13.2.dist-info → oracle_ads-2.13.2rc1.dist-info}/WHEEL +0 -0
- {oracle_ads-2.13.2.dist-info → oracle_ads-2.13.2rc1.dist-info}/entry_points.txt +0 -0
- {oracle_ads-2.13.2.dist-info → oracle_ads-2.13.2rc1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -71,7 +71,7 @@ class AnomalyOperatorBaseModel(ABC):
|
|
71
71
|
try:
|
72
72
|
anomaly_output = self._build_model()
|
73
73
|
except Exception as e:
|
74
|
-
logger.
|
74
|
+
logger.warning(f"Found exception: {e}")
|
75
75
|
if self.spec.datetime_column:
|
76
76
|
anomaly_output = self._fallback_build_model()
|
77
77
|
raise e
|
@@ -347,7 +347,7 @@ class AnomalyOperatorBaseModel(ABC):
|
|
347
347
|
storage_options=storage_options,
|
348
348
|
)
|
349
349
|
|
350
|
-
logger.
|
350
|
+
logger.warning(
|
351
351
|
f"The report has been successfully "
|
352
352
|
f"generated and placed to the: {unique_output_dir}."
|
353
353
|
)
|
@@ -356,7 +356,7 @@ class AnomalyOperatorBaseModel(ABC):
|
|
356
356
|
"""
|
357
357
|
Fallback method for the sub model _build_model method.
|
358
358
|
"""
|
359
|
-
logger.
|
359
|
+
logger.warning(
|
360
360
|
f"The build_model method has failed for the model: {self.spec.model}. "
|
361
361
|
"A fallback model will be built."
|
362
362
|
)
|
@@ -95,7 +95,7 @@ class RandomCutForestOperatorModel(AnomalyOperatorBaseModel):
|
|
95
95
|
|
96
96
|
anomaly_output.add_output(target, anomaly, score)
|
97
97
|
except Exception as e:
|
98
|
-
logger.
|
98
|
+
logger.warning(f"Encountered Error: {e}. Skipping series {target}.")
|
99
99
|
|
100
100
|
return anomaly_output
|
101
101
|
|
@@ -44,7 +44,7 @@ def _build_metrics_df(y_true, y_pred, column_name):
|
|
44
44
|
# Throws exception if y_true has only one class
|
45
45
|
metrics[SupportedMetrics.ROC_AUC] = roc_auc_score(y_true, y_pred)
|
46
46
|
except Exception as e:
|
47
|
-
logger.
|
47
|
+
logger.warning(f"An exception occurred: {e}")
|
48
48
|
metrics[SupportedMetrics.ROC_AUC] = None
|
49
49
|
precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
|
50
50
|
metrics[SupportedMetrics.PRC_AUC] = auc(recall, precision)
|
@@ -98,7 +98,11 @@ class Transformations(ABC):
|
|
98
98
|
return clean_df
|
99
99
|
|
100
100
|
def _remove_trailing_whitespace(self, df):
|
101
|
-
return df.apply(
|
101
|
+
return df.apply(
|
102
|
+
lambda x: x.str.strip()
|
103
|
+
if hasattr(x, "dtype") and x.dtype == "object"
|
104
|
+
else x
|
105
|
+
)
|
102
106
|
|
103
107
|
def _clean_column_names(self, df):
|
104
108
|
"""
|
@@ -3,6 +3,7 @@
|
|
3
3
|
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
|
6
|
+
import json
|
6
7
|
import logging
|
7
8
|
import os
|
8
9
|
import shutil
|
@@ -12,7 +13,6 @@ from typing import List, Union
|
|
12
13
|
|
13
14
|
import fsspec
|
14
15
|
import oracledb
|
15
|
-
import json
|
16
16
|
import pandas as pd
|
17
17
|
|
18
18
|
from ads.common.object_storage_details import ObjectStorageDetails
|
@@ -142,6 +142,11 @@ def write_data(data, filename, format, storage_options=None, index=False, **kwar
|
|
142
142
|
)
|
143
143
|
|
144
144
|
|
145
|
+
def write_json(json_dict, filename, storage_options=None):
|
146
|
+
with fsspec.open(filename, mode="w", **storage_options) as f:
|
147
|
+
f.write(json.dumps(json_dict))
|
148
|
+
|
149
|
+
|
145
150
|
def write_simple_json(data, path):
|
146
151
|
if ObjectStorageDetails.is_oci_path(path):
|
147
152
|
storage_options = default_signer()
|
@@ -265,7 +270,7 @@ def find_output_dirname(output_dir: OutputDirectory):
|
|
265
270
|
while os.path.exists(unique_output_dir):
|
266
271
|
unique_output_dir = f"{output_dir}_{counter}"
|
267
272
|
counter += 1
|
268
|
-
logger.
|
273
|
+
logger.warning(
|
269
274
|
f"Since the output directory was not specified, the output will be saved to {unique_output_dir} directory."
|
270
275
|
)
|
271
276
|
return unique_output_dir
|
@@ -1,6 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
|
3
|
-
# Copyright (c) 2023,
|
3
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
|
6
6
|
import logging
|
@@ -132,13 +132,14 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
132
132
|
|
133
133
|
logger.debug("===========Done===========")
|
134
134
|
except Exception as e:
|
135
|
-
|
135
|
+
new_error = {
|
136
136
|
"model_name": self.spec.model,
|
137
137
|
"error": str(e),
|
138
138
|
"error_trace": traceback.format_exc(),
|
139
139
|
}
|
140
|
-
|
141
|
-
logger.
|
140
|
+
self.errors_dict[s_id] = new_error
|
141
|
+
logger.warning(f"Encountered Error: {e}. Skipping.")
|
142
|
+
logger.warning(traceback.format_exc())
|
142
143
|
|
143
144
|
def _build_model(self) -> pd.DataFrame:
|
144
145
|
full_data_dict = self.datasets.get_data_by_series()
|
@@ -166,7 +167,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
166
167
|
sec5_text = rc.Heading("ARIMA Model Parameters", level=2)
|
167
168
|
blocks = [
|
168
169
|
rc.Html(
|
169
|
-
m[
|
170
|
+
m["model"].summary().as_html(),
|
170
171
|
label=s_id if self.target_cat_col else None,
|
171
172
|
)
|
172
173
|
for i, (s_id, m) in enumerate(self.models.items())
|
@@ -201,11 +202,15 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
201
202
|
self.formatted_local_explanation = aggregate_local_explanations
|
202
203
|
|
203
204
|
if not self.target_cat_col:
|
204
|
-
self.formatted_global_explanation =
|
205
|
-
|
206
|
-
|
205
|
+
self.formatted_global_explanation = (
|
206
|
+
self.formatted_global_explanation.rename(
|
207
|
+
{"Series 1": self.original_target_column},
|
208
|
+
axis=1,
|
209
|
+
)
|
210
|
+
)
|
211
|
+
self.formatted_local_explanation.drop(
|
212
|
+
"Series", axis=1, inplace=True
|
207
213
|
)
|
208
|
-
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
|
209
214
|
|
210
215
|
# Create a markdown section for the global explainability
|
211
216
|
global_explanation_section = rc.Block(
|
@@ -235,7 +240,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
235
240
|
local_explanation_section,
|
236
241
|
]
|
237
242
|
except Exception as e:
|
238
|
-
logger.
|
243
|
+
logger.warning(f"Failed to generate Explanations with error: {e}.")
|
239
244
|
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|
240
245
|
|
241
246
|
model_description = rc.Text(
|
@@ -184,13 +184,18 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
184
184
|
"selected_model_params": model.selected_model_params_,
|
185
185
|
}
|
186
186
|
except Exception as e:
|
187
|
-
|
187
|
+
new_error = {
|
188
188
|
"model_name": self.spec.model,
|
189
189
|
"error": str(e),
|
190
190
|
"error_trace": traceback.format_exc(),
|
191
191
|
}
|
192
|
-
|
193
|
-
|
192
|
+
if s_id in self.errors_dict:
|
193
|
+
self.errors_dict[s_id]["model_fitting"] = new_error
|
194
|
+
else:
|
195
|
+
self.errors_dict[s_id] = {"model_fitting": new_error}
|
196
|
+
logger.warning(f"Encountered Error: {e}. Skipping.")
|
197
|
+
logger.warning(f"self.errors_dict[s_id]: {self.errors_dict[s_id]}")
|
198
|
+
logger.warning(traceback.format_exc())
|
194
199
|
|
195
200
|
logger.debug("===========Forecast Generated===========")
|
196
201
|
|
@@ -257,7 +262,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
257
262
|
)
|
258
263
|
|
259
264
|
self.formatted_global_explanation.rename(
|
260
|
-
columns={
|
265
|
+
columns={
|
266
|
+
self.spec.datetime_column.name: ForecastOutputColumns.DATE
|
267
|
+
},
|
261
268
|
inplace=True,
|
262
269
|
)
|
263
270
|
|
@@ -312,7 +319,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
312
319
|
local_explanation_section,
|
313
320
|
]
|
314
321
|
except Exception as e:
|
315
|
-
logger.
|
322
|
+
logger.warning(f"Failed to generate Explanations with error: {e}.")
|
316
323
|
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|
317
324
|
|
318
325
|
model_description = rc.Text(
|
@@ -462,14 +469,27 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
462
469
|
index="row", columns="Feature", values="Attribution"
|
463
470
|
)
|
464
471
|
explanations_df = explanations_df.reset_index(drop=True)
|
465
|
-
|
472
|
+
explanations_df[ForecastOutputColumns.DATE] = (
|
473
|
+
self.datasets.get_horizon_at_series(
|
474
|
+
s_id=s_id
|
475
|
+
)[self.spec.datetime_column.name].reset_index(drop=True)
|
476
|
+
)
|
466
477
|
# Store the explanations in the local_explanation dictionary
|
467
478
|
self.local_explanation[s_id] = explanations_df
|
468
479
|
|
469
480
|
self.global_explanation[s_id] = dict(
|
470
481
|
zip(
|
471
|
-
self.local_explanation[s_id]
|
472
|
-
|
482
|
+
self.local_explanation[s_id]
|
483
|
+
.drop(ForecastOutputColumns.DATE, axis=1)
|
484
|
+
.columns,
|
485
|
+
np.nanmean(
|
486
|
+
np.abs(
|
487
|
+
self.local_explanation[s_id].drop(
|
488
|
+
ForecastOutputColumns.DATE, axis=1
|
489
|
+
)
|
490
|
+
),
|
491
|
+
axis=0,
|
492
|
+
),
|
473
493
|
)
|
474
494
|
)
|
475
495
|
else:
|
@@ -478,7 +498,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
478
498
|
except Exception as e:
|
479
499
|
if s_id in self.errors_dict:
|
480
500
|
self.errors_dict[s_id]["explainer_error"] = str(e)
|
481
|
-
self.errors_dict[s_id]["explainer_error_trace"] =
|
501
|
+
self.errors_dict[s_id]["explainer_error_trace"] = (
|
502
|
+
traceback.format_exc()
|
503
|
+
)
|
482
504
|
else:
|
483
505
|
self.errors_dict[s_id] = {
|
484
506
|
"model_name": self.spec.model,
|
@@ -211,8 +211,8 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
|
|
211
211
|
"error": str(e),
|
212
212
|
"error_trace": traceback.format_exc(),
|
213
213
|
}
|
214
|
-
logger.
|
215
|
-
logger.
|
214
|
+
logger.warning(f"Encountered Error: {e}. Skipping.")
|
215
|
+
logger.warning(traceback.format_exc())
|
216
216
|
|
217
217
|
logger.debug("===========Done===========")
|
218
218
|
|
@@ -242,7 +242,7 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
|
|
242
242
|
self.models.df_wide_numeric, series=s_id
|
243
243
|
),
|
244
244
|
self.datasets.list_series_ids(),
|
245
|
-
target_category_column=self.target_cat_col
|
245
|
+
target_category_column=self.target_cat_col,
|
246
246
|
)
|
247
247
|
section_1 = rc.Block(
|
248
248
|
rc.Heading("Forecast Overview", level=2),
|
@@ -260,7 +260,9 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
|
|
260
260
|
)
|
261
261
|
|
262
262
|
except KeyError:
|
263
|
-
logger.
|
263
|
+
logger.warning(
|
264
|
+
"Issue generating Model Parameters Table Section. Skipping"
|
265
|
+
)
|
264
266
|
sec2 = rc.Text("Error generating model parameters.")
|
265
267
|
|
266
268
|
section_2 = rc.Block(sec2_text, sec2)
|
@@ -268,7 +270,7 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
|
|
268
270
|
all_sections = [section_1, section_2]
|
269
271
|
|
270
272
|
if self.spec.generate_explanations:
|
271
|
-
logger.
|
273
|
+
logger.warning("Explanations not yet supported for the AutoTS Module")
|
272
274
|
|
273
275
|
# Model Description
|
274
276
|
model_description = rc.Text(
|
@@ -28,8 +28,8 @@ from ads.opctl.operator.lowcode.common.utils import (
|
|
28
28
|
merged_category_column_name,
|
29
29
|
seconds_to_datetime,
|
30
30
|
write_data,
|
31
|
+
write_json,
|
31
32
|
)
|
32
|
-
from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
|
33
33
|
from ads.opctl.operator.lowcode.forecast.utils import (
|
34
34
|
_build_metrics_df,
|
35
35
|
_build_metrics_per_horizon,
|
@@ -46,6 +46,7 @@ from ..const import (
|
|
46
46
|
AUTO_SELECT,
|
47
47
|
BACKTEST_REPORT_NAME,
|
48
48
|
SUMMARY_METRICS_HORIZON_LIMIT,
|
49
|
+
ForecastOutputColumns,
|
49
50
|
SpeedAccuracyMode,
|
50
51
|
SupportedMetrics,
|
51
52
|
SupportedModels,
|
@@ -132,11 +133,10 @@ class ForecastOperatorBaseModel(ABC):
|
|
132
133
|
|
133
134
|
if self.datasets.test_data is not None:
|
134
135
|
try:
|
135
|
-
(
|
136
|
-
self.
|
137
|
-
|
138
|
-
|
139
|
-
elapsed_time=elapsed_time,
|
136
|
+
(self.test_eval_metrics, summary_metrics) = (
|
137
|
+
self._test_evaluate_metrics(
|
138
|
+
elapsed_time=elapsed_time,
|
139
|
+
)
|
140
140
|
)
|
141
141
|
if not self.target_cat_col:
|
142
142
|
self.test_eval_metrics.rename(
|
@@ -145,7 +145,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
145
145
|
inplace=True,
|
146
146
|
)
|
147
147
|
except Exception:
|
148
|
-
logger.
|
148
|
+
logger.warning("Unable to generate Test Metrics.")
|
149
149
|
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|
150
150
|
report_sections = []
|
151
151
|
|
@@ -155,9 +155,8 @@ class ForecastOperatorBaseModel(ABC):
|
|
155
155
|
model_description,
|
156
156
|
other_sections,
|
157
157
|
) = self._generate_report()
|
158
|
-
|
159
158
|
header_section = rc.Block(
|
160
|
-
rc.Heading(
|
159
|
+
rc.Heading(self.spec.report_title, level=1),
|
161
160
|
rc.Text(
|
162
161
|
f"You selected the {self.spec.model} model.\nBased on your dataset, you could have also selected any of the models: {SupportedModels.keys()}."
|
163
162
|
),
|
@@ -369,7 +368,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
369
368
|
-self.spec.horizon :
|
370
369
|
]
|
371
370
|
except KeyError as ke:
|
372
|
-
logger.
|
371
|
+
logger.warning(
|
373
372
|
f"Error Generating Metrics: Unable to find {s_id} in the test data. Error: {ke.args}"
|
374
373
|
)
|
375
374
|
y_pred = self.forecast_output.get_forecast(s_id)["forecast_value"].values[
|
@@ -478,10 +477,11 @@ class ForecastOperatorBaseModel(ABC):
|
|
478
477
|
unique_output_dir = self.spec.output_directory.url
|
479
478
|
results = ForecastResults()
|
480
479
|
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
480
|
+
storage_options = (
|
481
|
+
default_signer()
|
482
|
+
if ObjectStorageDetails.is_oci_path(unique_output_dir)
|
483
|
+
else {}
|
484
|
+
)
|
485
485
|
|
486
486
|
# report-creator html report
|
487
487
|
if self.spec.generate_report:
|
@@ -512,12 +512,13 @@ class ForecastOperatorBaseModel(ABC):
|
|
512
512
|
if self.target_cat_col
|
513
513
|
else result_df.drop(DataColumns.Series, axis=1)
|
514
514
|
)
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
515
|
+
if self.spec.generate_forecast_file:
|
516
|
+
write_data(
|
517
|
+
data=result_df,
|
518
|
+
filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
|
519
|
+
format="csv",
|
520
|
+
storage_options=storage_options,
|
521
|
+
)
|
521
522
|
results.set_forecast(result_df)
|
522
523
|
|
523
524
|
# metrics csv report
|
@@ -531,18 +532,19 @@ class ForecastOperatorBaseModel(ABC):
|
|
531
532
|
metrics_df_formatted = metrics_df.reset_index().rename(
|
532
533
|
{"index": "metrics", "Series 1": metrics_col_name}, axis=1
|
533
534
|
)
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
535
|
+
if self.spec.generate_metrics_file:
|
536
|
+
write_data(
|
537
|
+
data=metrics_df_formatted,
|
538
|
+
filename=os.path.join(
|
539
|
+
unique_output_dir, self.spec.metrics_filename
|
540
|
+
),
|
541
|
+
format="csv",
|
542
|
+
storage_options=storage_options,
|
543
|
+
index=False,
|
544
|
+
)
|
543
545
|
results.set_metrics(metrics_df_formatted)
|
544
546
|
else:
|
545
|
-
logger.
|
547
|
+
logger.warning(
|
546
548
|
f"Attempted to generate the {self.spec.metrics_filename} file with the training metrics, however the training metrics could not be properly generated."
|
547
549
|
)
|
548
550
|
|
@@ -552,56 +554,59 @@ class ForecastOperatorBaseModel(ABC):
|
|
552
554
|
test_metrics_df_formatted = test_metrics_df.reset_index().rename(
|
553
555
|
{"index": "metrics", "Series 1": metrics_col_name}, axis=1
|
554
556
|
)
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
557
|
+
if self.spec.generate_metrics_file:
|
558
|
+
write_data(
|
559
|
+
data=test_metrics_df_formatted,
|
560
|
+
filename=os.path.join(
|
561
|
+
unique_output_dir, self.spec.test_metrics_filename
|
562
|
+
),
|
563
|
+
format="csv",
|
564
|
+
storage_options=storage_options,
|
565
|
+
index=False,
|
566
|
+
)
|
564
567
|
results.set_test_metrics(test_metrics_df_formatted)
|
565
568
|
else:
|
566
|
-
logger.
|
569
|
+
logger.warning(
|
567
570
|
f"Attempted to generate the {self.spec.test_metrics_filename} file with the test metrics, however the test metrics could not be properly generated."
|
568
571
|
)
|
569
572
|
# explanations csv reports
|
570
573
|
if self.spec.generate_explanations:
|
571
574
|
try:
|
572
575
|
if not self.formatted_global_explanation.empty:
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
576
|
+
if self.spec.generate_explanation_files:
|
577
|
+
write_data(
|
578
|
+
data=self.formatted_global_explanation,
|
579
|
+
filename=os.path.join(
|
580
|
+
unique_output_dir, self.spec.global_explanation_filename
|
581
|
+
),
|
582
|
+
format="csv",
|
583
|
+
storage_options=storage_options,
|
584
|
+
index=True,
|
585
|
+
)
|
582
586
|
results.set_global_explanations(self.formatted_global_explanation)
|
583
587
|
else:
|
584
|
-
logger.
|
588
|
+
logger.warning(
|
585
589
|
f"Attempted to generate global explanations for the {self.spec.global_explanation_filename} file, but an issue occured in formatting the explanations."
|
586
590
|
)
|
587
591
|
|
588
592
|
if not self.formatted_local_explanation.empty:
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
593
|
+
if self.spec.generate_explanation_files:
|
594
|
+
write_data(
|
595
|
+
data=self.formatted_local_explanation,
|
596
|
+
filename=os.path.join(
|
597
|
+
unique_output_dir, self.spec.local_explanation_filename
|
598
|
+
),
|
599
|
+
format="csv",
|
600
|
+
storage_options=storage_options,
|
601
|
+
index=True,
|
602
|
+
)
|
598
603
|
results.set_local_explanations(self.formatted_local_explanation)
|
599
604
|
else:
|
600
|
-
logger.
|
605
|
+
logger.warning(
|
601
606
|
f"Attempted to generate local explanations for the {self.spec.local_explanation_filename} file, but an issue occured in formatting the explanations."
|
602
607
|
)
|
603
608
|
except AttributeError as e:
|
604
|
-
logger.
|
609
|
+
logger.warning(
|
605
610
|
"Unable to generate explanations for this model type or for this dataset."
|
606
611
|
)
|
607
612
|
logger.debug(f"Got error: {e.args}")
|
@@ -631,15 +636,12 @@ class ForecastOperatorBaseModel(ABC):
|
|
631
636
|
f"The outputs have been successfully generated and placed into the directory: {unique_output_dir}."
|
632
637
|
)
|
633
638
|
if self.errors_dict:
|
634
|
-
|
635
|
-
|
639
|
+
write_json(
|
640
|
+
json_dict=self.errors_dict,
|
636
641
|
filename=os.path.join(
|
637
642
|
unique_output_dir, self.spec.errors_dict_filename
|
638
643
|
),
|
639
|
-
format="json",
|
640
644
|
storage_options=storage_options,
|
641
|
-
index=True,
|
642
|
-
indent=4,
|
643
645
|
)
|
644
646
|
results.set_errors_dict(self.errors_dict)
|
645
647
|
else:
|
@@ -742,45 +744,62 @@ class ForecastOperatorBaseModel(ABC):
|
|
742
744
|
include_horizon=False
|
743
745
|
).items():
|
744
746
|
if s_id in self.models:
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
kernel_explnr = PermutationExplainer(
|
761
|
-
model=explain_predict_fn, masker=data_trimmed_encoded
|
762
|
-
)
|
763
|
-
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
|
764
|
-
exp_end_time = time.time()
|
765
|
-
global_ex_time = global_ex_time + exp_end_time - exp_start_time
|
766
|
-
self.local_explainer(
|
767
|
-
kernel_explnr, series_id=s_id, datetime_col_name=datetime_col_name
|
768
|
-
)
|
769
|
-
local_ex_time = local_ex_time + time.time() - exp_end_time
|
747
|
+
try:
|
748
|
+
explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
|
749
|
+
data_trimmed = data_i.tail(
|
750
|
+
max(int(len(data_i) * ratio), 5)
|
751
|
+
).reset_index(drop=True)
|
752
|
+
data_trimmed[datetime_col_name] = data_trimmed[
|
753
|
+
datetime_col_name
|
754
|
+
].apply(lambda x: x.timestamp())
|
755
|
+
|
756
|
+
# Explainer fails when boolean columns are passed
|
757
|
+
|
758
|
+
_, data_trimmed_encoded = _label_encode_dataframe(
|
759
|
+
data_trimmed,
|
760
|
+
no_encode={datetime_col_name, self.original_target_column},
|
761
|
+
)
|
770
762
|
|
771
|
-
|
772
|
-
|
773
|
-
"No explanations generated. Ensure that additional data has been provided."
|
763
|
+
kernel_explnr = PermutationExplainer(
|
764
|
+
model=explain_predict_fn, masker=data_trimmed_encoded
|
774
765
|
)
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
766
|
+
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
|
767
|
+
exp_end_time = time.time()
|
768
|
+
global_ex_time = global_ex_time + exp_end_time - exp_start_time
|
769
|
+
self.local_explainer(
|
770
|
+
kernel_explnr,
|
771
|
+
series_id=s_id,
|
772
|
+
datetime_col_name=datetime_col_name,
|
781
773
|
)
|
774
|
+
local_ex_time = local_ex_time + time.time() - exp_end_time
|
775
|
+
|
776
|
+
if not len(kernel_explnr_vals):
|
777
|
+
logger.warning(
|
778
|
+
"No explanations generated. Ensure that additional data has been provided."
|
779
|
+
)
|
780
|
+
else:
|
781
|
+
self.global_explanation[s_id] = dict(
|
782
|
+
zip(
|
783
|
+
data_trimmed.columns[1:],
|
784
|
+
np.average(
|
785
|
+
np.absolute(kernel_explnr_vals[:, 1:]), axis=0
|
786
|
+
),
|
787
|
+
)
|
788
|
+
)
|
789
|
+
except Exception as e:
|
790
|
+
if s_id in self.errors_dict:
|
791
|
+
self.errors_dict[s_id]["explainer_error"] = str(e)
|
792
|
+
self.errors_dict[s_id]["explainer_error_trace"] = (
|
793
|
+
traceback.format_exc()
|
794
|
+
)
|
795
|
+
else:
|
796
|
+
self.errors_dict[s_id] = {
|
797
|
+
"model_name": self.spec.model,
|
798
|
+
"explainer_error": str(e),
|
799
|
+
"explainer_error_trace": traceback.format_exc(),
|
800
|
+
}
|
782
801
|
else:
|
783
|
-
logger.
|
802
|
+
logger.warning(
|
784
803
|
f"Skipping explanations for {s_id}, as forecast was not generated."
|
785
804
|
)
|
786
805
|
|
@@ -815,6 +834,13 @@ class ForecastOperatorBaseModel(ABC):
|
|
815
834
|
local_kernel_explnr_df = pd.DataFrame(
|
816
835
|
local_kernel_explnr_vals, columns=data.columns
|
817
836
|
)
|
837
|
+
|
838
|
+
# Add date column to local explanation DataFrame
|
839
|
+
local_kernel_explnr_df[ForecastOutputColumns.DATE] = (
|
840
|
+
self.datasets.get_horizon_at_series(
|
841
|
+
s_id=series_id
|
842
|
+
)[self.spec.datetime_column.name].reset_index(drop=True)
|
843
|
+
)
|
818
844
|
self.local_explanation[series_id] = local_kernel_explnr_df
|
819
845
|
|
820
846
|
def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"):
|