oracle-ads 2.11.6__py3-none-any.whl → 2.11.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +24 -14
- ads/aqua/base.py +0 -2
- ads/aqua/cli.py +50 -2
- ads/aqua/decorator.py +8 -0
- ads/aqua/deployment.py +37 -34
- ads/aqua/evaluation.py +106 -49
- ads/aqua/extension/base_handler.py +18 -10
- ads/aqua/extension/common_handler.py +21 -2
- ads/aqua/extension/deployment_handler.py +1 -4
- ads/aqua/extension/evaluation_handler.py +1 -2
- ads/aqua/extension/finetune_handler.py +0 -1
- ads/aqua/extension/ui_handler.py +1 -12
- ads/aqua/extension/utils.py +4 -4
- ads/aqua/finetune.py +24 -11
- ads/aqua/model.py +2 -4
- ads/aqua/utils.py +39 -23
- ads/catalog/model.py +3 -3
- ads/catalog/notebook.py +3 -3
- ads/catalog/project.py +2 -2
- ads/catalog/summary.py +2 -4
- ads/cli.py +21 -2
- ads/common/serializer.py +5 -4
- ads/common/utils.py +6 -2
- ads/config.py +1 -0
- ads/data_labeling/metadata.py +2 -2
- ads/dataset/dataset.py +3 -5
- ads/dataset/factory.py +2 -3
- ads/dataset/label_encoder.py +1 -1
- ads/dataset/sampled_dataset.py +3 -5
- ads/jobs/ads_job.py +26 -2
- ads/jobs/builders/infrastructure/dsc_job.py +20 -7
- ads/llm/serializers/runnable_parallel.py +7 -1
- ads/model/model_artifact_boilerplate/artifact_introspection_test/model_artifact_validate.py +1 -1
- ads/opctl/operator/lowcode/anomaly/README.md +1 -1
- ads/opctl/operator/lowcode/anomaly/environment.yaml +1 -1
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +8 -15
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +16 -10
- ads/opctl/operator/lowcode/anomaly/model/autots.py +9 -10
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +36 -39
- ads/opctl/operator/lowcode/anomaly/model/tods.py +4 -4
- ads/opctl/operator/lowcode/anomaly/operator_config.py +18 -1
- ads/opctl/operator/lowcode/anomaly/schema.yaml +16 -4
- ads/opctl/operator/lowcode/common/data.py +16 -2
- ads/opctl/operator/lowcode/common/transformations.py +48 -14
- ads/opctl/operator/lowcode/forecast/README.md +1 -1
- ads/opctl/operator/lowcode/forecast/environment.yaml +5 -4
- ads/opctl/operator/lowcode/forecast/model/arima.py +36 -29
- ads/opctl/operator/lowcode/forecast/model/automlx.py +91 -90
- ads/opctl/operator/lowcode/forecast/model/autots.py +200 -166
- ads/opctl/operator/lowcode/forecast/model/base_model.py +144 -140
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +86 -80
- ads/opctl/operator/lowcode/forecast/model/prophet.py +68 -63
- ads/opctl/operator/lowcode/forecast/operator_config.py +18 -2
- ads/opctl/operator/lowcode/forecast/schema.yaml +20 -4
- ads/opctl/operator/lowcode/forecast/utils.py +8 -4
- ads/opctl/operator/lowcode/pii/README.md +1 -1
- ads/opctl/operator/lowcode/pii/environment.yaml +1 -1
- ads/opctl/operator/lowcode/pii/model/report.py +71 -70
- ads/pipeline/ads_pipeline_step.py +11 -12
- {oracle_ads-2.11.6.dist-info → oracle_ads-2.11.8.dist-info}/METADATA +8 -7
- {oracle_ads-2.11.6.dist-info → oracle_ads-2.11.8.dist-info}/RECORD +64 -64
- {oracle_ads-2.11.6.dist-info → oracle_ads-2.11.8.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.11.6.dist-info → oracle_ads-2.11.8.dist-info}/WHEEL +0 -0
- {oracle_ads-2.11.6.dist-info → oracle_ads-2.11.8.dist-info}/entry_points.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
# -*- coding: utf-8 -*--
|
3
3
|
|
4
|
-
# Copyright (c) 2023 Oracle and/or its affiliates.
|
4
|
+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
|
5
5
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
6
|
|
7
7
|
import json
|
@@ -88,7 +88,9 @@ class ForecastOperatorBaseModel(ABC):
|
|
88
88
|
self.formatted_local_explanation = None
|
89
89
|
|
90
90
|
self.forecast_col_name = "yhat"
|
91
|
-
self.perform_tuning = self.spec.tuning != None
|
91
|
+
self.perform_tuning = (self.spec.tuning != None) and (
|
92
|
+
self.spec.tuning.n_trials != None
|
93
|
+
)
|
92
94
|
|
93
95
|
def generate_report(self):
|
94
96
|
"""Generates the forecasting report."""
|
@@ -100,7 +102,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
100
102
|
warnings.simplefilter(action="ignore", category=UserWarning)
|
101
103
|
warnings.simplefilter(action="ignore", category=RuntimeWarning)
|
102
104
|
warnings.simplefilter(action="ignore", category=ConvergenceWarning)
|
103
|
-
import
|
105
|
+
import report_creator as rc
|
104
106
|
|
105
107
|
# load models if given
|
106
108
|
if self.spec.previous_output_dir is not None:
|
@@ -140,69 +142,58 @@ class ForecastOperatorBaseModel(ABC):
|
|
140
142
|
other_sections,
|
141
143
|
) = self._generate_report()
|
142
144
|
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
model_description,
|
153
|
-
dp.Text(
|
154
|
-
"Based on your dataset, you could have also selected "
|
155
|
-
f"any of the models: `{'`, `'.join(SupportedModels.keys())}`."
|
145
|
+
header_section = rc.Block(
|
146
|
+
rc.Heading("Forecast Report", level=1),
|
147
|
+
rc.Text(
|
148
|
+
f"You selected the {self.spec.model} model.\n{model_description}\nBased on your dataset, you could have also selected any of the models: {SupportedModels.keys()}."
|
149
|
+
),
|
150
|
+
rc.Group(
|
151
|
+
rc.Metric(
|
152
|
+
heading="Analysis was completed in ",
|
153
|
+
value=human_time_friendly(elapsed_time),
|
156
154
|
),
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
),
|
162
|
-
dp.BigNumber(
|
163
|
-
heading="Starting time index",
|
164
|
-
value=self.datasets.get_earliest_timestamp().strftime(
|
165
|
-
"%B %d, %Y"
|
166
|
-
),
|
155
|
+
rc.Metric(
|
156
|
+
heading="Starting time index",
|
157
|
+
value=self.datasets.get_earliest_timestamp().strftime(
|
158
|
+
"%B %d, %Y"
|
167
159
|
),
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
),
|
174
|
-
dp.BigNumber(
|
175
|
-
heading="Num series",
|
176
|
-
value=len(self.datasets.list_series_ids()),
|
160
|
+
),
|
161
|
+
rc.Metric(
|
162
|
+
heading="Ending time index",
|
163
|
+
value=self.datasets.get_latest_timestamp().strftime(
|
164
|
+
"%B %d, %Y"
|
177
165
|
),
|
178
|
-
columns=4,
|
179
166
|
),
|
180
|
-
|
167
|
+
rc.Metric(
|
168
|
+
heading="Num series",
|
169
|
+
value=len(self.datasets.list_series_ids()),
|
170
|
+
),
|
171
|
+
),
|
181
172
|
)
|
182
173
|
|
183
|
-
|
184
|
-
|
185
|
-
df.head(
|
186
|
-
caption="Start",
|
174
|
+
first_5_rows_blocks = [
|
175
|
+
rc.DataTable(
|
176
|
+
df.head(5),
|
187
177
|
label=s_id,
|
178
|
+
index=True,
|
188
179
|
)
|
189
180
|
for s_id, df in self.full_data_dict.items()
|
190
181
|
]
|
191
182
|
|
192
|
-
|
193
|
-
|
194
|
-
df.tail(
|
195
|
-
caption="End",
|
183
|
+
last_5_rows_blocks = [
|
184
|
+
rc.DataTable(
|
185
|
+
df.tail(5),
|
196
186
|
label=s_id,
|
187
|
+
index=True,
|
197
188
|
)
|
198
189
|
for s_id, df in self.full_data_dict.items()
|
199
190
|
]
|
200
191
|
|
201
192
|
data_summary_blocks = [
|
202
|
-
|
193
|
+
rc.DataTable(
|
203
194
|
df.describe(),
|
204
|
-
caption="Summary Statistics",
|
205
195
|
label=s_id,
|
196
|
+
index=True,
|
206
197
|
)
|
207
198
|
for s_id, df in self.full_data_dict.items()
|
208
199
|
]
|
@@ -210,44 +201,33 @@ class ForecastOperatorBaseModel(ABC):
|
|
210
201
|
series_name = merged_category_column_name(
|
211
202
|
self.spec.target_category_columns
|
212
203
|
)
|
213
|
-
series_subtext =
|
214
|
-
first_10_title =
|
215
|
-
last_10_title =
|
216
|
-
summary_title =
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
first_10_title,
|
237
|
-
first_10_rows_blocks[0],
|
238
|
-
last_10_title,
|
239
|
-
last_10_rows_blocks[0],
|
240
|
-
summary_title,
|
241
|
-
data_summary_blocks[0],
|
242
|
-
dp.Text("----"),
|
243
|
-
]
|
244
|
-
)
|
204
|
+
# series_subtext = rc.Text(f"Indexed by {series_name}")
|
205
|
+
first_10_title = rc.Heading("First 5 Rows of Data", level=3)
|
206
|
+
last_10_title = rc.Heading("Last 5 Rows of Data", level=3)
|
207
|
+
summary_title = rc.Heading("Data Summary Statistics", level=3)
|
208
|
+
|
209
|
+
data_summary_sec = rc.Block(
|
210
|
+
rc.Block(
|
211
|
+
first_10_title,
|
212
|
+
# series_subtext,
|
213
|
+
rc.Select(blocks=first_5_rows_blocks),
|
214
|
+
),
|
215
|
+
rc.Block(
|
216
|
+
last_10_title,
|
217
|
+
# series_subtext,
|
218
|
+
rc.Select(blocks=last_5_rows_blocks),
|
219
|
+
),
|
220
|
+
rc.Block(
|
221
|
+
summary_title,
|
222
|
+
# series_subtext,
|
223
|
+
rc.Select(blocks=data_summary_blocks),
|
224
|
+
),
|
225
|
+
rc.Separator(),
|
226
|
+
)
|
245
227
|
|
246
|
-
summary =
|
247
|
-
|
248
|
-
|
249
|
-
data_summary_sec,
|
250
|
-
]
|
228
|
+
summary = rc.Block(
|
229
|
+
header_section,
|
230
|
+
data_summary_sec,
|
251
231
|
)
|
252
232
|
|
253
233
|
test_metrics_sections = []
|
@@ -255,38 +235,47 @@ class ForecastOperatorBaseModel(ABC):
|
|
255
235
|
self.test_eval_metrics is not None
|
256
236
|
and not self.test_eval_metrics.empty
|
257
237
|
):
|
258
|
-
sec7_text =
|
259
|
-
sec7 =
|
238
|
+
sec7_text = rc.Heading("Test Data Evaluation Metrics", level=2)
|
239
|
+
sec7 = rc.DataTable(self.test_eval_metrics, index=True)
|
260
240
|
test_metrics_sections = test_metrics_sections + [sec7_text, sec7]
|
261
241
|
|
262
242
|
if summary_metrics is not None and not summary_metrics.empty:
|
263
|
-
sec8_text =
|
264
|
-
sec8 =
|
243
|
+
sec8_text = rc.Heading("Test Data Summary Metrics", level=2)
|
244
|
+
sec8 = rc.DataTable(summary_metrics, index=True)
|
265
245
|
test_metrics_sections = test_metrics_sections + [sec8_text, sec8]
|
266
246
|
|
267
247
|
train_metrics_sections = []
|
268
248
|
if self.eval_metrics is not None and not self.eval_metrics.empty:
|
269
|
-
sec9_text =
|
270
|
-
sec9 =
|
249
|
+
sec9_text = rc.Heading("Training Data Metrics", level=2)
|
250
|
+
sec9 = rc.DataTable(self.eval_metrics, index=True)
|
271
251
|
train_metrics_sections = [sec9_text, sec9]
|
272
252
|
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
253
|
+
forecast_plots = []
|
254
|
+
if len(self.forecast_output.list_series_ids()) > 0:
|
255
|
+
forecast_text = rc.Heading(
|
256
|
+
"Forecasted Data Overlaying Historical", level=2
|
257
|
+
)
|
258
|
+
forecast_sec = get_forecast_plots(
|
259
|
+
self.forecast_output,
|
260
|
+
horizon=self.spec.horizon,
|
261
|
+
test_data=test_data,
|
262
|
+
ci_interval_width=self.spec.confidence_interval_width,
|
263
|
+
)
|
264
|
+
if (
|
265
|
+
series_name is not None
|
266
|
+
and len(self.datasets.list_series_ids()) > 1
|
267
|
+
):
|
268
|
+
forecast_plots = [
|
269
|
+
forecast_text,
|
270
|
+
forecast_sec,
|
271
|
+
] # series_subtext,
|
272
|
+
else:
|
273
|
+
forecast_plots = [forecast_text, forecast_sec]
|
274
|
+
|
275
|
+
yaml_appendix_title = rc.Heading("Reference: YAML File", level=2)
|
276
|
+
yaml_appendix = rc.Yaml(self.config.to_dict())
|
287
277
|
report_sections = (
|
288
|
-
[
|
289
|
-
+ [summary]
|
278
|
+
[summary]
|
290
279
|
+ forecast_plots
|
291
280
|
+ other_sections
|
292
281
|
+ test_metrics_sections
|
@@ -418,7 +407,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
418
407
|
test_metrics_df: pd.DataFrame,
|
419
408
|
):
|
420
409
|
"""Saves resulting reports to the given folder."""
|
421
|
-
import
|
410
|
+
import report_creator as rc
|
422
411
|
|
423
412
|
unique_output_dir = find_output_dirname(self.spec.output_directory)
|
424
413
|
|
@@ -427,13 +416,13 @@ class ForecastOperatorBaseModel(ABC):
|
|
427
416
|
else:
|
428
417
|
storage_options = dict()
|
429
418
|
|
430
|
-
#
|
419
|
+
# report-creator html report
|
431
420
|
if self.spec.generate_report:
|
432
|
-
# datapane html report
|
433
421
|
with tempfile.TemporaryDirectory() as temp_dir:
|
434
422
|
report_local_path = os.path.join(temp_dir, "___report.html")
|
435
423
|
disable_print()
|
436
|
-
|
424
|
+
with rc.ReportCreator("My Report") as report:
|
425
|
+
report.save(rc.Block(*report_sections), report_local_path)
|
437
426
|
enable_print()
|
438
427
|
|
439
428
|
report_path = os.path.join(unique_output_dir, self.spec.report_filename)
|
@@ -557,13 +546,14 @@ class ForecastOperatorBaseModel(ABC):
|
|
557
546
|
)
|
558
547
|
if self.errors_dict:
|
559
548
|
write_data(
|
560
|
-
data=pd.DataFrame(self.errors_dict
|
549
|
+
data=pd.DataFrame.from_dict(self.errors_dict),
|
561
550
|
filename=os.path.join(
|
562
551
|
unique_output_dir, self.spec.errors_dict_filename
|
563
552
|
),
|
564
|
-
format="
|
553
|
+
format="json",
|
565
554
|
storage_options=storage_options,
|
566
555
|
index=True,
|
556
|
+
indent=4,
|
567
557
|
)
|
568
558
|
else:
|
569
559
|
logger.info(f"All modeling completed successfully.")
|
@@ -650,38 +640,47 @@ class ForecastOperatorBaseModel(ABC):
|
|
650
640
|
for s_id, data_i in self.datasets.get_data_by_series(
|
651
641
|
include_horizon=False
|
652
642
|
).items():
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
kernel_explnr = PermutationExplainer(
|
663
|
-
model=explain_predict_fn, masker=data_trimmed
|
664
|
-
)
|
665
|
-
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed)
|
643
|
+
if s_id in self.models:
|
644
|
+
explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
|
645
|
+
data_trimmed = data_i.tail(
|
646
|
+
max(int(len(data_i) * ratio), 5)
|
647
|
+
).reset_index(drop=True)
|
648
|
+
data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply(
|
649
|
+
lambda x: x.timestamp()
|
650
|
+
)
|
666
651
|
|
667
|
-
|
668
|
-
global_ex_time = global_ex_time + exp_end_time - exp_start_time
|
652
|
+
# Explainer fails when boolean columns are passed
|
669
653
|
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
654
|
+
_, data_trimmed_encoded = _label_encode_dataframe(
|
655
|
+
data_trimmed,
|
656
|
+
no_encode={datetime_col_name, self.original_target_column},
|
657
|
+
)
|
674
658
|
|
675
|
-
|
676
|
-
|
677
|
-
f"No explanations generated. Ensure that additional data has been provided."
|
659
|
+
kernel_explnr = PermutationExplainer(
|
660
|
+
model=explain_predict_fn, masker=data_trimmed_encoded
|
678
661
|
)
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
662
|
+
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
|
663
|
+
exp_end_time = time.time()
|
664
|
+
global_ex_time = global_ex_time + exp_end_time - exp_start_time
|
665
|
+
self.local_explainer(
|
666
|
+
kernel_explnr, series_id=s_id, datetime_col_name=datetime_col_name
|
667
|
+
)
|
668
|
+
local_ex_time = local_ex_time + time.time() - exp_end_time
|
669
|
+
|
670
|
+
if not len(kernel_explnr_vals):
|
671
|
+
logger.warn(
|
672
|
+
f"No explanations generated. Ensure that additional data has been provided."
|
673
|
+
)
|
674
|
+
else:
|
675
|
+
self.global_explanation[s_id] = dict(
|
676
|
+
zip(
|
677
|
+
data_trimmed.columns[1:],
|
678
|
+
np.average(np.absolute(kernel_explnr_vals[:, 1:]), axis=0),
|
679
|
+
)
|
684
680
|
)
|
681
|
+
else:
|
682
|
+
logger.warn(
|
683
|
+
f"Skipping explanations for {s_id}, as forecast was not generated."
|
685
684
|
)
|
686
685
|
|
687
686
|
logger.info(
|
@@ -700,9 +699,14 @@ class ForecastOperatorBaseModel(ABC):
|
|
700
699
|
kernel_explainer: The kernel explainer object to use for generating explanations.
|
701
700
|
"""
|
702
701
|
data = self.datasets.get_horizon_at_series(s_id=series_id)
|
703
|
-
|
702
|
+
# columns that were dropped in train_model in arima, should be dropped here as well
|
704
703
|
data[datetime_col_name] = datetime_to_seconds(data[datetime_col_name])
|
705
704
|
data = data.reset_index(drop=True)
|
705
|
+
|
706
|
+
# Explainer fails when boolean columns are passed
|
707
|
+
_, data = _label_encode_dataframe(
|
708
|
+
data, no_encode={datetime_col_name, self.original_target_column}
|
709
|
+
)
|
706
710
|
# Generate local SHAP values using the kernel explainer
|
707
711
|
local_kernel_explnr_vals = kernel_explainer.shap_values(data)
|
708
712
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
# -*- coding: utf-8 -*--
|
3
3
|
|
4
|
-
# Copyright (c) 2023 Oracle and/or its affiliates.
|
4
|
+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
|
5
5
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
6
|
|
7
7
|
import numpy as np
|
@@ -41,20 +41,20 @@ from .forecast_datasets import ForecastDatasets, ForecastOutput
|
|
41
41
|
import traceback
|
42
42
|
|
43
43
|
|
44
|
-
def _get_np_metrics_dict(selected_metric):
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
44
|
+
# def _get_np_metrics_dict(selected_metric):
|
45
|
+
# metric_translation = {
|
46
|
+
# "mape": MeanAbsolutePercentageError,
|
47
|
+
# "smape": SymmetricMeanAbsolutePercentageError,
|
48
|
+
# "mae": MeanAbsoluteError,
|
49
|
+
# "r2": R2Score,
|
50
|
+
# "rmse": MeanSquaredError,
|
51
|
+
# }
|
52
|
+
# if selected_metric not in metric_translation.keys():
|
53
|
+
# logger.warn(
|
54
|
+
# f"Could not find the metric: {selected_metric} in torchmetrics. Defaulting to MAE and RMSE"
|
55
|
+
# )
|
56
|
+
# return {"MAE": MeanAbsoluteError(), "RMSE": MeanSquaredError()}
|
57
|
+
# return {selected_metric: metric_translation[selected_metric]()}
|
58
58
|
|
59
59
|
|
60
60
|
@runtime_dependency(
|
@@ -70,7 +70,7 @@ def _fit_model(data, params, additional_regressors, select_metric):
|
|
70
70
|
disable_print()
|
71
71
|
|
72
72
|
m = NeuralProphet(**params)
|
73
|
-
m.metrics = _get_np_metrics_dict(select_metric)
|
73
|
+
# m.metrics = _get_np_metrics_dict(select_metric)
|
74
74
|
for add_reg in additional_regressors:
|
75
75
|
m = m.add_future_regressor(name=add_reg)
|
76
76
|
m.fit(df=data)
|
@@ -120,11 +120,11 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
|
|
120
120
|
data = self.preprocess(df, s_id)
|
121
121
|
data_i = self.drop_horizon(data)
|
122
122
|
|
123
|
-
if self.loaded_models is not None:
|
123
|
+
if self.loaded_models is not None and s_id in self.loaded_models:
|
124
124
|
model = self.loaded_models[s_id]
|
125
125
|
accepted_regressors_config = model.config_regressors or dict()
|
126
126
|
self.accepted_regressors[s_id] = list(accepted_regressors_config.keys())
|
127
|
-
if self.loaded_trainers is not None:
|
127
|
+
if self.loaded_trainers is not None and s_id in self.loaded_trainers:
|
128
128
|
model.trainer = self.loaded_trainers[s_id]
|
129
129
|
else:
|
130
130
|
if self.perform_tuning:
|
@@ -135,7 +135,8 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
|
|
135
135
|
data=data_i,
|
136
136
|
params=model_kwargs,
|
137
137
|
additional_regressors=self.additional_regressors,
|
138
|
-
select_metric=
|
138
|
+
select_metric=None,
|
139
|
+
# select_metric=self.spec.metric,
|
139
140
|
)
|
140
141
|
|
141
142
|
logger.debug(
|
@@ -209,6 +210,7 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
|
|
209
210
|
logger.debug("===========Done===========")
|
210
211
|
except Exception as e:
|
211
212
|
self.errors_dict[s_id] = {"model_name": self.spec.model, "error": str(e)}
|
213
|
+
raise e
|
212
214
|
|
213
215
|
def _build_model(self) -> pd.DataFrame:
|
214
216
|
full_data_dict = self.datasets.get_data_by_series()
|
@@ -309,92 +311,96 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
|
|
309
311
|
return selected_params
|
310
312
|
|
311
313
|
def _generate_report(self):
|
312
|
-
import
|
314
|
+
import report_creator as rc
|
313
315
|
|
316
|
+
series_ids = self.models.keys()
|
314
317
|
all_sections = []
|
318
|
+
if len(series_ids) > 0:
|
319
|
+
try:
|
320
|
+
sec1 = _select_plot_list(
|
321
|
+
lambda s_id: self.models[s_id].plot(self.outputs[s_id]),
|
322
|
+
series_ids=series_ids,
|
323
|
+
)
|
324
|
+
section_1 = rc.Block(
|
325
|
+
rc.Heading("Forecast Overview", level=2),
|
326
|
+
rc.Text(
|
327
|
+
"These plots show your forecast in the context of historical data."
|
328
|
+
),
|
329
|
+
sec1,
|
330
|
+
)
|
331
|
+
all_sections = all_sections + [section_1]
|
332
|
+
except Exception as e:
|
333
|
+
logger.debug(f"Failed to plot with exception: {e.args}")
|
315
334
|
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
logger.debug(f"Failed to plot with exception: {e.args}")
|
328
|
-
|
329
|
-
try:
|
330
|
-
sec2_text = dp.Text(f"## Forecast Broken Down by Trend Component")
|
331
|
-
sec2 = _select_plot_list(
|
332
|
-
lambda s_id: self.models[s_id].plot_components(self.outputs[s_id]),
|
333
|
-
series_ids=self.datasets.list_series_ids(),
|
334
|
-
)
|
335
|
-
all_sections = all_sections + [sec2_text, sec2]
|
336
|
-
except Exception as e:
|
337
|
-
logger.debug(f"Failed to plot with exception: {e.args}")
|
335
|
+
try:
|
336
|
+
sec2 = _select_plot_list(
|
337
|
+
lambda s_id: self.models[s_id].plot_components(self.outputs[s_id]),
|
338
|
+
series_ids=series_ids,
|
339
|
+
)
|
340
|
+
section_2 = rc.Block(
|
341
|
+
rc.Heading("Forecast Broken Down by Trend Component", level=2), sec2
|
342
|
+
)
|
343
|
+
all_sections = all_sections + [section_2]
|
344
|
+
except Exception as e:
|
345
|
+
logger.debug(f"Failed to plot with exception: {e.args}")
|
338
346
|
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
series_ids=self.datasets.list_series_ids(),
|
344
|
-
)
|
345
|
-
all_sections = all_sections + [sec3_text, sec3]
|
346
|
-
except Exception as e:
|
347
|
-
logger.debug(f"Failed to plot with exception: {e.args}")
|
348
|
-
|
349
|
-
sec5_text = dp.Text(f"## Neural Prophet Model Parameters")
|
350
|
-
model_states = []
|
351
|
-
for i, (s_id, m) in enumerate(self.models.items()):
|
352
|
-
model_states.append(
|
353
|
-
pd.Series(
|
354
|
-
m.state_dict(),
|
355
|
-
index=m.state_dict().keys(),
|
356
|
-
name=s_id,
|
347
|
+
try:
|
348
|
+
sec3 = _select_plot_list(
|
349
|
+
lambda s_id: self.models[s_id].plot_parameters(),
|
350
|
+
series_ids=series_ids,
|
357
351
|
)
|
358
|
-
|
359
|
-
|
360
|
-
|
352
|
+
section_3 = rc.Block(
|
353
|
+
rc.Heading("Forecast Parameter Plots", level=2), sec3
|
354
|
+
)
|
355
|
+
all_sections = all_sections + [section_3]
|
356
|
+
except Exception as e:
|
357
|
+
logger.debug(f"Failed to plot with exception: {e.args}")
|
358
|
+
|
359
|
+
sec5_text = rc.Heading("Neural Prophet Model Parameters", level=2)
|
360
|
+
model_states = []
|
361
|
+
for i, (s_id, m) in enumerate(self.models.items()):
|
362
|
+
model_states.append(
|
363
|
+
pd.Series(
|
364
|
+
m.state_dict(),
|
365
|
+
index=m.state_dict().keys(),
|
366
|
+
name=s_id,
|
367
|
+
)
|
368
|
+
)
|
369
|
+
all_model_states = pd.concat(model_states, axis=1)
|
370
|
+
sec5 = rc.DataTable(all_model_states, index=True)
|
361
371
|
|
362
|
-
|
372
|
+
all_sections = all_sections + [sec5_text, sec5]
|
363
373
|
|
364
374
|
if self.spec.generate_explanations:
|
365
375
|
try:
|
366
376
|
# If the key is present, call the "explain_model" method
|
367
377
|
self.explain_model()
|
368
378
|
|
369
|
-
# Create a markdown text block for the global explanation section
|
370
|
-
global_explanation_text = dp.Text(
|
371
|
-
f"## Global Explanation of Models \n "
|
372
|
-
"The following tables provide the feature attribution for the global explainability."
|
373
|
-
)
|
374
|
-
|
375
379
|
# Create a markdown section for the global explainability
|
376
|
-
global_explanation_section =
|
377
|
-
"
|
378
|
-
|
380
|
+
global_explanation_section = rc.Block(
|
381
|
+
rc.Heading("Global Explainability", level=2),
|
382
|
+
rc.Text(
|
383
|
+
"The following tables provide the feature attribution for the global explainability."
|
384
|
+
),
|
385
|
+
rc.DataTable(self.formatted_global_explanation, index=True),
|
379
386
|
)
|
380
387
|
|
381
|
-
local_explanation_text = dp.Text(f"## Local Explanation of Models \n ")
|
382
388
|
blocks = [
|
383
|
-
|
389
|
+
rc.DataTable(
|
384
390
|
local_ex_df.drop("Series", axis=1),
|
385
391
|
label=s_id,
|
392
|
+
index=True,
|
386
393
|
)
|
387
394
|
for s_id, local_ex_df in self.local_explanation.items()
|
388
395
|
]
|
389
|
-
local_explanation_section = (
|
390
|
-
|
396
|
+
local_explanation_section = rc.Block(
|
397
|
+
rc.Heading("Local Explanation of Models", level=2),
|
398
|
+
rc.Select(blocks=blocks),
|
391
399
|
)
|
392
400
|
|
393
401
|
# Append the global explanation text and section to the "all_sections" list
|
394
402
|
all_sections = all_sections + [
|
395
|
-
global_explanation_text,
|
396
403
|
global_explanation_section,
|
397
|
-
local_explanation_text,
|
398
404
|
local_explanation_section,
|
399
405
|
]
|
400
406
|
except Exception as e:
|
@@ -402,7 +408,7 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
|
|
402
408
|
logger.warn(f"Failed to generate Explanations with error: {e}.")
|
403
409
|
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|
404
410
|
|
405
|
-
model_description =
|
411
|
+
model_description = rc.Text(
|
406
412
|
"NeuralProphet is an easy to learn framework for interpretable time "
|
407
413
|
"series forecasting. NeuralProphet is built on PyTorch and combines "
|
408
414
|
"Neural Network and traditional time-series algorithms, inspired by "
|