oracle-ads 2.12.9__py3-none-any.whl → 2.12.10rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +4 -4
- ads/aqua/common/enums.py +3 -0
- ads/aqua/common/utils.py +62 -2
- ads/aqua/data.py +2 -19
- ads/aqua/extension/finetune_handler.py +8 -14
- ads/aqua/extension/model_handler.py +19 -2
- ads/aqua/finetuning/constants.py +5 -2
- ads/aqua/finetuning/entities.py +64 -17
- ads/aqua/finetuning/finetuning.py +38 -54
- ads/aqua/model/entities.py +2 -1
- ads/aqua/model/model.py +61 -23
- ads/common/auth.py +9 -9
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +12 -11
- ads/model/__init__.py +11 -13
- ads/model/artifact.py +47 -8
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/generic_model.py +26 -24
- ads/model/model_metadata.py +8 -7
- ads/opctl/config/merger.py +13 -14
- ads/opctl/operator/common/operator_config.py +4 -4
- ads/opctl/operator/lowcode/common/transformations.py +12 -5
- ads/opctl/operator/lowcode/common/utils.py +11 -5
- ads/opctl/operator/lowcode/forecast/const.py +2 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
- ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
- ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +61 -14
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
- ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
- ads/opctl/operator/lowcode/forecast/schema.yaml +13 -0
- ads/opctl/operator/lowcode/forecast/utils.py +4 -3
- ads/telemetry/base.py +18 -11
- ads/telemetry/client.py +33 -13
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/METADATA +7 -8
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/RECORD +60 -39
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/entry_points.txt +0 -0
@@ -1,7 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
|
-
# Copyright (c) 2023 Oracle and/or its affiliates.
|
3
|
+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
|
@@ -11,15 +10,16 @@ from dataclasses import dataclass
|
|
11
10
|
from typing import Any, Dict, List
|
12
11
|
|
13
12
|
from ads.common.serializer import DataClassSerializable
|
14
|
-
|
15
|
-
from ads.opctl.operator.common.utils import OperatorValidator
|
16
13
|
from ads.opctl.operator.common.errors import InvalidParameterError
|
14
|
+
from ads.opctl.operator.common.utils import OperatorValidator
|
15
|
+
|
17
16
|
|
18
17
|
@dataclass(repr=True)
|
19
18
|
class InputData(DataClassSerializable):
|
20
19
|
"""Class representing operator specification input data details."""
|
21
20
|
|
22
21
|
connect_args: Dict = None
|
22
|
+
data: Dict = None
|
23
23
|
format: str = None
|
24
24
|
columns: List[str] = None
|
25
25
|
url: str = None
|
@@ -1,10 +1,11 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
|
3
|
-
# Copyright (c) 2023,
|
3
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
|
6
6
|
from abc import ABC
|
7
7
|
|
8
|
+
import numpy as np
|
8
9
|
import pandas as pd
|
9
10
|
|
10
11
|
from ads.opctl import logger
|
@@ -209,18 +210,24 @@ class Transformations(ABC):
|
|
209
210
|
-------
|
210
211
|
A new Pandas DataFrame with treated outliears.
|
211
212
|
"""
|
212
|
-
df
|
213
|
+
return df
|
214
|
+
df["__z_score"] = (
|
213
215
|
df[self.target_column_name]
|
214
216
|
.groupby(DataColumns.Series)
|
215
217
|
.transform(lambda x: (x - x.mean()) / x.std())
|
216
218
|
)
|
217
|
-
outliers_mask = df["
|
219
|
+
outliers_mask = df["__z_score"].abs() > 3
|
220
|
+
|
221
|
+
if df[self.target_column_name].dtype == np.int:
|
222
|
+
df[self.target_column_name].astype(np.float)
|
223
|
+
|
218
224
|
df.loc[outliers_mask, self.target_column_name] = (
|
219
225
|
df[self.target_column_name]
|
220
226
|
.groupby(DataColumns.Series)
|
221
|
-
.transform(lambda x:
|
227
|
+
.transform(lambda x: np.median(x))
|
222
228
|
)
|
223
|
-
|
229
|
+
df_ret = df.drop("__z_score", axis=1)
|
230
|
+
return df_ret
|
224
231
|
|
225
232
|
def _check_historical_dataset(self, df):
|
226
233
|
expected_names = [self.target_column_name, self.dt_column_name] + (
|
@@ -1,6 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
|
3
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
3
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
|
6
6
|
import logging
|
@@ -40,6 +40,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
|
|
40
40
|
if data_spec is None:
|
41
41
|
raise InvalidParameterError("No details provided for this data source.")
|
42
42
|
filename = data_spec.url
|
43
|
+
data = data_spec.data
|
43
44
|
format = data_spec.format
|
44
45
|
columns = data_spec.columns
|
45
46
|
connect_args = data_spec.connect_args
|
@@ -51,9 +52,12 @@ def load_data(data_spec, storage_options=None, **kwargs):
|
|
51
52
|
default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
|
52
53
|
)
|
53
54
|
if vault_secret_id is not None and connect_args is None:
|
54
|
-
connect_args =
|
55
|
+
connect_args = {}
|
55
56
|
|
56
|
-
if
|
57
|
+
if data is not None:
|
58
|
+
if format == "spark":
|
59
|
+
data = data.toPandas()
|
60
|
+
elif filename is not None:
|
57
61
|
if not format:
|
58
62
|
_, format = os.path.splitext(filename)
|
59
63
|
format = format[1:]
|
@@ -98,7 +102,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
|
|
98
102
|
except Exception as e:
|
99
103
|
raise Exception(
|
100
104
|
f"Could not retrieve database credentials from vault {vault_secret_id}: {e}"
|
101
|
-
)
|
105
|
+
) from e
|
102
106
|
|
103
107
|
con = oracledb.connect(**connect_args)
|
104
108
|
if table_name is not None:
|
@@ -122,6 +126,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
|
|
122
126
|
|
123
127
|
|
124
128
|
def write_data(data, filename, format, storage_options, index=False, **kwargs):
|
129
|
+
disable_print()
|
125
130
|
if not format:
|
126
131
|
_, format = os.path.splitext(filename)
|
127
132
|
format = format[1:]
|
@@ -130,7 +135,8 @@ def write_data(data, filename, format, storage_options, index=False, **kwargs):
|
|
130
135
|
return call_pandas_fsspec(
|
131
136
|
write_fn, filename, index=index, storage_options=storage_options, **kwargs
|
132
137
|
)
|
133
|
-
|
138
|
+
enable_print()
|
139
|
+
raise InvalidParameterError(
|
134
140
|
f"The format {format} is not currently supported for writing data. Please change the format parameter for the data output: {filename} ."
|
135
141
|
)
|
136
142
|
|
@@ -27,10 +27,12 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
|
|
27
27
|
HIGH_ACCURACY = "HIGH_ACCURACY"
|
28
28
|
BALANCED = "BALANCED"
|
29
29
|
FAST_APPROXIMATE = "FAST_APPROXIMATE"
|
30
|
+
AUTOMLX = "AUTOMLX"
|
30
31
|
ratio = {}
|
31
32
|
ratio[HIGH_ACCURACY] = 1 # 100 % data used for generating explanations
|
32
33
|
ratio[BALANCED] = 0.5 # 50 % data used for generating explanations
|
33
34
|
ratio[FAST_APPROXIMATE] = 0 # constant
|
35
|
+
ratio[AUTOMLX] = 0 # constant
|
34
36
|
|
35
37
|
|
36
38
|
class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
|
@@ -164,11 +164,11 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
164
164
|
blocks = [
|
165
165
|
rc.Html(
|
166
166
|
m.summary().as_html(),
|
167
|
-
label=s_id,
|
167
|
+
label=s_id if self.target_cat_col else None,
|
168
168
|
)
|
169
169
|
for i, (s_id, m) in enumerate(self.models.items())
|
170
170
|
]
|
171
|
-
sec5 = rc.Select(blocks=blocks)
|
171
|
+
sec5 = rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
|
172
172
|
all_sections = [sec5_text, sec5]
|
173
173
|
|
174
174
|
if self.spec.generate_explanations:
|
@@ -188,6 +188,21 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
188
188
|
axis=1,
|
189
189
|
)
|
190
190
|
)
|
191
|
+
aggregate_local_explanations = pd.DataFrame()
|
192
|
+
for s_id, local_ex_df in self.local_explanation.items():
|
193
|
+
local_ex_df_copy = local_ex_df.copy()
|
194
|
+
local_ex_df_copy["Series"] = s_id
|
195
|
+
aggregate_local_explanations = pd.concat(
|
196
|
+
[aggregate_local_explanations, local_ex_df_copy], axis=0
|
197
|
+
)
|
198
|
+
self.formatted_local_explanation = aggregate_local_explanations
|
199
|
+
|
200
|
+
if not self.target_cat_col:
|
201
|
+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
|
202
|
+
{"Series 1": self.original_target_column},
|
203
|
+
axis=1,
|
204
|
+
)
|
205
|
+
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
|
191
206
|
|
192
207
|
# Create a markdown section for the global explainability
|
193
208
|
global_explanation_section = rc.Block(
|
@@ -198,26 +213,17 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
198
213
|
rc.DataTable(self.formatted_global_explanation, index=True),
|
199
214
|
)
|
200
215
|
|
201
|
-
aggregate_local_explanations = pd.DataFrame()
|
202
|
-
for s_id, local_ex_df in self.local_explanation.items():
|
203
|
-
local_ex_df_copy = local_ex_df.copy()
|
204
|
-
local_ex_df_copy["Series"] = s_id
|
205
|
-
aggregate_local_explanations = pd.concat(
|
206
|
-
[aggregate_local_explanations, local_ex_df_copy], axis=0
|
207
|
-
)
|
208
|
-
self.formatted_local_explanation = aggregate_local_explanations
|
209
|
-
|
210
216
|
blocks = [
|
211
217
|
rc.DataTable(
|
212
218
|
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
|
213
|
-
label=s_id,
|
219
|
+
label=s_id if self.target_cat_col else None,
|
214
220
|
index=True,
|
215
221
|
)
|
216
222
|
for s_id, local_ex_df in self.local_explanation.items()
|
217
223
|
]
|
218
224
|
local_explanation_section = rc.Block(
|
219
225
|
rc.Heading("Local Explanation of Models", level=2),
|
220
|
-
rc.Select(blocks=blocks),
|
226
|
+
rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
|
221
227
|
)
|
222
228
|
|
223
229
|
# Append the global explanation text and section to the "all_sections" list
|
@@ -17,6 +17,7 @@ from ads.opctl.operator.lowcode.common.utils import (
|
|
17
17
|
from ads.opctl.operator.lowcode.forecast.const import (
|
18
18
|
AUTOMLX_METRIC_MAP,
|
19
19
|
ForecastOutputColumns,
|
20
|
+
SpeedAccuracyMode,
|
20
21
|
SupportedModels,
|
21
22
|
)
|
22
23
|
from ads.opctl.operator.lowcode.forecast.utils import _label_encode_dataframe
|
@@ -81,22 +82,6 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
81
82
|
|
82
83
|
from automlx import Pipeline, init
|
83
84
|
|
84
|
-
cpu_count = os.cpu_count()
|
85
|
-
try:
|
86
|
-
if cpu_count < 4:
|
87
|
-
engine = "local"
|
88
|
-
engine_opts = None
|
89
|
-
else:
|
90
|
-
engine = "ray"
|
91
|
-
engine_opts = ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
|
92
|
-
init(
|
93
|
-
engine=engine,
|
94
|
-
engine_opts=engine_opts,
|
95
|
-
loglevel=logging.CRITICAL,
|
96
|
-
)
|
97
|
-
except Exception as e:
|
98
|
-
logger.info(f"Error. Has Ray already been initialized? Skipping. {e}")
|
99
|
-
|
100
85
|
full_data_dict = self.datasets.get_data_by_series()
|
101
86
|
|
102
87
|
self.models = {}
|
@@ -112,6 +97,26 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
112
97
|
# Clean up kwargs for pass through
|
113
98
|
model_kwargs_cleaned, time_budget = self.set_kwargs()
|
114
99
|
|
100
|
+
cpu_count = os.cpu_count()
|
101
|
+
try:
|
102
|
+
engine_type = model_kwargs_cleaned.pop(
|
103
|
+
"engine", "local" if cpu_count <= 4 else "ray"
|
104
|
+
)
|
105
|
+
engine_opts = (
|
106
|
+
None
|
107
|
+
if engine_type == "local"
|
108
|
+
else ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
|
109
|
+
)
|
110
|
+
init(
|
111
|
+
engine=engine_type,
|
112
|
+
engine_opts=engine_opts,
|
113
|
+
loglevel=logging.CRITICAL,
|
114
|
+
)
|
115
|
+
except Exception as e:
|
116
|
+
logger.info(
|
117
|
+
f"Error initializing automlx. Has Ray already been initialized? Skipping. {e}"
|
118
|
+
)
|
119
|
+
|
115
120
|
for s_id, df in full_data_dict.items():
|
116
121
|
try:
|
117
122
|
logger.debug(f"Running automlx on series {s_id}")
|
@@ -223,6 +228,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
223
228
|
selected_models.items(), columns=["series_id", "best_selected_model"]
|
224
229
|
)
|
225
230
|
selected_df = selected_models_df["best_selected_model"].apply(pd.Series)
|
231
|
+
if not self.target_cat_col:
|
232
|
+
selected_df = selected_df.drop("series_id", axis=1)
|
226
233
|
selected_models_section = rc.Block(
|
227
234
|
rc.Heading("Selected Models Overview", level=2),
|
228
235
|
rc.Text(
|
@@ -239,27 +246,18 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
239
246
|
# If the key is present, call the "explain_model" method
|
240
247
|
self.explain_model()
|
241
248
|
|
242
|
-
|
243
|
-
|
249
|
+
global_explanation_section = None
|
250
|
+
if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
|
251
|
+
# Convert the global explanation data to a DataFrame
|
252
|
+
global_explanation_df = pd.DataFrame(self.global_explanation)
|
244
253
|
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
self.formatted_global_explanation.rename(
|
254
|
+
self.formatted_global_explanation = (
|
255
|
+
global_explanation_df / global_explanation_df.sum(axis=0) * 100
|
256
|
+
)
|
257
|
+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
|
250
258
|
{self.spec.datetime_column.name: ForecastOutputColumns.DATE},
|
251
259
|
axis=1,
|
252
260
|
)
|
253
|
-
)
|
254
|
-
|
255
|
-
# Create a markdown section for the global explainability
|
256
|
-
global_explanation_section = rc.Block(
|
257
|
-
rc.Heading("Global Explanation of Models", level=2),
|
258
|
-
rc.Text(
|
259
|
-
"The following tables provide the feature attribution for the global explainability."
|
260
|
-
),
|
261
|
-
rc.DataTable(self.formatted_global_explanation, index=True),
|
262
|
-
)
|
263
261
|
|
264
262
|
aggregate_local_explanations = pd.DataFrame()
|
265
263
|
for s_id, local_ex_df in self.local_explanation.items():
|
@@ -270,22 +268,41 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
270
268
|
)
|
271
269
|
self.formatted_local_explanation = aggregate_local_explanations
|
272
270
|
|
271
|
+
if not self.target_cat_col:
|
272
|
+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
|
273
|
+
{"Series 1": self.original_target_column},
|
274
|
+
axis=1,
|
275
|
+
)
|
276
|
+
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
|
277
|
+
|
278
|
+
# Create a markdown section for the global explainability
|
279
|
+
global_explanation_section = rc.Block(
|
280
|
+
rc.Heading("Global Explanation of Models", level=2),
|
281
|
+
rc.Text(
|
282
|
+
"The following tables provide the feature attribution for the global explainability."
|
283
|
+
),
|
284
|
+
rc.DataTable(self.formatted_global_explanation, index=True),
|
285
|
+
)
|
286
|
+
|
273
287
|
blocks = [
|
274
288
|
rc.DataTable(
|
275
289
|
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
|
276
|
-
label=s_id,
|
290
|
+
label=s_id if self.target_cat_col else None,
|
277
291
|
index=True,
|
278
292
|
)
|
279
293
|
for s_id, local_ex_df in self.local_explanation.items()
|
280
294
|
]
|
281
295
|
local_explanation_section = rc.Block(
|
282
296
|
rc.Heading("Local Explanation of Models", level=2),
|
283
|
-
rc.Select(blocks=blocks),
|
297
|
+
rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
|
284
298
|
)
|
285
299
|
|
286
300
|
# Append the global explanation text and section to the "other_sections" list
|
301
|
+
if global_explanation_section:
|
302
|
+
other_sections.append(global_explanation_section)
|
303
|
+
|
304
|
+
# Append the local explanation text and section to the "other_sections" list
|
287
305
|
other_sections = other_sections + [
|
288
|
-
global_explanation_section,
|
289
306
|
local_explanation_section,
|
290
307
|
]
|
291
308
|
except Exception as e:
|
@@ -366,3 +383,79 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
366
383
|
return self.models.get(self.series_id).forecast(
|
367
384
|
X=data_temp, periods=data_temp.shape[0]
|
368
385
|
)[self.series_id]
|
386
|
+
|
387
|
+
@runtime_dependency(
|
388
|
+
module="automlx",
|
389
|
+
err_msg=(
|
390
|
+
"Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
|
391
|
+
),
|
392
|
+
)
|
393
|
+
def explain_model(self):
|
394
|
+
"""
|
395
|
+
Generates explanations for the model using the AutoMLx library.
|
396
|
+
|
397
|
+
Parameters
|
398
|
+
----------
|
399
|
+
None
|
400
|
+
|
401
|
+
Returns
|
402
|
+
-------
|
403
|
+
None
|
404
|
+
|
405
|
+
Notes
|
406
|
+
-----
|
407
|
+
This function works by generating local explanations for each series in the dataset.
|
408
|
+
It uses the ``MLExplainer`` class from the AutoMLx library to generate feature attributions
|
409
|
+
for each series. The feature attributions are then stored in the ``self.local_explanation`` dictionary.
|
410
|
+
|
411
|
+
If the accuracy mode is set to AutoMLX, it uses the AutoMLx library to generate explanations.
|
412
|
+
Otherwise, it falls back to the default explanation generation method.
|
413
|
+
"""
|
414
|
+
import automlx
|
415
|
+
|
416
|
+
# Loop through each series in the dataset
|
417
|
+
for s_id, data_i in self.datasets.get_data_by_series(
|
418
|
+
include_horizon=False
|
419
|
+
).items():
|
420
|
+
try:
|
421
|
+
if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
|
422
|
+
# Use the MLExplainer class from AutoMLx to generate explanations
|
423
|
+
explainer = automlx.MLExplainer(
|
424
|
+
self.models[s_id],
|
425
|
+
self.datasets.additional_data.get_data_for_series(series_id=s_id)
|
426
|
+
.drop(self.spec.datetime_column.name, axis=1)
|
427
|
+
.head(-self.spec.horizon)
|
428
|
+
if self.spec.additional_data
|
429
|
+
else None,
|
430
|
+
pd.DataFrame(data_i[self.spec.target_column]),
|
431
|
+
task="forecasting",
|
432
|
+
)
|
433
|
+
|
434
|
+
# Generate explanations for the forecast
|
435
|
+
explanations = explainer.explain_prediction(
|
436
|
+
X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
|
437
|
+
.drop(self.spec.datetime_column.name, axis=1)
|
438
|
+
.tail(self.spec.horizon)
|
439
|
+
if self.spec.additional_data
|
440
|
+
else None,
|
441
|
+
forecast_timepoints=list(range(self.spec.horizon + 1)),
|
442
|
+
)
|
443
|
+
|
444
|
+
# Convert the explanations to a DataFrame
|
445
|
+
explanations_df = pd.concat(
|
446
|
+
[exp.to_dataframe() for exp in explanations]
|
447
|
+
)
|
448
|
+
explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
|
449
|
+
explanations_df = explanations_df.pivot(
|
450
|
+
index="row", columns="Feature", values="Attribution"
|
451
|
+
)
|
452
|
+
explanations_df = explanations_df.reset_index(drop=True)
|
453
|
+
|
454
|
+
# Store the explanations in the local_explanation dictionary
|
455
|
+
self.local_explanation[s_id] = explanations_df
|
456
|
+
else:
|
457
|
+
# Fall back to the default explanation generation method
|
458
|
+
super().explain_model()
|
459
|
+
except Exception as e:
|
460
|
+
logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.")
|
461
|
+
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|
@@ -242,6 +242,7 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
|
|
242
242
|
self.models.df_wide_numeric, series=s_id
|
243
243
|
),
|
244
244
|
self.datasets.list_series_ids(),
|
245
|
+
target_category_column=self.target_cat_col
|
245
246
|
)
|
246
247
|
section_1 = rc.Block(
|
247
248
|
rc.Heading("Forecast Overview", level=2),
|
@@ -28,6 +28,7 @@ from ads.opctl.operator.lowcode.common.utils import (
|
|
28
28
|
seconds_to_datetime,
|
29
29
|
write_data,
|
30
30
|
)
|
31
|
+
from ads.opctl.operator.lowcode.common.const import DataColumns
|
31
32
|
from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
|
32
33
|
from ads.opctl.operator.lowcode.forecast.utils import (
|
33
34
|
_build_metrics_df,
|
@@ -43,11 +44,12 @@ from ads.opctl.operator.lowcode.forecast.utils import (
|
|
43
44
|
|
44
45
|
from ..const import (
|
45
46
|
AUTO_SELECT,
|
47
|
+
BACKTEST_REPORT_NAME,
|
46
48
|
SUMMARY_METRICS_HORIZON_LIMIT,
|
47
49
|
SpeedAccuracyMode,
|
48
50
|
SupportedMetrics,
|
49
51
|
SupportedModels,
|
50
|
-
BACKTEST_REPORT_NAME
|
52
|
+
BACKTEST_REPORT_NAME,
|
51
53
|
)
|
52
54
|
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
|
53
55
|
from .forecast_datasets import ForecastDatasets
|
@@ -69,7 +71,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
69
71
|
self.config: ForecastOperatorConfig = config
|
70
72
|
self.spec: ForecastOperatorSpec = config.spec
|
71
73
|
self.datasets: ForecastDatasets = datasets
|
72
|
-
|
74
|
+
self.target_cat_col = self.spec.target_category_columns
|
73
75
|
self.full_data_dict = datasets.get_data_by_series()
|
74
76
|
|
75
77
|
self.test_eval_metrics = None
|
@@ -124,6 +126,9 @@ class ForecastOperatorBaseModel(ABC):
|
|
124
126
|
|
125
127
|
if self.spec.generate_report or self.spec.generate_metrics:
|
126
128
|
self.eval_metrics = self.generate_train_metrics()
|
129
|
+
if not self.target_cat_col:
|
130
|
+
self.eval_metrics.rename({"Series 1": self.original_target_column},
|
131
|
+
axis=1, inplace=True)
|
127
132
|
|
128
133
|
if self.spec.test_data:
|
129
134
|
try:
|
@@ -134,6 +139,9 @@ class ForecastOperatorBaseModel(ABC):
|
|
134
139
|
) = self._test_evaluate_metrics(
|
135
140
|
elapsed_time=elapsed_time,
|
136
141
|
)
|
142
|
+
if not self.target_cat_col:
|
143
|
+
self.test_eval_metrics.rename({"Series 1": self.original_target_column},
|
144
|
+
axis=1, inplace=True)
|
137
145
|
except Exception:
|
138
146
|
logger.warn("Unable to generate Test Metrics.")
|
139
147
|
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|
@@ -179,7 +187,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
179
187
|
first_5_rows_blocks = [
|
180
188
|
rc.DataTable(
|
181
189
|
df.head(5),
|
182
|
-
label=s_id,
|
190
|
+
label=s_id if self.target_cat_col else None,
|
183
191
|
index=True,
|
184
192
|
)
|
185
193
|
for s_id, df in self.full_data_dict.items()
|
@@ -188,7 +196,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
188
196
|
last_5_rows_blocks = [
|
189
197
|
rc.DataTable(
|
190
198
|
df.tail(5),
|
191
|
-
label=s_id,
|
199
|
+
label=s_id if self.target_cat_col else None,
|
192
200
|
index=True,
|
193
201
|
)
|
194
202
|
for s_id, df in self.full_data_dict.items()
|
@@ -197,7 +205,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
197
205
|
data_summary_blocks = [
|
198
206
|
rc.DataTable(
|
199
207
|
df.describe(),
|
200
|
-
label=s_id,
|
208
|
+
label=s_id if self.target_cat_col else None,
|
201
209
|
index=True,
|
202
210
|
)
|
203
211
|
for s_id, df in self.full_data_dict.items()
|
@@ -215,17 +223,17 @@ class ForecastOperatorBaseModel(ABC):
|
|
215
223
|
rc.Block(
|
216
224
|
first_10_title,
|
217
225
|
# series_subtext,
|
218
|
-
rc.Select(blocks=first_5_rows_blocks),
|
226
|
+
rc.Select(blocks=first_5_rows_blocks) if self.target_cat_col else first_5_rows_blocks[0],
|
219
227
|
),
|
220
228
|
rc.Block(
|
221
229
|
last_10_title,
|
222
230
|
# series_subtext,
|
223
|
-
rc.Select(blocks=last_5_rows_blocks),
|
231
|
+
rc.Select(blocks=last_5_rows_blocks) if self.target_cat_col else last_5_rows_blocks[0],
|
224
232
|
),
|
225
233
|
rc.Block(
|
226
234
|
summary_title,
|
227
235
|
# series_subtext,
|
228
|
-
rc.Select(blocks=data_summary_blocks),
|
236
|
+
rc.Select(blocks=data_summary_blocks) if self.target_cat_col else data_summary_blocks[0],
|
229
237
|
),
|
230
238
|
rc.Separator(),
|
231
239
|
)
|
@@ -259,7 +267,11 @@ class ForecastOperatorBaseModel(ABC):
|
|
259
267
|
output_dir = self.spec.output_directory.url
|
260
268
|
file_path = f"{output_dir}/{BACKTEST_REPORT_NAME}"
|
261
269
|
if self.spec.model == AUTO_SELECT:
|
262
|
-
backtest_sections.append(
|
270
|
+
backtest_sections.append(
|
271
|
+
rc.Heading(
|
272
|
+
"Auto-Select Backtesting and Performance Metrics", level=2
|
273
|
+
)
|
274
|
+
)
|
263
275
|
if not os.path.exists(file_path):
|
264
276
|
failure_msg = rc.Text(
|
265
277
|
"auto-select could not be executed. Please check the "
|
@@ -268,15 +280,23 @@ class ForecastOperatorBaseModel(ABC):
|
|
268
280
|
backtest_sections.append(failure_msg)
|
269
281
|
else:
|
270
282
|
backtest_stats = pd.read_csv(file_path)
|
271
|
-
model_metric_map = backtest_stats.drop(
|
272
|
-
|
283
|
+
model_metric_map = backtest_stats.drop(
|
284
|
+
columns=["metric", "backtest"]
|
285
|
+
)
|
286
|
+
average_dict = {
|
287
|
+
k: round(v, 4)
|
288
|
+
for k, v in model_metric_map.mean().to_dict().items()
|
289
|
+
}
|
273
290
|
best_model = min(average_dict, key=average_dict.get)
|
274
291
|
summary_text = rc.Text(
|
275
292
|
f"Overall, the average {self.spec.metric} scores for the models are {average_dict}, with"
|
276
|
-
f" {best_model} being identified as the top-performing model during backtesting."
|
293
|
+
f" {best_model} being identified as the top-performing model during backtesting."
|
294
|
+
)
|
277
295
|
backtest_table = rc.DataTable(backtest_stats, index=True)
|
278
296
|
liner_plot = get_auto_select_plot(backtest_stats)
|
279
|
-
backtest_sections.extend(
|
297
|
+
backtest_sections.extend(
|
298
|
+
[backtest_table, summary_text, liner_plot]
|
299
|
+
)
|
280
300
|
|
281
301
|
forecast_plots = []
|
282
302
|
if len(self.forecast_output.list_series_ids()) > 0:
|
@@ -288,6 +308,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
288
308
|
horizon=self.spec.horizon,
|
289
309
|
test_data=test_data,
|
290
310
|
ci_interval_width=self.spec.confidence_interval_width,
|
311
|
+
target_category_column=self.target_cat_col
|
291
312
|
)
|
292
313
|
if (
|
293
314
|
series_name is not None
|
@@ -301,7 +322,14 @@ class ForecastOperatorBaseModel(ABC):
|
|
301
322
|
forecast_plots = [forecast_text, forecast_sec]
|
302
323
|
|
303
324
|
yaml_appendix_title = rc.Heading("Reference: YAML File", level=2)
|
304
|
-
|
325
|
+
config_dict = self.config.to_dict()
|
326
|
+
# pop the data incase it isn't json serializable
|
327
|
+
config_dict["spec"]["historical_data"].pop("data")
|
328
|
+
if config_dict["spec"].get("additional_data"):
|
329
|
+
config_dict["spec"]["additional_data"].pop("data")
|
330
|
+
if config_dict["spec"].get("test_data"):
|
331
|
+
config_dict["spec"]["test_data"].pop("data")
|
332
|
+
yaml_appendix = rc.Yaml(config_dict)
|
305
333
|
report_sections = (
|
306
334
|
[summary]
|
307
335
|
+ backtest_sections
|
@@ -463,6 +491,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
463
491
|
f2.write(f1.read())
|
464
492
|
|
465
493
|
# forecast csv report
|
494
|
+
result_df = result_df if self.target_cat_col else result_df.drop(DataColumns.Series, axis=1)
|
466
495
|
write_data(
|
467
496
|
data=result_df,
|
468
497
|
filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
|
@@ -637,6 +666,13 @@ class ForecastOperatorBaseModel(ABC):
|
|
637
666
|
storage_options=storage_options,
|
638
667
|
)
|
639
668
|
|
669
|
+
def _validate_automlx_explanation_mode(self):
|
670
|
+
if self.spec.model != SupportedModels.AutoMLX and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
|
671
|
+
raise ValueError(
|
672
|
+
"AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
|
673
|
+
"Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
|
674
|
+
)
|
675
|
+
|
640
676
|
@runtime_dependency(
|
641
677
|
module="shap",
|
642
678
|
err_msg=(
|
@@ -665,6 +701,9 @@ class ForecastOperatorBaseModel(ABC):
|
|
665
701
|
)
|
666
702
|
ratio = SpeedAccuracyMode.ratio[self.spec.explanations_accuracy_mode]
|
667
703
|
|
704
|
+
# validate the automlx mode is use for automlx model
|
705
|
+
self._validate_automlx_explanation_mode()
|
706
|
+
|
668
707
|
for s_id, data_i in self.datasets.get_data_by_series(
|
669
708
|
include_horizon=False
|
670
709
|
).items():
|
@@ -699,6 +738,14 @@ class ForecastOperatorBaseModel(ABC):
|
|
699
738
|
logger.warn(
|
700
739
|
"No explanations generated. Ensure that additional data has been provided."
|
701
740
|
)
|
741
|
+
elif (
|
742
|
+
self.spec.model == SupportedModels.AutoMLX
|
743
|
+
and self.spec.explanations_accuracy_mode
|
744
|
+
== SpeedAccuracyMode.AUTOMLX
|
745
|
+
):
|
746
|
+
logger.warning(
|
747
|
+
"Global explanations not available for AutoMLX models with inherent explainability"
|
748
|
+
)
|
702
749
|
else:
|
703
750
|
self.global_explanation[s_id] = dict(
|
704
751
|
zip(
|