oracle-ads 2.13.1rc0__py3-none-any.whl → 2.13.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +7 -1
- ads/aqua/app.py +24 -23
- ads/aqua/client/client.py +48 -11
- ads/aqua/common/entities.py +28 -1
- ads/aqua/common/enums.py +13 -7
- ads/aqua/common/utils.py +8 -13
- ads/aqua/config/container_config.py +203 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +5 -181
- ads/aqua/constants.py +0 -1
- ads/aqua/evaluation/evaluation.py +4 -4
- ads/aqua/extension/base_handler.py +4 -0
- ads/aqua/extension/model_handler.py +19 -28
- ads/aqua/finetuning/finetuning.py +2 -3
- ads/aqua/model/entities.py +2 -3
- ads/aqua/model/model.py +25 -30
- ads/aqua/modeldeployment/deployment.py +6 -14
- ads/aqua/modeldeployment/entities.py +2 -2
- ads/aqua/server/__init__.py +4 -0
- ads/aqua/server/__main__.py +24 -0
- ads/aqua/server/app.py +47 -0
- ads/aqua/server/aqua_spec.yml +1291 -0
- ads/aqua/ui.py +5 -199
- ads/common/auth.py +20 -11
- ads/common/utils.py +91 -11
- ads/config.py +3 -0
- ads/llm/__init__.py +1 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +32 -23
- ads/model/artifact_downloader.py +4 -1
- ads/model/common/utils.py +15 -3
- ads/model/datascience_model.py +339 -8
- ads/model/model_metadata.py +54 -14
- ads/model/model_version_set.py +5 -3
- ads/model/service/oci_datascience_model.py +477 -5
- ads/opctl/operator/common/utils.py +16 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +3 -3
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +1 -1
- ads/opctl/operator/lowcode/anomaly/utils.py +1 -1
- ads/opctl/operator/lowcode/common/data.py +5 -2
- ads/opctl/operator/lowcode/common/transformations.py +7 -13
- ads/opctl/operator/lowcode/common/utils.py +7 -2
- ads/opctl/operator/lowcode/forecast/model/arima.py +15 -10
- ads/opctl/operator/lowcode/forecast/model/automlx.py +39 -9
- ads/opctl/operator/lowcode/forecast/model/autots.py +7 -5
- ads/opctl/operator/lowcode/forecast/model/base_model.py +135 -110
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +30 -14
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +2 -2
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +46 -32
- ads/opctl/operator/lowcode/forecast/model/prophet.py +82 -29
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +142 -62
- ads/opctl/operator/lowcode/forecast/operator_config.py +29 -3
- ads/opctl/operator/lowcode/forecast/schema.yaml +1 -1
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +108 -56
- {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info}/METADATA +15 -12
- {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info}/RECORD +57 -53
- {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info}/WHEEL +1 -1
- ads/aqua/config/evaluation/evaluation_service_model_config.py +0 -8
- {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info}/entry_points.txt +0 -0
- {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info/licenses}/LICENSE.txt +0 -0
@@ -3,6 +3,7 @@
|
|
3
3
|
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
|
6
|
+
import json
|
6
7
|
import logging
|
7
8
|
import os
|
8
9
|
import shutil
|
@@ -12,7 +13,6 @@ from typing import List, Union
|
|
12
13
|
|
13
14
|
import fsspec
|
14
15
|
import oracledb
|
15
|
-
import json
|
16
16
|
import pandas as pd
|
17
17
|
|
18
18
|
from ads.common.object_storage_details import ObjectStorageDetails
|
@@ -142,6 +142,11 @@ def write_data(data, filename, format, storage_options=None, index=False, **kwar
|
|
142
142
|
)
|
143
143
|
|
144
144
|
|
145
|
+
def write_json(json_dict, filename, storage_options=None):
|
146
|
+
with fsspec.open(filename, mode="w", **storage_options) as f:
|
147
|
+
f.write(json.dumps(json_dict))
|
148
|
+
|
149
|
+
|
145
150
|
def write_simple_json(data, path):
|
146
151
|
if ObjectStorageDetails.is_oci_path(path):
|
147
152
|
storage_options = default_signer()
|
@@ -265,7 +270,7 @@ def find_output_dirname(output_dir: OutputDirectory):
|
|
265
270
|
while os.path.exists(unique_output_dir):
|
266
271
|
unique_output_dir = f"{output_dir}_{counter}"
|
267
272
|
counter += 1
|
268
|
-
logger.
|
273
|
+
logger.warning(
|
269
274
|
f"Since the output directory was not specified, the output will be saved to {unique_output_dir} directory."
|
270
275
|
)
|
271
276
|
return unique_output_dir
|
@@ -1,6 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
|
3
|
-
# Copyright (c) 2023,
|
3
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
|
6
6
|
import logging
|
@@ -132,13 +132,14 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
132
132
|
|
133
133
|
logger.debug("===========Done===========")
|
134
134
|
except Exception as e:
|
135
|
-
|
135
|
+
new_error = {
|
136
136
|
"model_name": self.spec.model,
|
137
137
|
"error": str(e),
|
138
138
|
"error_trace": traceback.format_exc(),
|
139
139
|
}
|
140
|
-
|
141
|
-
logger.
|
140
|
+
self.errors_dict[s_id] = new_error
|
141
|
+
logger.warning(f"Encountered Error: {e}. Skipping.")
|
142
|
+
logger.warning(traceback.format_exc())
|
142
143
|
|
143
144
|
def _build_model(self) -> pd.DataFrame:
|
144
145
|
full_data_dict = self.datasets.get_data_by_series()
|
@@ -166,7 +167,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
166
167
|
sec5_text = rc.Heading("ARIMA Model Parameters", level=2)
|
167
168
|
blocks = [
|
168
169
|
rc.Html(
|
169
|
-
m[
|
170
|
+
m["model"].summary().as_html(),
|
170
171
|
label=s_id if self.target_cat_col else None,
|
171
172
|
)
|
172
173
|
for i, (s_id, m) in enumerate(self.models.items())
|
@@ -201,11 +202,15 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
201
202
|
self.formatted_local_explanation = aggregate_local_explanations
|
202
203
|
|
203
204
|
if not self.target_cat_col:
|
204
|
-
self.formatted_global_explanation =
|
205
|
-
|
206
|
-
|
205
|
+
self.formatted_global_explanation = (
|
206
|
+
self.formatted_global_explanation.rename(
|
207
|
+
{"Series 1": self.original_target_column},
|
208
|
+
axis=1,
|
209
|
+
)
|
210
|
+
)
|
211
|
+
self.formatted_local_explanation.drop(
|
212
|
+
"Series", axis=1, inplace=True
|
207
213
|
)
|
208
|
-
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
|
209
214
|
|
210
215
|
# Create a markdown section for the global explainability
|
211
216
|
global_explanation_section = rc.Block(
|
@@ -235,7 +240,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
235
240
|
local_explanation_section,
|
236
241
|
]
|
237
242
|
except Exception as e:
|
238
|
-
logger.
|
243
|
+
logger.warning(f"Failed to generate Explanations with error: {e}.")
|
239
244
|
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|
240
245
|
|
241
246
|
model_description = rc.Text(
|
@@ -184,13 +184,18 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
184
184
|
"selected_model_params": model.selected_model_params_,
|
185
185
|
}
|
186
186
|
except Exception as e:
|
187
|
-
|
187
|
+
new_error = {
|
188
188
|
"model_name": self.spec.model,
|
189
189
|
"error": str(e),
|
190
190
|
"error_trace": traceback.format_exc(),
|
191
191
|
}
|
192
|
-
|
193
|
-
|
192
|
+
if s_id in self.errors_dict:
|
193
|
+
self.errors_dict[s_id]["model_fitting"] = new_error
|
194
|
+
else:
|
195
|
+
self.errors_dict[s_id] = {"model_fitting": new_error}
|
196
|
+
logger.warning(f"Encountered Error: {e}. Skipping.")
|
197
|
+
logger.warning(f"self.errors_dict[s_id]: {self.errors_dict[s_id]}")
|
198
|
+
logger.warning(traceback.format_exc())
|
194
199
|
|
195
200
|
logger.debug("===========Forecast Generated===========")
|
196
201
|
|
@@ -249,7 +254,6 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
249
254
|
self.explain_model()
|
250
255
|
|
251
256
|
global_explanation_section = None
|
252
|
-
|
253
257
|
# Convert the global explanation data to a DataFrame
|
254
258
|
global_explanation_df = pd.DataFrame(self.global_explanation)
|
255
259
|
|
@@ -258,7 +262,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
258
262
|
)
|
259
263
|
|
260
264
|
self.formatted_global_explanation.rename(
|
261
|
-
columns={
|
265
|
+
columns={
|
266
|
+
self.spec.datetime_column.name: ForecastOutputColumns.DATE
|
267
|
+
},
|
262
268
|
inplace=True,
|
263
269
|
)
|
264
270
|
|
@@ -313,7 +319,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
313
319
|
local_explanation_section,
|
314
320
|
]
|
315
321
|
except Exception as e:
|
316
|
-
logger.
|
322
|
+
logger.warning(f"Failed to generate Explanations with error: {e}.")
|
317
323
|
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|
318
324
|
|
319
325
|
model_description = rc.Text(
|
@@ -463,20 +469,44 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
463
469
|
index="row", columns="Feature", values="Attribution"
|
464
470
|
)
|
465
471
|
explanations_df = explanations_df.reset_index(drop=True)
|
466
|
-
|
472
|
+
explanations_df[ForecastOutputColumns.DATE] = (
|
473
|
+
self.datasets.get_horizon_at_series(
|
474
|
+
s_id=s_id
|
475
|
+
)[self.spec.datetime_column.name].reset_index(drop=True)
|
476
|
+
)
|
467
477
|
# Store the explanations in the local_explanation dictionary
|
468
478
|
self.local_explanation[s_id] = explanations_df
|
469
479
|
|
470
480
|
self.global_explanation[s_id] = dict(
|
471
481
|
zip(
|
472
|
-
self.local_explanation[s_id]
|
473
|
-
|
482
|
+
self.local_explanation[s_id]
|
483
|
+
.drop(ForecastOutputColumns.DATE, axis=1)
|
484
|
+
.columns,
|
485
|
+
np.nanmean(
|
486
|
+
np.abs(
|
487
|
+
self.local_explanation[s_id].drop(
|
488
|
+
ForecastOutputColumns.DATE, axis=1
|
489
|
+
)
|
490
|
+
),
|
491
|
+
axis=0,
|
492
|
+
),
|
474
493
|
)
|
475
494
|
)
|
476
495
|
else:
|
477
496
|
# Fall back to the default explanation generation method
|
478
497
|
super().explain_model()
|
479
498
|
except Exception as e:
|
499
|
+
if s_id in self.errors_dict:
|
500
|
+
self.errors_dict[s_id]["explainer_error"] = str(e)
|
501
|
+
self.errors_dict[s_id]["explainer_error_trace"] = (
|
502
|
+
traceback.format_exc()
|
503
|
+
)
|
504
|
+
else:
|
505
|
+
self.errors_dict[s_id] = {
|
506
|
+
"model_name": self.spec.model,
|
507
|
+
"explainer_error": str(e),
|
508
|
+
"explainer_error_trace": traceback.format_exc(),
|
509
|
+
}
|
480
510
|
logger.warning(
|
481
511
|
f"Failed to generate explanations for series {s_id} with error: {e}."
|
482
512
|
)
|
@@ -211,8 +211,8 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
|
|
211
211
|
"error": str(e),
|
212
212
|
"error_trace": traceback.format_exc(),
|
213
213
|
}
|
214
|
-
logger.
|
215
|
-
logger.
|
214
|
+
logger.warning(f"Encountered Error: {e}. Skipping.")
|
215
|
+
logger.warning(traceback.format_exc())
|
216
216
|
|
217
217
|
logger.debug("===========Done===========")
|
218
218
|
|
@@ -242,7 +242,7 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
|
|
242
242
|
self.models.df_wide_numeric, series=s_id
|
243
243
|
),
|
244
244
|
self.datasets.list_series_ids(),
|
245
|
-
target_category_column=self.target_cat_col
|
245
|
+
target_category_column=self.target_cat_col,
|
246
246
|
)
|
247
247
|
section_1 = rc.Block(
|
248
248
|
rc.Heading("Forecast Overview", level=2),
|
@@ -260,7 +260,9 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
|
|
260
260
|
)
|
261
261
|
|
262
262
|
except KeyError:
|
263
|
-
logger.
|
263
|
+
logger.warning(
|
264
|
+
"Issue generating Model Parameters Table Section. Skipping"
|
265
|
+
)
|
264
266
|
sec2 = rc.Text("Error generating model parameters.")
|
265
267
|
|
266
268
|
section_2 = rc.Block(sec2_text, sec2)
|
@@ -268,7 +270,7 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
|
|
268
270
|
all_sections = [section_1, section_2]
|
269
271
|
|
270
272
|
if self.spec.generate_explanations:
|
271
|
-
logger.
|
273
|
+
logger.warning("Explanations not yet supported for the AutoTS Module")
|
272
274
|
|
273
275
|
# Model Description
|
274
276
|
model_description = rc.Text(
|
@@ -28,8 +28,8 @@ from ads.opctl.operator.lowcode.common.utils import (
|
|
28
28
|
merged_category_column_name,
|
29
29
|
seconds_to_datetime,
|
30
30
|
write_data,
|
31
|
+
write_json,
|
31
32
|
)
|
32
|
-
from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
|
33
33
|
from ads.opctl.operator.lowcode.forecast.utils import (
|
34
34
|
_build_metrics_df,
|
35
35
|
_build_metrics_per_horizon,
|
@@ -46,6 +46,7 @@ from ..const import (
|
|
46
46
|
AUTO_SELECT,
|
47
47
|
BACKTEST_REPORT_NAME,
|
48
48
|
SUMMARY_METRICS_HORIZON_LIMIT,
|
49
|
+
ForecastOutputColumns,
|
49
50
|
SpeedAccuracyMode,
|
50
51
|
SupportedMetrics,
|
51
52
|
SupportedModels,
|
@@ -120,7 +121,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
120
121
|
|
121
122
|
# Generate metrics
|
122
123
|
summary_metrics = None
|
123
|
-
test_data =
|
124
|
+
test_data = self.datasets.test_data
|
124
125
|
self.eval_metrics = None
|
125
126
|
|
126
127
|
if self.spec.generate_report or self.spec.generate_metrics:
|
@@ -130,14 +131,12 @@ class ForecastOperatorBaseModel(ABC):
|
|
130
131
|
{"Series 1": self.original_target_column}, axis=1, inplace=True
|
131
132
|
)
|
132
133
|
|
133
|
-
if self.
|
134
|
+
if self.datasets.test_data is not None:
|
134
135
|
try:
|
135
|
-
(
|
136
|
-
self.
|
137
|
-
|
138
|
-
|
139
|
-
) = self._test_evaluate_metrics(
|
140
|
-
elapsed_time=elapsed_time,
|
136
|
+
(self.test_eval_metrics, summary_metrics) = (
|
137
|
+
self._test_evaluate_metrics(
|
138
|
+
elapsed_time=elapsed_time,
|
139
|
+
)
|
141
140
|
)
|
142
141
|
if not self.target_cat_col:
|
143
142
|
self.test_eval_metrics.rename(
|
@@ -146,7 +145,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
146
145
|
inplace=True,
|
147
146
|
)
|
148
147
|
except Exception:
|
149
|
-
logger.
|
148
|
+
logger.warning("Unable to generate Test Metrics.")
|
150
149
|
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|
151
150
|
report_sections = []
|
152
151
|
|
@@ -156,9 +155,8 @@ class ForecastOperatorBaseModel(ABC):
|
|
156
155
|
model_description,
|
157
156
|
other_sections,
|
158
157
|
) = self._generate_report()
|
159
|
-
|
160
158
|
header_section = rc.Block(
|
161
|
-
rc.Heading(
|
159
|
+
rc.Heading(self.spec.report_title, level=1),
|
162
160
|
rc.Text(
|
163
161
|
f"You selected the {self.spec.model} model.\nBased on your dataset, you could have also selected any of the models: {SupportedModels.keys()}."
|
164
162
|
),
|
@@ -361,7 +359,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
361
359
|
def _test_evaluate_metrics(self, elapsed_time=0):
|
362
360
|
total_metrics = pd.DataFrame()
|
363
361
|
summary_metrics = pd.DataFrame()
|
364
|
-
data =
|
362
|
+
data = self.datasets.test_data
|
365
363
|
|
366
364
|
# Generate y_pred and y_true for each series
|
367
365
|
for s_id in self.forecast_output.list_series_ids():
|
@@ -370,7 +368,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
370
368
|
-self.spec.horizon :
|
371
369
|
]
|
372
370
|
except KeyError as ke:
|
373
|
-
logger.
|
371
|
+
logger.warning(
|
374
372
|
f"Error Generating Metrics: Unable to find {s_id} in the test data. Error: {ke.args}"
|
375
373
|
)
|
376
374
|
y_pred = self.forecast_output.get_forecast(s_id)["forecast_value"].values[
|
@@ -398,7 +396,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
398
396
|
total_metrics = pd.concat([total_metrics, metrics_df], axis=1)
|
399
397
|
|
400
398
|
if total_metrics.empty:
|
401
|
-
return total_metrics, summary_metrics
|
399
|
+
return total_metrics, summary_metrics
|
402
400
|
|
403
401
|
summary_metrics = pd.DataFrame(
|
404
402
|
{
|
@@ -464,7 +462,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
464
462
|
]
|
465
463
|
summary_metrics = summary_metrics[new_column_order]
|
466
464
|
|
467
|
-
return total_metrics, summary_metrics
|
465
|
+
return total_metrics, summary_metrics
|
468
466
|
|
469
467
|
def _save_report(
|
470
468
|
self,
|
@@ -479,10 +477,11 @@ class ForecastOperatorBaseModel(ABC):
|
|
479
477
|
unique_output_dir = self.spec.output_directory.url
|
480
478
|
results = ForecastResults()
|
481
479
|
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
480
|
+
storage_options = (
|
481
|
+
default_signer()
|
482
|
+
if ObjectStorageDetails.is_oci_path(unique_output_dir)
|
483
|
+
else {}
|
484
|
+
)
|
486
485
|
|
487
486
|
# report-creator html report
|
488
487
|
if self.spec.generate_report:
|
@@ -513,12 +512,13 @@ class ForecastOperatorBaseModel(ABC):
|
|
513
512
|
if self.target_cat_col
|
514
513
|
else result_df.drop(DataColumns.Series, axis=1)
|
515
514
|
)
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
515
|
+
if self.spec.generate_forecast_file:
|
516
|
+
write_data(
|
517
|
+
data=result_df,
|
518
|
+
filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
|
519
|
+
format="csv",
|
520
|
+
storage_options=storage_options,
|
521
|
+
)
|
522
522
|
results.set_forecast(result_df)
|
523
523
|
|
524
524
|
# metrics csv report
|
@@ -532,77 +532,81 @@ class ForecastOperatorBaseModel(ABC):
|
|
532
532
|
metrics_df_formatted = metrics_df.reset_index().rename(
|
533
533
|
{"index": "metrics", "Series 1": metrics_col_name}, axis=1
|
534
534
|
)
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
535
|
+
if self.spec.generate_metrics_file:
|
536
|
+
write_data(
|
537
|
+
data=metrics_df_formatted,
|
538
|
+
filename=os.path.join(
|
539
|
+
unique_output_dir, self.spec.metrics_filename
|
540
|
+
),
|
541
|
+
format="csv",
|
542
|
+
storage_options=storage_options,
|
543
|
+
index=False,
|
544
|
+
)
|
544
545
|
results.set_metrics(metrics_df_formatted)
|
545
546
|
else:
|
546
|
-
logger.
|
547
|
+
logger.warning(
|
547
548
|
f"Attempted to generate the {self.spec.metrics_filename} file with the training metrics, however the training metrics could not be properly generated."
|
548
549
|
)
|
549
550
|
|
550
551
|
# test_metrics csv report
|
551
|
-
if self.
|
552
|
+
if self.datasets.test_data is not None:
|
552
553
|
if test_metrics_df is not None:
|
553
554
|
test_metrics_df_formatted = test_metrics_df.reset_index().rename(
|
554
555
|
{"index": "metrics", "Series 1": metrics_col_name}, axis=1
|
555
556
|
)
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
557
|
+
if self.spec.generate_metrics_file:
|
558
|
+
write_data(
|
559
|
+
data=test_metrics_df_formatted,
|
560
|
+
filename=os.path.join(
|
561
|
+
unique_output_dir, self.spec.test_metrics_filename
|
562
|
+
),
|
563
|
+
format="csv",
|
564
|
+
storage_options=storage_options,
|
565
|
+
index=False,
|
566
|
+
)
|
565
567
|
results.set_test_metrics(test_metrics_df_formatted)
|
566
568
|
else:
|
567
|
-
logger.
|
569
|
+
logger.warning(
|
568
570
|
f"Attempted to generate the {self.spec.test_metrics_filename} file with the test metrics, however the test metrics could not be properly generated."
|
569
571
|
)
|
570
572
|
# explanations csv reports
|
571
573
|
if self.spec.generate_explanations:
|
572
574
|
try:
|
573
|
-
if self.formatted_global_explanation
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
575
|
+
if not self.formatted_global_explanation.empty:
|
576
|
+
if self.spec.generate_explanation_files:
|
577
|
+
write_data(
|
578
|
+
data=self.formatted_global_explanation,
|
579
|
+
filename=os.path.join(
|
580
|
+
unique_output_dir, self.spec.global_explanation_filename
|
581
|
+
),
|
582
|
+
format="csv",
|
583
|
+
storage_options=storage_options,
|
584
|
+
index=True,
|
585
|
+
)
|
583
586
|
results.set_global_explanations(self.formatted_global_explanation)
|
584
587
|
else:
|
585
|
-
logger.
|
588
|
+
logger.warning(
|
586
589
|
f"Attempted to generate global explanations for the {self.spec.global_explanation_filename} file, but an issue occured in formatting the explanations."
|
587
590
|
)
|
588
591
|
|
589
|
-
if self.formatted_local_explanation
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
592
|
+
if not self.formatted_local_explanation.empty:
|
593
|
+
if self.spec.generate_explanation_files:
|
594
|
+
write_data(
|
595
|
+
data=self.formatted_local_explanation,
|
596
|
+
filename=os.path.join(
|
597
|
+
unique_output_dir, self.spec.local_explanation_filename
|
598
|
+
),
|
599
|
+
format="csv",
|
600
|
+
storage_options=storage_options,
|
601
|
+
index=True,
|
602
|
+
)
|
599
603
|
results.set_local_explanations(self.formatted_local_explanation)
|
600
604
|
else:
|
601
|
-
logger.
|
605
|
+
logger.warning(
|
602
606
|
f"Attempted to generate local explanations for the {self.spec.local_explanation_filename} file, but an issue occured in formatting the explanations."
|
603
607
|
)
|
604
608
|
except AttributeError as e:
|
605
|
-
logger.
|
609
|
+
logger.warning(
|
606
610
|
"Unable to generate explanations for this model type or for this dataset."
|
607
611
|
)
|
608
612
|
logger.debug(f"Got error: {e.args}")
|
@@ -632,15 +636,12 @@ class ForecastOperatorBaseModel(ABC):
|
|
632
636
|
f"The outputs have been successfully generated and placed into the directory: {unique_output_dir}."
|
633
637
|
)
|
634
638
|
if self.errors_dict:
|
635
|
-
|
636
|
-
|
639
|
+
write_json(
|
640
|
+
json_dict=self.errors_dict,
|
637
641
|
filename=os.path.join(
|
638
642
|
unique_output_dir, self.spec.errors_dict_filename
|
639
643
|
),
|
640
|
-
format="json",
|
641
644
|
storage_options=storage_options,
|
642
|
-
index=True,
|
643
|
-
indent=4,
|
644
645
|
)
|
645
646
|
results.set_errors_dict(self.errors_dict)
|
646
647
|
else:
|
@@ -743,45 +744,62 @@ class ForecastOperatorBaseModel(ABC):
|
|
743
744
|
include_horizon=False
|
744
745
|
).items():
|
745
746
|
if s_id in self.models:
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
kernel_explnr = PermutationExplainer(
|
762
|
-
model=explain_predict_fn, masker=data_trimmed_encoded
|
763
|
-
)
|
764
|
-
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
|
765
|
-
exp_end_time = time.time()
|
766
|
-
global_ex_time = global_ex_time + exp_end_time - exp_start_time
|
767
|
-
self.local_explainer(
|
768
|
-
kernel_explnr, series_id=s_id, datetime_col_name=datetime_col_name
|
769
|
-
)
|
770
|
-
local_ex_time = local_ex_time + time.time() - exp_end_time
|
747
|
+
try:
|
748
|
+
explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
|
749
|
+
data_trimmed = data_i.tail(
|
750
|
+
max(int(len(data_i) * ratio), 5)
|
751
|
+
).reset_index(drop=True)
|
752
|
+
data_trimmed[datetime_col_name] = data_trimmed[
|
753
|
+
datetime_col_name
|
754
|
+
].apply(lambda x: x.timestamp())
|
755
|
+
|
756
|
+
# Explainer fails when boolean columns are passed
|
757
|
+
|
758
|
+
_, data_trimmed_encoded = _label_encode_dataframe(
|
759
|
+
data_trimmed,
|
760
|
+
no_encode={datetime_col_name, self.original_target_column},
|
761
|
+
)
|
771
762
|
|
772
|
-
|
773
|
-
|
774
|
-
"No explanations generated. Ensure that additional data has been provided."
|
763
|
+
kernel_explnr = PermutationExplainer(
|
764
|
+
model=explain_predict_fn, masker=data_trimmed_encoded
|
775
765
|
)
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
766
|
+
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
|
767
|
+
exp_end_time = time.time()
|
768
|
+
global_ex_time = global_ex_time + exp_end_time - exp_start_time
|
769
|
+
self.local_explainer(
|
770
|
+
kernel_explnr,
|
771
|
+
series_id=s_id,
|
772
|
+
datetime_col_name=datetime_col_name,
|
782
773
|
)
|
774
|
+
local_ex_time = local_ex_time + time.time() - exp_end_time
|
775
|
+
|
776
|
+
if not len(kernel_explnr_vals):
|
777
|
+
logger.warning(
|
778
|
+
"No explanations generated. Ensure that additional data has been provided."
|
779
|
+
)
|
780
|
+
else:
|
781
|
+
self.global_explanation[s_id] = dict(
|
782
|
+
zip(
|
783
|
+
data_trimmed.columns[1:],
|
784
|
+
np.average(
|
785
|
+
np.absolute(kernel_explnr_vals[:, 1:]), axis=0
|
786
|
+
),
|
787
|
+
)
|
788
|
+
)
|
789
|
+
except Exception as e:
|
790
|
+
if s_id in self.errors_dict:
|
791
|
+
self.errors_dict[s_id]["explainer_error"] = str(e)
|
792
|
+
self.errors_dict[s_id]["explainer_error_trace"] = (
|
793
|
+
traceback.format_exc()
|
794
|
+
)
|
795
|
+
else:
|
796
|
+
self.errors_dict[s_id] = {
|
797
|
+
"model_name": self.spec.model,
|
798
|
+
"explainer_error": str(e),
|
799
|
+
"explainer_error_trace": traceback.format_exc(),
|
800
|
+
}
|
783
801
|
else:
|
784
|
-
logger.
|
802
|
+
logger.warning(
|
785
803
|
f"Skipping explanations for {s_id}, as forecast was not generated."
|
786
804
|
)
|
787
805
|
|
@@ -816,6 +834,13 @@ class ForecastOperatorBaseModel(ABC):
|
|
816
834
|
local_kernel_explnr_df = pd.DataFrame(
|
817
835
|
local_kernel_explnr_vals, columns=data.columns
|
818
836
|
)
|
837
|
+
|
838
|
+
# Add date column to local explanation DataFrame
|
839
|
+
local_kernel_explnr_df[ForecastOutputColumns.DATE] = (
|
840
|
+
self.datasets.get_horizon_at_series(
|
841
|
+
s_id=series_id
|
842
|
+
)[self.spec.datetime_column.name].reset_index(drop=True)
|
843
|
+
)
|
819
844
|
self.local_explanation[series_id] = local_kernel_explnr_df
|
820
845
|
|
821
846
|
def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"):
|