oracle-ads 2.12.8__py3-none-any.whl → 2.12.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +4 -3
- ads/aqua/app.py +40 -18
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +799 -0
- ads/aqua/common/enums.py +3 -0
- ads/aqua/common/utils.py +62 -2
- ads/aqua/data.py +2 -19
- ads/aqua/evaluation/entities.py +6 -0
- ads/aqua/evaluation/evaluation.py +45 -15
- ads/aqua/extension/aqua_ws_msg_handler.py +14 -7
- ads/aqua/extension/base_handler.py +12 -9
- ads/aqua/extension/deployment_handler.py +8 -4
- ads/aqua/extension/finetune_handler.py +8 -14
- ads/aqua/extension/model_handler.py +30 -6
- ads/aqua/extension/ui_handler.py +13 -1
- ads/aqua/finetuning/constants.py +5 -2
- ads/aqua/finetuning/entities.py +73 -17
- ads/aqua/finetuning/finetuning.py +110 -82
- ads/aqua/model/entities.py +5 -1
- ads/aqua/model/model.py +230 -104
- ads/aqua/modeldeployment/deployment.py +35 -11
- ads/aqua/modeldeployment/entities.py +7 -4
- ads/aqua/ui.py +24 -2
- ads/cli.py +16 -8
- ads/common/auth.py +9 -9
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/guardrails/base.py +6 -5
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +46 -20
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +38 -11
- ads/model/__init__.py +11 -13
- ads/model/artifact.py +47 -8
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/generic_model.py +26 -24
- ads/model/model_metadata.py +8 -7
- ads/opctl/config/merger.py +13 -14
- ads/opctl/operator/common/operator_config.py +4 -4
- ads/opctl/operator/lowcode/common/transformations.py +50 -8
- ads/opctl/operator/lowcode/common/utils.py +22 -6
- ads/opctl/operator/lowcode/forecast/__main__.py +10 -0
- ads/opctl/operator/lowcode/forecast/const.py +3 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
- ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
- ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +58 -17
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +1 -1
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
- ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +3 -2
- ads/opctl/operator/lowcode/forecast/operator_config.py +31 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +76 -0
- ads/opctl/operator/lowcode/forecast/utils.py +8 -6
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +233 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +238 -0
- ads/telemetry/base.py +18 -11
- ads/telemetry/client.py +33 -13
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/METADATA +11 -10
- {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/RECORD +82 -56
- {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/entry_points.txt +0 -0
ads/model/model_metadata.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
3
|
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
@@ -11,20 +10,21 @@ import sys
|
|
11
10
|
from abc import ABC, abstractmethod
|
12
11
|
from dataclasses import dataclass, field, fields
|
13
12
|
from pathlib import Path
|
14
|
-
from typing import Dict, List, Tuple, Union
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
15
14
|
|
16
|
-
import ads.dataset.factory as factory
|
17
15
|
import fsspec
|
18
16
|
import git
|
19
17
|
import oci.data_science.models
|
20
18
|
import pandas as pd
|
21
19
|
import yaml
|
20
|
+
from oci.util import to_dict
|
21
|
+
|
22
22
|
from ads.common import logger
|
23
23
|
from ads.common.error import ChangesNotCommitted
|
24
24
|
from ads.common.extended_enum import ExtendedEnumMeta
|
25
|
-
from ads.common.serializer import DataClassSerializable
|
26
25
|
from ads.common.object_storage_details import ObjectStorageDetails
|
27
|
-
from
|
26
|
+
from ads.common.serializer import DataClassSerializable
|
27
|
+
from ads.dataset import factory
|
28
28
|
|
29
29
|
try:
|
30
30
|
from yaml import CDumper as dumper
|
@@ -173,6 +173,7 @@ class Framework(str, metaclass=ExtendedEnumMeta):
|
|
173
173
|
WORD2VEC = "word2vec"
|
174
174
|
ENSEMBLE = "ensemble"
|
175
175
|
SPARK = "pyspark"
|
176
|
+
EMBEDDING_ONNX = "embedding_onnx"
|
176
177
|
OTHER = "other"
|
177
178
|
|
178
179
|
|
@@ -1398,7 +1399,7 @@ class ModelCustomMetadata(ModelMetadata):
|
|
1398
1399
|
if (
|
1399
1400
|
not data
|
1400
1401
|
or not isinstance(data, Dict)
|
1401
|
-
or
|
1402
|
+
or "data" not in data
|
1402
1403
|
or not isinstance(data["data"], List)
|
1403
1404
|
):
|
1404
1405
|
raise ValueError(
|
@@ -1550,7 +1551,7 @@ class ModelTaxonomyMetadata(ModelMetadata):
|
|
1550
1551
|
if (
|
1551
1552
|
not data
|
1552
1553
|
or not isinstance(data, Dict)
|
1553
|
-
or
|
1554
|
+
or "data" not in data
|
1554
1555
|
or not isinstance(data["data"], List)
|
1555
1556
|
):
|
1556
1557
|
raise ValueError(
|
ads/opctl/config/merger.py
CHANGED
@@ -1,35 +1,33 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
|
-
# Copyright (c) 2022,
|
3
|
+
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
import os
|
8
7
|
from string import Template
|
9
8
|
from typing import Dict
|
10
|
-
import json
|
11
9
|
|
12
10
|
import yaml
|
13
11
|
|
14
12
|
from ads.common.auth import AuthType, ResourcePrincipal
|
15
13
|
from ads.opctl import logger
|
16
14
|
from ads.opctl.config.base import ConfigProcessor
|
17
|
-
from ads.opctl.config.utils import
|
18
|
-
from ads.opctl.utils import is_in_notebook_session, get_service_pack_prefix
|
15
|
+
from ads.opctl.config.utils import _DefaultNoneDict, read_from_ini
|
19
16
|
from ads.opctl.constants import (
|
20
|
-
DEFAULT_PROFILE,
|
21
|
-
DEFAULT_OCI_CONFIG_FILE,
|
22
|
-
DEFAULT_CONDA_PACK_FOLDER,
|
23
|
-
DEFAULT_ADS_CONFIG_FOLDER,
|
24
|
-
ADS_JOBS_CONFIG_FILE_NAME,
|
25
17
|
ADS_CONFIG_FILE_NAME,
|
26
|
-
ADS_ML_PIPELINE_CONFIG_FILE_NAME,
|
27
18
|
ADS_DATAFLOW_CONFIG_FILE_NAME,
|
19
|
+
ADS_JOBS_CONFIG_FILE_NAME,
|
28
20
|
ADS_LOCAL_BACKEND_CONFIG_FILE_NAME,
|
21
|
+
ADS_ML_PIPELINE_CONFIG_FILE_NAME,
|
29
22
|
ADS_MODEL_DEPLOYMENT_CONFIG_FILE_NAME,
|
30
|
-
DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
|
31
23
|
BACKEND_NAME,
|
24
|
+
DEFAULT_ADS_CONFIG_FOLDER,
|
25
|
+
DEFAULT_CONDA_PACK_FOLDER,
|
26
|
+
DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
|
27
|
+
DEFAULT_OCI_CONFIG_FILE,
|
28
|
+
DEFAULT_PROFILE,
|
32
29
|
)
|
30
|
+
from ads.opctl.utils import get_service_pack_prefix, is_in_notebook_session
|
33
31
|
|
34
32
|
|
35
33
|
class ConfigMerger(ConfigProcessor):
|
@@ -41,8 +39,9 @@ class ConfigMerger(ConfigProcessor):
|
|
41
39
|
"""
|
42
40
|
|
43
41
|
def process(self, **kwargs) -> None:
|
44
|
-
|
45
|
-
|
42
|
+
for key, value in self.config.items():
|
43
|
+
if isinstance(value, str): # Substitute only if the value is a string
|
44
|
+
self.config[key] = Template(value).safe_substitute(os.environ)
|
46
45
|
|
47
46
|
if "runtime" not in self.config:
|
48
47
|
self.config["runtime"] = {}
|
@@ -1,7 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
|
-
# Copyright (c) 2023 Oracle and/or its affiliates.
|
3
|
+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
|
@@ -11,15 +10,16 @@ from dataclasses import dataclass
|
|
11
10
|
from typing import Any, Dict, List
|
12
11
|
|
13
12
|
from ads.common.serializer import DataClassSerializable
|
14
|
-
|
15
|
-
from ads.opctl.operator.common.utils import OperatorValidator
|
16
13
|
from ads.opctl.operator.common.errors import InvalidParameterError
|
14
|
+
from ads.opctl.operator.common.utils import OperatorValidator
|
15
|
+
|
17
16
|
|
18
17
|
@dataclass(repr=True)
|
19
18
|
class InputData(DataClassSerializable):
|
20
19
|
"""Class representing operator specification input data details."""
|
21
20
|
|
22
21
|
connect_args: Dict = None
|
22
|
+
data: Dict = None
|
23
23
|
format: str = None
|
24
24
|
columns: List[str] = None
|
25
25
|
url: str = None
|
@@ -1,10 +1,11 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
|
3
|
-
# Copyright (c) 2023,
|
3
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
|
6
6
|
from abc import ABC
|
7
7
|
|
8
|
+
import numpy as np
|
8
9
|
import pandas as pd
|
9
10
|
|
10
11
|
from ads.opctl import logger
|
@@ -14,6 +15,7 @@ from ads.opctl.operator.lowcode.common.errors import (
|
|
14
15
|
InvalidParameterError,
|
15
16
|
)
|
16
17
|
from ads.opctl.operator.lowcode.common.utils import merge_category_columns
|
18
|
+
from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorSpec
|
17
19
|
|
18
20
|
|
19
21
|
class Transformations(ABC):
|
@@ -33,6 +35,7 @@ class Transformations(ABC):
|
|
33
35
|
self.dataset_info = dataset_info
|
34
36
|
self.target_category_columns = dataset_info.target_category_columns
|
35
37
|
self.target_column_name = dataset_info.target_column
|
38
|
+
self.raw_column_names = None
|
36
39
|
self.dt_column_name = (
|
37
40
|
dataset_info.datetime_column.name if dataset_info.datetime_column else None
|
38
41
|
)
|
@@ -59,7 +62,8 @@ class Transformations(ABC):
|
|
59
62
|
|
60
63
|
"""
|
61
64
|
clean_df = self._remove_trailing_whitespace(data)
|
62
|
-
|
65
|
+
if isinstance(self.dataset_info, ForecastOperatorSpec):
|
66
|
+
clean_df = self._clean_column_names(clean_df)
|
63
67
|
if self.name == "historical_data":
|
64
68
|
self._check_historical_dataset(clean_df)
|
65
69
|
clean_df = self._set_series_id_column(clean_df)
|
@@ -97,8 +101,36 @@ class Transformations(ABC):
|
|
97
101
|
def _remove_trailing_whitespace(self, df):
|
98
102
|
return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
|
99
103
|
|
100
|
-
|
101
|
-
|
104
|
+
def _clean_column_names(self, df):
|
105
|
+
"""
|
106
|
+
Remove all whitespaces from column names in a DataFrame and store the original names.
|
107
|
+
|
108
|
+
Parameters:
|
109
|
+
df (pd.DataFrame): The DataFrame whose column names need to be cleaned.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
pd.DataFrame: The DataFrame with cleaned column names.
|
113
|
+
"""
|
114
|
+
|
115
|
+
self.raw_column_names = {
|
116
|
+
col: col.replace(" ", "") for col in df.columns if " " in col
|
117
|
+
}
|
118
|
+
df.columns = [self.raw_column_names.get(col, col) for col in df.columns]
|
119
|
+
|
120
|
+
if self.target_column_name:
|
121
|
+
self.target_column_name = self.raw_column_names.get(
|
122
|
+
self.target_column_name, self.target_column_name
|
123
|
+
)
|
124
|
+
self.dt_column_name = self.raw_column_names.get(
|
125
|
+
self.dt_column_name, self.dt_column_name
|
126
|
+
)
|
127
|
+
|
128
|
+
if self.target_category_columns:
|
129
|
+
self.target_category_columns = [
|
130
|
+
self.raw_column_names.get(col, col)
|
131
|
+
for col in self.target_category_columns
|
132
|
+
]
|
133
|
+
return df
|
102
134
|
|
103
135
|
def _set_series_id_column(self, df):
|
104
136
|
self._target_category_columns_map = {}
|
@@ -209,23 +241,33 @@ class Transformations(ABC):
|
|
209
241
|
-------
|
210
242
|
A new Pandas DataFrame with treated outliears.
|
211
243
|
"""
|
212
|
-
df
|
244
|
+
return df
|
245
|
+
df["__z_score"] = (
|
213
246
|
df[self.target_column_name]
|
214
247
|
.groupby(DataColumns.Series)
|
215
248
|
.transform(lambda x: (x - x.mean()) / x.std())
|
216
249
|
)
|
217
|
-
outliers_mask = df["
|
250
|
+
outliers_mask = df["__z_score"].abs() > 3
|
251
|
+
|
252
|
+
if df[self.target_column_name].dtype == np.int:
|
253
|
+
df[self.target_column_name].astype(np.float)
|
254
|
+
|
218
255
|
df.loc[outliers_mask, self.target_column_name] = (
|
219
256
|
df[self.target_column_name]
|
220
257
|
.groupby(DataColumns.Series)
|
221
|
-
.transform(lambda x:
|
258
|
+
.transform(lambda x: np.median(x))
|
222
259
|
)
|
223
|
-
|
260
|
+
df_ret = df.drop("__z_score", axis=1)
|
261
|
+
return df_ret
|
224
262
|
|
225
263
|
def _check_historical_dataset(self, df):
|
226
264
|
expected_names = [self.target_column_name, self.dt_column_name] + (
|
227
265
|
self.target_category_columns if self.target_category_columns else []
|
228
266
|
)
|
267
|
+
|
268
|
+
if self.raw_column_names:
|
269
|
+
expected_names.extend(list(self.raw_column_names.values()))
|
270
|
+
|
229
271
|
if set(df.columns) != set(expected_names):
|
230
272
|
raise DataMismatchError(
|
231
273
|
f"Expected {self.name} to have columns: {expected_names}, but instead found column names: {df.columns}. Is the {self.name} path correct?"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
|
3
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
3
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
|
6
6
|
import logging
|
@@ -12,6 +12,7 @@ from typing import List, Union
|
|
12
12
|
|
13
13
|
import fsspec
|
14
14
|
import oracledb
|
15
|
+
import json
|
15
16
|
import pandas as pd
|
16
17
|
|
17
18
|
from ads.common.object_storage_details import ObjectStorageDetails
|
@@ -40,6 +41,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
|
|
40
41
|
if data_spec is None:
|
41
42
|
raise InvalidParameterError("No details provided for this data source.")
|
42
43
|
filename = data_spec.url
|
44
|
+
data = data_spec.data
|
43
45
|
format = data_spec.format
|
44
46
|
columns = data_spec.columns
|
45
47
|
connect_args = data_spec.connect_args
|
@@ -51,9 +53,12 @@ def load_data(data_spec, storage_options=None, **kwargs):
|
|
51
53
|
default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
|
52
54
|
)
|
53
55
|
if vault_secret_id is not None and connect_args is None:
|
54
|
-
connect_args =
|
56
|
+
connect_args = {}
|
55
57
|
|
56
|
-
if
|
58
|
+
if data is not None:
|
59
|
+
if format == "spark":
|
60
|
+
data = data.toPandas()
|
61
|
+
elif filename is not None:
|
57
62
|
if not format:
|
58
63
|
_, format = os.path.splitext(filename)
|
59
64
|
format = format[1:]
|
@@ -98,7 +103,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
|
|
98
103
|
except Exception as e:
|
99
104
|
raise Exception(
|
100
105
|
f"Could not retrieve database credentials from vault {vault_secret_id}: {e}"
|
101
|
-
)
|
106
|
+
) from e
|
102
107
|
|
103
108
|
con = oracledb.connect(**connect_args)
|
104
109
|
if table_name is not None:
|
@@ -121,7 +126,8 @@ def load_data(data_spec, storage_options=None, **kwargs):
|
|
121
126
|
return data
|
122
127
|
|
123
128
|
|
124
|
-
def write_data(data, filename, format, storage_options, index=False, **kwargs):
|
129
|
+
def write_data(data, filename, format, storage_options=None, index=False, **kwargs):
|
130
|
+
disable_print()
|
125
131
|
if not format:
|
126
132
|
_, format = os.path.splitext(filename)
|
127
133
|
format = format[1:]
|
@@ -130,11 +136,21 @@ def write_data(data, filename, format, storage_options, index=False, **kwargs):
|
|
130
136
|
return call_pandas_fsspec(
|
131
137
|
write_fn, filename, index=index, storage_options=storage_options, **kwargs
|
132
138
|
)
|
133
|
-
|
139
|
+
enable_print()
|
140
|
+
raise InvalidParameterError(
|
134
141
|
f"The format {format} is not currently supported for writing data. Please change the format parameter for the data output: {filename} ."
|
135
142
|
)
|
136
143
|
|
137
144
|
|
145
|
+
def write_simple_json(data, path):
|
146
|
+
if ObjectStorageDetails.is_oci_path(path):
|
147
|
+
storage_options = default_signer()
|
148
|
+
else:
|
149
|
+
storage_options = {}
|
150
|
+
with fsspec.open(path, mode="w", **storage_options) as f:
|
151
|
+
json.dump(data, f, indent=4)
|
152
|
+
|
153
|
+
|
138
154
|
def merge_category_columns(data, target_category_columns):
|
139
155
|
result = data.apply(
|
140
156
|
lambda x: "__".join([str(x[col]) for col in target_category_columns]), axis=1
|
@@ -17,6 +17,7 @@ from ads.opctl.operator.common.utils import _parse_input_args
|
|
17
17
|
|
18
18
|
from .operator_config import ForecastOperatorConfig
|
19
19
|
from .model.forecast_datasets import ForecastDatasets
|
20
|
+
from .whatifserve import ModelDeploymentManager
|
20
21
|
|
21
22
|
|
22
23
|
def operate(operator_config: ForecastOperatorConfig) -> None:
|
@@ -27,6 +28,15 @@ def operate(operator_config: ForecastOperatorConfig) -> None:
|
|
27
28
|
ForecastOperatorModelFactory.get_model(
|
28
29
|
operator_config, datasets
|
29
30
|
).generate_report()
|
31
|
+
# saving to model catalog
|
32
|
+
spec = operator_config.spec
|
33
|
+
if spec.what_if_analysis and datasets.additional_data:
|
34
|
+
mdm = ModelDeploymentManager(spec, datasets.additional_data)
|
35
|
+
mdm.save_to_catalog()
|
36
|
+
if spec.what_if_analysis.model_deployment:
|
37
|
+
mdm.create_deployment()
|
38
|
+
mdm.save_deployment_info()
|
39
|
+
|
30
40
|
|
31
41
|
def verify(spec: Dict, **kwargs: Dict) -> bool:
|
32
42
|
"""Verifies the forecasting operator config."""
|
@@ -27,10 +27,12 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
|
|
27
27
|
HIGH_ACCURACY = "HIGH_ACCURACY"
|
28
28
|
BALANCED = "BALANCED"
|
29
29
|
FAST_APPROXIMATE = "FAST_APPROXIMATE"
|
30
|
+
AUTOMLX = "AUTOMLX"
|
30
31
|
ratio = {}
|
31
32
|
ratio[HIGH_ACCURACY] = 1 # 100 % data used for generating explanations
|
32
33
|
ratio[BALANCED] = 0.5 # 50 % data used for generating explanations
|
33
34
|
ratio[FAST_APPROXIMATE] = 0 # constant
|
35
|
+
ratio[AUTOMLX] = 0 # constant
|
34
36
|
|
35
37
|
|
36
38
|
class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
|
@@ -87,3 +89,4 @@ SUMMARY_METRICS_HORIZON_LIMIT = 10
|
|
87
89
|
PROPHET_INTERNAL_DATE_COL = "ds"
|
88
90
|
RENDER_LIMIT = 5000
|
89
91
|
AUTO_SELECT = "auto-select"
|
92
|
+
BACKTEST_REPORT_NAME = "back_test.csv"
|
@@ -164,11 +164,11 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
164
164
|
blocks = [
|
165
165
|
rc.Html(
|
166
166
|
m.summary().as_html(),
|
167
|
-
label=s_id,
|
167
|
+
label=s_id if self.target_cat_col else None,
|
168
168
|
)
|
169
169
|
for i, (s_id, m) in enumerate(self.models.items())
|
170
170
|
]
|
171
|
-
sec5 = rc.Select(blocks=blocks)
|
171
|
+
sec5 = rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
|
172
172
|
all_sections = [sec5_text, sec5]
|
173
173
|
|
174
174
|
if self.spec.generate_explanations:
|
@@ -188,6 +188,21 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
188
188
|
axis=1,
|
189
189
|
)
|
190
190
|
)
|
191
|
+
aggregate_local_explanations = pd.DataFrame()
|
192
|
+
for s_id, local_ex_df in self.local_explanation.items():
|
193
|
+
local_ex_df_copy = local_ex_df.copy()
|
194
|
+
local_ex_df_copy["Series"] = s_id
|
195
|
+
aggregate_local_explanations = pd.concat(
|
196
|
+
[aggregate_local_explanations, local_ex_df_copy], axis=0
|
197
|
+
)
|
198
|
+
self.formatted_local_explanation = aggregate_local_explanations
|
199
|
+
|
200
|
+
if not self.target_cat_col:
|
201
|
+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
|
202
|
+
{"Series 1": self.original_target_column},
|
203
|
+
axis=1,
|
204
|
+
)
|
205
|
+
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
|
191
206
|
|
192
207
|
# Create a markdown section for the global explainability
|
193
208
|
global_explanation_section = rc.Block(
|
@@ -198,26 +213,17 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
198
213
|
rc.DataTable(self.formatted_global_explanation, index=True),
|
199
214
|
)
|
200
215
|
|
201
|
-
aggregate_local_explanations = pd.DataFrame()
|
202
|
-
for s_id, local_ex_df in self.local_explanation.items():
|
203
|
-
local_ex_df_copy = local_ex_df.copy()
|
204
|
-
local_ex_df_copy["Series"] = s_id
|
205
|
-
aggregate_local_explanations = pd.concat(
|
206
|
-
[aggregate_local_explanations, local_ex_df_copy], axis=0
|
207
|
-
)
|
208
|
-
self.formatted_local_explanation = aggregate_local_explanations
|
209
|
-
|
210
216
|
blocks = [
|
211
217
|
rc.DataTable(
|
212
218
|
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
|
213
|
-
label=s_id,
|
219
|
+
label=s_id if self.target_cat_col else None,
|
214
220
|
index=True,
|
215
221
|
)
|
216
222
|
for s_id, local_ex_df in self.local_explanation.items()
|
217
223
|
]
|
218
224
|
local_explanation_section = rc.Block(
|
219
225
|
rc.Heading("Local Explanation of Models", level=2),
|
220
|
-
rc.Select(blocks=blocks),
|
226
|
+
rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
|
221
227
|
)
|
222
228
|
|
223
229
|
# Append the global explanation text and section to the "all_sections" list
|
@@ -17,6 +17,7 @@ from ads.opctl.operator.lowcode.common.utils import (
|
|
17
17
|
from ads.opctl.operator.lowcode.forecast.const import (
|
18
18
|
AUTOMLX_METRIC_MAP,
|
19
19
|
ForecastOutputColumns,
|
20
|
+
SpeedAccuracyMode,
|
20
21
|
SupportedModels,
|
21
22
|
)
|
22
23
|
from ads.opctl.operator.lowcode.forecast.utils import _label_encode_dataframe
|
@@ -81,22 +82,6 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
81
82
|
|
82
83
|
from automlx import Pipeline, init
|
83
84
|
|
84
|
-
cpu_count = os.cpu_count()
|
85
|
-
try:
|
86
|
-
if cpu_count < 4:
|
87
|
-
engine = "local"
|
88
|
-
engine_opts = None
|
89
|
-
else:
|
90
|
-
engine = "ray"
|
91
|
-
engine_opts = ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
|
92
|
-
init(
|
93
|
-
engine=engine,
|
94
|
-
engine_opts=engine_opts,
|
95
|
-
loglevel=logging.CRITICAL,
|
96
|
-
)
|
97
|
-
except Exception as e:
|
98
|
-
logger.info(f"Error. Has Ray already been initialized? Skipping. {e}")
|
99
|
-
|
100
85
|
full_data_dict = self.datasets.get_data_by_series()
|
101
86
|
|
102
87
|
self.models = {}
|
@@ -112,6 +97,26 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
112
97
|
# Clean up kwargs for pass through
|
113
98
|
model_kwargs_cleaned, time_budget = self.set_kwargs()
|
114
99
|
|
100
|
+
cpu_count = os.cpu_count()
|
101
|
+
try:
|
102
|
+
engine_type = model_kwargs_cleaned.pop(
|
103
|
+
"engine", "local" if cpu_count <= 4 else "ray"
|
104
|
+
)
|
105
|
+
engine_opts = (
|
106
|
+
None
|
107
|
+
if engine_type == "local"
|
108
|
+
else ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
|
109
|
+
)
|
110
|
+
init(
|
111
|
+
engine=engine_type,
|
112
|
+
engine_opts=engine_opts,
|
113
|
+
loglevel=logging.CRITICAL,
|
114
|
+
)
|
115
|
+
except Exception as e:
|
116
|
+
logger.info(
|
117
|
+
f"Error initializing automlx. Has Ray already been initialized? Skipping. {e}"
|
118
|
+
)
|
119
|
+
|
115
120
|
for s_id, df in full_data_dict.items():
|
116
121
|
try:
|
117
122
|
logger.debug(f"Running automlx on series {s_id}")
|
@@ -223,6 +228,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
223
228
|
selected_models.items(), columns=["series_id", "best_selected_model"]
|
224
229
|
)
|
225
230
|
selected_df = selected_models_df["best_selected_model"].apply(pd.Series)
|
231
|
+
if not self.target_cat_col:
|
232
|
+
selected_df = selected_df.drop("series_id", axis=1)
|
226
233
|
selected_models_section = rc.Block(
|
227
234
|
rc.Heading("Selected Models Overview", level=2),
|
228
235
|
rc.Text(
|
@@ -239,27 +246,18 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
239
246
|
# If the key is present, call the "explain_model" method
|
240
247
|
self.explain_model()
|
241
248
|
|
242
|
-
|
243
|
-
|
249
|
+
global_explanation_section = None
|
250
|
+
if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
|
251
|
+
# Convert the global explanation data to a DataFrame
|
252
|
+
global_explanation_df = pd.DataFrame(self.global_explanation)
|
244
253
|
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
self.formatted_global_explanation.rename(
|
254
|
+
self.formatted_global_explanation = (
|
255
|
+
global_explanation_df / global_explanation_df.sum(axis=0) * 100
|
256
|
+
)
|
257
|
+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
|
250
258
|
{self.spec.datetime_column.name: ForecastOutputColumns.DATE},
|
251
259
|
axis=1,
|
252
260
|
)
|
253
|
-
)
|
254
|
-
|
255
|
-
# Create a markdown section for the global explainability
|
256
|
-
global_explanation_section = rc.Block(
|
257
|
-
rc.Heading("Global Explanation of Models", level=2),
|
258
|
-
rc.Text(
|
259
|
-
"The following tables provide the feature attribution for the global explainability."
|
260
|
-
),
|
261
|
-
rc.DataTable(self.formatted_global_explanation, index=True),
|
262
|
-
)
|
263
261
|
|
264
262
|
aggregate_local_explanations = pd.DataFrame()
|
265
263
|
for s_id, local_ex_df in self.local_explanation.items():
|
@@ -270,22 +268,41 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
270
268
|
)
|
271
269
|
self.formatted_local_explanation = aggregate_local_explanations
|
272
270
|
|
271
|
+
if not self.target_cat_col:
|
272
|
+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
|
273
|
+
{"Series 1": self.original_target_column},
|
274
|
+
axis=1,
|
275
|
+
)
|
276
|
+
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
|
277
|
+
|
278
|
+
# Create a markdown section for the global explainability
|
279
|
+
global_explanation_section = rc.Block(
|
280
|
+
rc.Heading("Global Explanation of Models", level=2),
|
281
|
+
rc.Text(
|
282
|
+
"The following tables provide the feature attribution for the global explainability."
|
283
|
+
),
|
284
|
+
rc.DataTable(self.formatted_global_explanation, index=True),
|
285
|
+
)
|
286
|
+
|
273
287
|
blocks = [
|
274
288
|
rc.DataTable(
|
275
289
|
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
|
276
|
-
label=s_id,
|
290
|
+
label=s_id if self.target_cat_col else None,
|
277
291
|
index=True,
|
278
292
|
)
|
279
293
|
for s_id, local_ex_df in self.local_explanation.items()
|
280
294
|
]
|
281
295
|
local_explanation_section = rc.Block(
|
282
296
|
rc.Heading("Local Explanation of Models", level=2),
|
283
|
-
rc.Select(blocks=blocks),
|
297
|
+
rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
|
284
298
|
)
|
285
299
|
|
286
300
|
# Append the global explanation text and section to the "other_sections" list
|
301
|
+
if global_explanation_section:
|
302
|
+
other_sections.append(global_explanation_section)
|
303
|
+
|
304
|
+
# Append the local explanation text and section to the "other_sections" list
|
287
305
|
other_sections = other_sections + [
|
288
|
-
global_explanation_section,
|
289
306
|
local_explanation_section,
|
290
307
|
]
|
291
308
|
except Exception as e:
|
@@ -366,3 +383,79 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
366
383
|
return self.models.get(self.series_id).forecast(
|
367
384
|
X=data_temp, periods=data_temp.shape[0]
|
368
385
|
)[self.series_id]
|
386
|
+
|
387
|
+
@runtime_dependency(
|
388
|
+
module="automlx",
|
389
|
+
err_msg=(
|
390
|
+
"Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
|
391
|
+
),
|
392
|
+
)
|
393
|
+
def explain_model(self):
|
394
|
+
"""
|
395
|
+
Generates explanations for the model using the AutoMLx library.
|
396
|
+
|
397
|
+
Parameters
|
398
|
+
----------
|
399
|
+
None
|
400
|
+
|
401
|
+
Returns
|
402
|
+
-------
|
403
|
+
None
|
404
|
+
|
405
|
+
Notes
|
406
|
+
-----
|
407
|
+
This function works by generating local explanations for each series in the dataset.
|
408
|
+
It uses the ``MLExplainer`` class from the AutoMLx library to generate feature attributions
|
409
|
+
for each series. The feature attributions are then stored in the ``self.local_explanation`` dictionary.
|
410
|
+
|
411
|
+
If the accuracy mode is set to AutoMLX, it uses the AutoMLx library to generate explanations.
|
412
|
+
Otherwise, it falls back to the default explanation generation method.
|
413
|
+
"""
|
414
|
+
import automlx
|
415
|
+
|
416
|
+
# Loop through each series in the dataset
|
417
|
+
for s_id, data_i in self.datasets.get_data_by_series(
|
418
|
+
include_horizon=False
|
419
|
+
).items():
|
420
|
+
try:
|
421
|
+
if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
|
422
|
+
# Use the MLExplainer class from AutoMLx to generate explanations
|
423
|
+
explainer = automlx.MLExplainer(
|
424
|
+
self.models[s_id],
|
425
|
+
self.datasets.additional_data.get_data_for_series(series_id=s_id)
|
426
|
+
.drop(self.spec.datetime_column.name, axis=1)
|
427
|
+
.head(-self.spec.horizon)
|
428
|
+
if self.spec.additional_data
|
429
|
+
else None,
|
430
|
+
pd.DataFrame(data_i[self.spec.target_column]),
|
431
|
+
task="forecasting",
|
432
|
+
)
|
433
|
+
|
434
|
+
# Generate explanations for the forecast
|
435
|
+
explanations = explainer.explain_prediction(
|
436
|
+
X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
|
437
|
+
.drop(self.spec.datetime_column.name, axis=1)
|
438
|
+
.tail(self.spec.horizon)
|
439
|
+
if self.spec.additional_data
|
440
|
+
else None,
|
441
|
+
forecast_timepoints=list(range(self.spec.horizon + 1)),
|
442
|
+
)
|
443
|
+
|
444
|
+
# Convert the explanations to a DataFrame
|
445
|
+
explanations_df = pd.concat(
|
446
|
+
[exp.to_dataframe() for exp in explanations]
|
447
|
+
)
|
448
|
+
explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
|
449
|
+
explanations_df = explanations_df.pivot(
|
450
|
+
index="row", columns="Feature", values="Attribution"
|
451
|
+
)
|
452
|
+
explanations_df = explanations_df.reset_index(drop=True)
|
453
|
+
|
454
|
+
# Store the explanations in the local_explanation dictionary
|
455
|
+
self.local_explanation[s_id] = explanations_df
|
456
|
+
else:
|
457
|
+
# Fall back to the default explanation generation method
|
458
|
+
super().explain_model()
|
459
|
+
except Exception as e:
|
460
|
+
logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.")
|
461
|
+
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|
@@ -242,6 +242,7 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
|
|
242
242
|
self.models.df_wide_numeric, series=s_id
|
243
243
|
),
|
244
244
|
self.datasets.list_series_ids(),
|
245
|
+
target_category_column=self.target_cat_col
|
245
246
|
)
|
246
247
|
section_1 = rc.Block(
|
247
248
|
rc.Heading("Forecast Overview", level=2),
|