oracle-ads 2.12.11__py3-none-any.whl → 2.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +7 -1
- ads/aqua/app.py +41 -27
- ads/aqua/client/client.py +48 -11
- ads/aqua/common/entities.py +28 -1
- ads/aqua/common/enums.py +32 -21
- ads/aqua/common/errors.py +3 -4
- ads/aqua/common/utils.py +10 -15
- ads/aqua/config/container_config.py +203 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +5 -181
- ads/aqua/constants.py +1 -1
- ads/aqua/evaluation/constants.py +7 -7
- ads/aqua/evaluation/errors.py +3 -4
- ads/aqua/evaluation/evaluation.py +4 -4
- ads/aqua/extension/base_handler.py +4 -0
- ads/aqua/extension/model_handler.py +41 -27
- ads/aqua/extension/models/ws_models.py +5 -6
- ads/aqua/finetuning/constants.py +3 -3
- ads/aqua/finetuning/finetuning.py +2 -3
- ads/aqua/model/constants.py +7 -7
- ads/aqua/model/entities.py +2 -3
- ads/aqua/model/enums.py +4 -5
- ads/aqua/model/model.py +46 -29
- ads/aqua/modeldeployment/deployment.py +6 -14
- ads/aqua/modeldeployment/entities.py +5 -3
- ads/aqua/server/__init__.py +4 -0
- ads/aqua/server/__main__.py +24 -0
- ads/aqua/server/app.py +47 -0
- ads/aqua/server/aqua_spec.yml +1291 -0
- ads/aqua/ui.py +5 -199
- ads/common/auth.py +50 -28
- ads/common/extended_enum.py +52 -44
- ads/common/utils.py +91 -11
- ads/config.py +3 -0
- ads/llm/__init__.py +12 -8
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +32 -23
- ads/model/artifact_downloader.py +6 -4
- ads/model/common/utils.py +15 -3
- ads/model/datascience_model.py +422 -71
- ads/model/generic_model.py +3 -3
- ads/model/model_metadata.py +70 -24
- ads/model/model_version_set.py +5 -3
- ads/model/service/oci_datascience_model.py +487 -17
- ads/opctl/anomaly_detection.py +11 -0
- ads/opctl/backend/marketplace/helm_helper.py +13 -14
- ads/opctl/cli.py +4 -5
- ads/opctl/cmds.py +28 -32
- ads/opctl/config/merger.py +8 -11
- ads/opctl/config/resolver.py +25 -30
- ads/opctl/forecast.py +11 -0
- ads/opctl/operator/cli.py +9 -9
- ads/opctl/operator/common/backend_factory.py +56 -60
- ads/opctl/operator/common/const.py +5 -5
- ads/opctl/operator/common/utils.py +16 -0
- ads/opctl/operator/lowcode/anomaly/const.py +8 -9
- ads/opctl/operator/lowcode/common/data.py +5 -2
- ads/opctl/operator/lowcode/common/transformations.py +2 -12
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +43 -48
- ads/opctl/operator/lowcode/forecast/__main__.py +5 -5
- ads/opctl/operator/lowcode/forecast/const.py +6 -6
- ads/opctl/operator/lowcode/forecast/model/arima.py +6 -3
- ads/opctl/operator/lowcode/forecast/model/automlx.py +61 -31
- ads/opctl/operator/lowcode/forecast/model/base_model.py +66 -40
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +79 -13
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +5 -2
- ads/opctl/operator/lowcode/forecast/model/prophet.py +28 -15
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +13 -15
- ads/opctl/operator/lowcode/forecast/schema.yaml +1 -1
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +19 -11
- ads/opctl/operator/lowcode/pii/constant.py +6 -7
- ads/opctl/operator/lowcode/recommender/constant.py +12 -7
- ads/opctl/operator/runtime/marketplace_runtime.py +4 -10
- ads/opctl/operator/runtime/runtime.py +4 -6
- ads/pipeline/ads_pipeline_run.py +13 -25
- ads/pipeline/visualizer/graph_renderer.py +3 -4
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/METADATA +18 -15
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/RECORD +82 -74
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/WHEEL +1 -1
- ads/aqua/config/evaluation/evaluation_service_model_config.py +0 -8
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/entry_points.txt +0 -0
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info/licenses}/LICENSE.txt +0 -0
@@ -9,6 +9,7 @@ import os
|
|
9
9
|
import sys
|
10
10
|
import time
|
11
11
|
import traceback
|
12
|
+
import uuid
|
12
13
|
from string import Template
|
13
14
|
from typing import Any, Dict, List, Tuple
|
14
15
|
|
@@ -17,6 +18,7 @@ import yaml
|
|
17
18
|
from cerberus import Validator
|
18
19
|
|
19
20
|
from ads.opctl import logger, utils
|
21
|
+
from ads.common.oci_logging import OCILog
|
20
22
|
|
21
23
|
CONTAINER_NETWORK = "CONTAINER_NETWORK"
|
22
24
|
|
@@ -190,3 +192,17 @@ def print_traceback():
|
|
190
192
|
if logger.level == logging.DEBUG:
|
191
193
|
ex_type, ex, tb = sys.exc_info()
|
192
194
|
traceback.print_tb(tb)
|
195
|
+
|
196
|
+
|
197
|
+
def create_log_in_log_group(compartment_id, log_group_id, auth, log_name=None):
|
198
|
+
"""
|
199
|
+
Creates a log within a given log group
|
200
|
+
"""
|
201
|
+
if not log_name:
|
202
|
+
log_name = f"log-{int(time.time())}-{uuid.uuid4()}"
|
203
|
+
log = OCILog(display_name=log_name,
|
204
|
+
log_group_id=log_group_id,
|
205
|
+
compartment_id=compartment_id,
|
206
|
+
**auth).create()
|
207
|
+
logger.info(f"Created log with log OCID {log.id}")
|
208
|
+
return log.id
|
@@ -1,14 +1,13 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c) 2023 Oracle and/or its affiliates.
|
3
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
|
-
from ads.common.extended_enum import
|
6
|
+
from ads.common.extended_enum import ExtendedEnum
|
8
7
|
from ads.opctl.operator.lowcode.common.const import DataColumns
|
9
8
|
|
10
9
|
|
11
|
-
class SupportedModels(
|
10
|
+
class SupportedModels(ExtendedEnum):
|
12
11
|
"""Supported anomaly models."""
|
13
12
|
|
14
13
|
AutoTS = "autots"
|
@@ -38,7 +37,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
|
|
38
37
|
BOCPD = "bocpd"
|
39
38
|
|
40
39
|
|
41
|
-
class NonTimeADSupportedModels(
|
40
|
+
class NonTimeADSupportedModels(ExtendedEnum):
|
42
41
|
"""Supported non time-based anomaly detection models."""
|
43
42
|
|
44
43
|
OneClassSVM = "oneclasssvm"
|
@@ -48,7 +47,7 @@ class NonTimeADSupportedModels(str, metaclass=ExtendedEnumMeta):
|
|
48
47
|
# DBScan = "dbscan"
|
49
48
|
|
50
49
|
|
51
|
-
class TODSSubModels(
|
50
|
+
class TODSSubModels(ExtendedEnum):
|
52
51
|
"""Supported TODS sub models."""
|
53
52
|
|
54
53
|
OCSVM = "ocsvm"
|
@@ -78,7 +77,7 @@ TODS_MODEL_MAP = {
|
|
78
77
|
}
|
79
78
|
|
80
79
|
|
81
|
-
class MerlionADModels(
|
80
|
+
class MerlionADModels(ExtendedEnum):
|
82
81
|
"""Supported Merlion AD sub models."""
|
83
82
|
|
84
83
|
# point anomaly
|
@@ -126,7 +125,7 @@ MERLIONAD_MODEL_MAP = {
|
|
126
125
|
}
|
127
126
|
|
128
127
|
|
129
|
-
class SupportedMetrics(
|
128
|
+
class SupportedMetrics(ExtendedEnum):
|
130
129
|
UNSUPERVISED_UNIFY95 = "unsupervised_unify95"
|
131
130
|
UNSUPERVISED_UNIFY95_LOG_LOSS = "unsupervised_unify95_log_loss"
|
132
131
|
UNSUPERVISED_N1_EXPERTS = "unsupervised_n-1_experts"
|
@@ -158,7 +157,7 @@ class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
|
|
158
157
|
ELAPSED_TIME = "Elapsed Time"
|
159
158
|
|
160
159
|
|
161
|
-
class OutputColumns(
|
160
|
+
class OutputColumns(ExtendedEnum):
|
162
161
|
ANOMALY_COL = "anomaly"
|
163
162
|
SCORE_COL = "score"
|
164
163
|
Series = DataColumns.Series
|
@@ -19,13 +19,16 @@ from .transformations import Transformations
|
|
19
19
|
|
20
20
|
|
21
21
|
class AbstractData(ABC):
|
22
|
-
def __init__(self, spec
|
22
|
+
def __init__(self, spec, name="input_data", data=None):
|
23
23
|
self.Transformations = Transformations
|
24
24
|
self.data = None
|
25
25
|
self._data_dict = dict()
|
26
26
|
self.name = name
|
27
27
|
self.spec = spec
|
28
|
-
|
28
|
+
if data is not None:
|
29
|
+
self.data = data
|
30
|
+
else:
|
31
|
+
self.load_transform_ingest_data(spec)
|
29
32
|
|
30
33
|
def get_raw_data_by_cat(self, category):
|
31
34
|
mapping = self._data_transformer.get_target_category_columns_map()
|
@@ -31,7 +31,6 @@ class Transformations(ABC):
|
|
31
31
|
dataset_info : ForecastOperatorConfig
|
32
32
|
"""
|
33
33
|
self.name = name
|
34
|
-
self.has_artificial_series = False
|
35
34
|
self.dataset_info = dataset_info
|
36
35
|
self.target_category_columns = dataset_info.target_category_columns
|
37
36
|
self.target_column_name = dataset_info.target_column
|
@@ -136,7 +135,6 @@ class Transformations(ABC):
|
|
136
135
|
self._target_category_columns_map = {}
|
137
136
|
if not self.target_category_columns:
|
138
137
|
df[DataColumns.Series] = "Series 1"
|
139
|
-
self.has_artificial_series = True
|
140
138
|
else:
|
141
139
|
df[DataColumns.Series] = merge_category_columns(
|
142
140
|
df, self.target_category_columns
|
@@ -209,7 +207,7 @@ class Transformations(ABC):
|
|
209
207
|
|
210
208
|
def _missing_value_imputation_add(self, df):
|
211
209
|
"""
|
212
|
-
Function fills missing values
|
210
|
+
Function fills missing values with zero
|
213
211
|
|
214
212
|
Parameters
|
215
213
|
----------
|
@@ -219,15 +217,7 @@ class Transformations(ABC):
|
|
219
217
|
-------
|
220
218
|
A new Pandas DataFrame without missing values.
|
221
219
|
"""
|
222
|
-
|
223
|
-
for col in df.columns:
|
224
|
-
# find next int not in list
|
225
|
-
i = 0
|
226
|
-
vals = df[col].unique()
|
227
|
-
while i in vals:
|
228
|
-
i = i + 1
|
229
|
-
df[col] = df[col].fillna(0)
|
230
|
-
return df
|
220
|
+
return df.fillna(0)
|
231
221
|
|
232
222
|
def _outlier_treatment(self, df):
|
233
223
|
"""
|
@@ -1,71 +1,66 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
3
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
import ast
|
8
7
|
import base64
|
9
|
-
from typing import
|
8
|
+
from typing import TYPE_CHECKING, Dict, List, Optional
|
10
9
|
|
11
10
|
import oci
|
12
11
|
import requests
|
13
|
-
from typing import TYPE_CHECKING
|
14
12
|
|
15
13
|
try:
|
16
14
|
from kubernetes.client import (
|
17
|
-
V1ServiceStatus,
|
18
|
-
V1Service,
|
19
|
-
V1LoadBalancerStatus,
|
20
15
|
V1LoadBalancerIngress,
|
16
|
+
V1LoadBalancerStatus,
|
17
|
+
V1Service,
|
18
|
+
V1ServiceStatus,
|
21
19
|
)
|
22
20
|
except ImportError:
|
23
21
|
if TYPE_CHECKING:
|
24
22
|
from kubernetes.client import (
|
25
|
-
V1ServiceStatus,
|
26
|
-
V1Service,
|
27
|
-
V1LoadBalancerStatus,
|
28
23
|
V1LoadBalancerIngress,
|
24
|
+
V1LoadBalancerStatus,
|
25
|
+
V1Service,
|
26
|
+
V1ServiceStatus,
|
29
27
|
)
|
30
28
|
|
31
|
-
|
32
|
-
|
33
|
-
from ads.opctl.operator.lowcode.feature_store_marketplace.models.apigw_config import (
|
34
|
-
APIGatewayConfig,
|
35
|
-
)
|
29
|
+
import click
|
30
|
+
from oci.resource_manager.models import AssociatedResourceSummary, StackSummary
|
36
31
|
|
37
|
-
from ads.common.oci_client import OCIClientFactory
|
38
|
-
from ads.opctl.operator.lowcode.feature_store_marketplace.const import (
|
39
|
-
LISTING_ID,
|
40
|
-
APIGW_STACK_NAME,
|
41
|
-
STACK_URL,
|
42
|
-
NLB_RULES_ADDRESS,
|
43
|
-
NODES_RULES_ADDRESS,
|
44
|
-
)
|
45
32
|
from ads import logger
|
46
|
-
import
|
33
|
+
from ads.common import auth as authutil
|
34
|
+
from ads.common.oci_client import OCIClientFactory
|
47
35
|
from ads.opctl import logger
|
48
|
-
|
49
36
|
from ads.opctl.backend.marketplace.marketplace_utils import (
|
50
37
|
Color,
|
51
38
|
print_heading,
|
52
39
|
print_ticker,
|
53
40
|
)
|
54
|
-
from ads.opctl.operator.lowcode.feature_store_marketplace.
|
55
|
-
|
41
|
+
from ads.opctl.operator.lowcode.feature_store_marketplace.const import (
|
42
|
+
APIGW_STACK_NAME,
|
43
|
+
LISTING_ID,
|
44
|
+
NLB_RULES_ADDRESS,
|
45
|
+
NODES_RULES_ADDRESS,
|
46
|
+
STACK_URL,
|
47
|
+
)
|
48
|
+
from ads.opctl.operator.lowcode.feature_store_marketplace.models.apigw_config import (
|
49
|
+
APIGatewayConfig,
|
56
50
|
)
|
57
|
-
|
58
51
|
from ads.opctl.operator.lowcode.feature_store_marketplace.models.db_config import (
|
59
52
|
DBConfig,
|
60
53
|
)
|
61
|
-
from ads.
|
54
|
+
from ads.opctl.operator.lowcode.feature_store_marketplace.models.mysql_config import (
|
55
|
+
MySqlConfig,
|
56
|
+
)
|
62
57
|
|
63
58
|
|
64
59
|
def get_db_details() -> DBConfig:
|
65
60
|
jdbc_url = "jdbc:mysql://{}/{}?createDatabaseIfNotExist=true"
|
66
61
|
mysql_db_config = MySqlConfig()
|
67
62
|
print_heading(
|
68
|
-
|
63
|
+
"MySQL database configuration",
|
69
64
|
colors=[Color.BOLD, Color.BLUE],
|
70
65
|
prefix_newline_count=2,
|
71
66
|
)
|
@@ -76,12 +71,12 @@ def get_db_details() -> DBConfig:
|
|
76
71
|
"Is password provided as plain-text or via a Vault secret?\n"
|
77
72
|
"(https://docs.oracle.com/en-us/iaas/Content/KeyManagement/Concepts/keyoverview.htm)",
|
78
73
|
type=click.Choice(MySqlConfig.MySQLAuthType.values()),
|
79
|
-
default=MySqlConfig.MySQLAuthType.BASIC
|
74
|
+
default=MySqlConfig.MySQLAuthType.BASIC,
|
80
75
|
)
|
81
76
|
)
|
82
77
|
if mysql_db_config.auth_type == MySqlConfig.MySQLAuthType.BASIC:
|
83
78
|
basic_auth_config = MySqlConfig.BasicConfig()
|
84
|
-
basic_auth_config.password = click.prompt(
|
79
|
+
basic_auth_config.password = click.prompt("Password", hide_input=True)
|
85
80
|
mysql_db_config.basic_config = basic_auth_config
|
86
81
|
|
87
82
|
elif mysql_db_config.auth_type == MySqlConfig.MySQLAuthType.VAULT:
|
@@ -176,12 +171,12 @@ def detect_or_create_stack(apigw_config: APIGatewayConfig):
|
|
176
171
|
).data
|
177
172
|
|
178
173
|
if len(stacks) >= 1:
|
179
|
-
print(
|
174
|
+
print("Auto-detected feature store stack(s) in tenancy:")
|
180
175
|
for stack in stacks:
|
181
176
|
_print_stack_detail(stack)
|
182
177
|
choices = {"1": "new", "2": "existing"}
|
183
178
|
stack_provision_method = click.prompt(
|
184
|
-
|
179
|
+
"Select stack provisioning method:\n1.Create new stack\n2.Existing stack\n",
|
185
180
|
type=click.Choice(list(choices.keys())),
|
186
181
|
show_choices=False,
|
187
182
|
)
|
@@ -240,20 +235,20 @@ def get_api_gw_details(compartment_id: str) -> APIGatewayConfig:
|
|
240
235
|
|
241
236
|
|
242
237
|
def get_nlb_id_from_service(service: "V1Service", apigw_config: APIGatewayConfig):
|
243
|
-
status:
|
244
|
-
lb_status:
|
245
|
-
lb_ingress:
|
238
|
+
status: V1ServiceStatus = service.status
|
239
|
+
lb_status: V1LoadBalancerStatus = status.load_balancer
|
240
|
+
lb_ingress: V1LoadBalancerIngress = lb_status.ingress[0]
|
246
241
|
resource_client = OCIClientFactory(**authutil.default_signer()).create_client(
|
247
242
|
oci.resource_search.ResourceSearchClient
|
248
243
|
)
|
249
244
|
search_details = oci.resource_search.models.FreeTextSearchDetails()
|
250
245
|
search_details.matching_context_type = "NONE"
|
251
246
|
search_details.text = lb_ingress.ip
|
252
|
-
resources: List[
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
)
|
247
|
+
resources: List[oci.resource_search.models.ResourceSummary] = (
|
248
|
+
resource_client.search_resources(
|
249
|
+
search_details, tenant_id=apigw_config.root_compartment_id
|
250
|
+
).data.items
|
251
|
+
)
|
257
252
|
private_ips = list(filter(lambda obj: obj.resource_type == "PrivateIp", resources))
|
258
253
|
if len(private_ips) != 1:
|
259
254
|
return click.prompt(
|
@@ -264,12 +259,12 @@ def get_nlb_id_from_service(service: "V1Service", apigw_config: APIGatewayConfig
|
|
264
259
|
nlb_client = OCIClientFactory(**authutil.default_signer()).create_client(
|
265
260
|
oci.network_load_balancer.NetworkLoadBalancerClient
|
266
261
|
)
|
267
|
-
nlbs: List[
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
)
|
262
|
+
nlbs: List[oci.network_load_balancer.models.NetworkLoadBalancerSummary] = (
|
263
|
+
nlb_client.list_network_load_balancers(
|
264
|
+
compartment_id=nlb_private_ip.compartment_id,
|
265
|
+
display_name=nlb_private_ip.display_name,
|
266
|
+
).data.items
|
267
|
+
)
|
273
268
|
if len(nlbs) != 1:
|
274
269
|
return click.prompt(
|
275
270
|
f"Please enter OCID of load balancer associated with ip: {lb_ingress.ip}"
|
@@ -1,7 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c) 2023,
|
3
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
import json
|
@@ -15,17 +14,17 @@ from ads.opctl import logger
|
|
15
14
|
from ads.opctl.operator.common.const import ENV_OPERATOR_ARGS
|
16
15
|
from ads.opctl.operator.common.utils import _parse_input_args
|
17
16
|
|
17
|
+
from .model.forecast_datasets import ForecastDatasets, ForecastResults
|
18
18
|
from .operator_config import ForecastOperatorConfig
|
19
|
-
from .model.forecast_datasets import ForecastDatasets
|
20
19
|
from .whatifserve import ModelDeploymentManager
|
21
20
|
|
22
21
|
|
23
|
-
def operate(operator_config: ForecastOperatorConfig) ->
|
22
|
+
def operate(operator_config: ForecastOperatorConfig) -> ForecastResults:
|
24
23
|
"""Runs the forecasting operator."""
|
25
24
|
from .model.factory import ForecastOperatorModelFactory
|
26
25
|
|
27
26
|
datasets = ForecastDatasets(operator_config)
|
28
|
-
ForecastOperatorModelFactory.get_model(
|
27
|
+
results = ForecastOperatorModelFactory.get_model(
|
29
28
|
operator_config, datasets
|
30
29
|
).generate_report()
|
31
30
|
# saving to model catalog
|
@@ -36,6 +35,7 @@ def operate(operator_config: ForecastOperatorConfig) -> None:
|
|
36
35
|
if spec.what_if_analysis.model_deployment:
|
37
36
|
mdm.create_deployment()
|
38
37
|
mdm.save_deployment_info()
|
38
|
+
return results
|
39
39
|
|
40
40
|
|
41
41
|
def verify(spec: Dict, **kwargs: Dict) -> bool:
|
@@ -1,13 +1,13 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
|
3
|
-
# Copyright (c) 2023,
|
3
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
|
6
|
-
from ads.common.extended_enum import
|
6
|
+
from ads.common.extended_enum import ExtendedEnum
|
7
7
|
from ads.opctl.operator.lowcode.common.const import DataColumns
|
8
8
|
|
9
9
|
|
10
|
-
class SupportedModels(
|
10
|
+
class SupportedModels(ExtendedEnum):
|
11
11
|
"""Supported forecast models."""
|
12
12
|
|
13
13
|
Prophet = "prophet"
|
@@ -19,7 +19,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
|
|
19
19
|
# Auto = "auto"
|
20
20
|
|
21
21
|
|
22
|
-
class SpeedAccuracyMode(
|
22
|
+
class SpeedAccuracyMode(ExtendedEnum):
|
23
23
|
"""
|
24
24
|
Enum representing different modes based on time taken and accuracy for explainability.
|
25
25
|
"""
|
@@ -35,7 +35,7 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
|
|
35
35
|
ratio[AUTOMLX] = 0 # constant
|
36
36
|
|
37
37
|
|
38
|
-
class SupportedMetrics(
|
38
|
+
class SupportedMetrics(ExtendedEnum):
|
39
39
|
"""Supported forecast metrics."""
|
40
40
|
|
41
41
|
MAPE = "MAPE"
|
@@ -62,7 +62,7 @@ class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
|
|
62
62
|
ELAPSED_TIME = "Elapsed Time"
|
63
63
|
|
64
64
|
|
65
|
-
class ForecastOutputColumns(
|
65
|
+
class ForecastOutputColumns(ExtendedEnum):
|
66
66
|
"""The column names for the forecast.csv output file"""
|
67
67
|
|
68
68
|
DATE = "Date"
|
@@ -116,7 +116,10 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
116
116
|
lower_bound=self.get_horizon(forecast["yhat_lower"]).values,
|
117
117
|
)
|
118
118
|
|
119
|
-
self.models[s_id] =
|
119
|
+
self.models[s_id] = {}
|
120
|
+
self.models[s_id]["model"] = model
|
121
|
+
self.models[s_id]["le"] = self.le[s_id]
|
122
|
+
self.models[s_id]["predict_component_cols"] = X_pred.columns
|
120
123
|
|
121
124
|
params = vars(model).copy()
|
122
125
|
for param in ["arima_res_", "endog_index_"]:
|
@@ -163,7 +166,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
163
166
|
sec5_text = rc.Heading("ARIMA Model Parameters", level=2)
|
164
167
|
blocks = [
|
165
168
|
rc.Html(
|
166
|
-
m.summary().as_html(),
|
169
|
+
m['model'].summary().as_html(),
|
167
170
|
label=s_id if self.target_cat_col else None,
|
168
171
|
)
|
169
172
|
for i, (s_id, m) in enumerate(self.models.items())
|
@@ -251,7 +254,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
251
254
|
def get_explain_predict_fn(self, series_id):
|
252
255
|
def _custom_predict(
|
253
256
|
data,
|
254
|
-
model=self.models[series_id],
|
257
|
+
model=self.models[series_id]["model"],
|
255
258
|
dt_column_name=self.datasets._datetime_column_name,
|
256
259
|
target_col=self.original_target_column,
|
257
260
|
):
|
@@ -1,5 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2023,
|
2
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
import logging
|
5
5
|
import os
|
@@ -56,8 +56,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
56
56
|
)
|
57
57
|
return model_kwargs_cleaned, time_budget
|
58
58
|
|
59
|
-
def preprocess(self, data): # TODO: re-use self.le for explanations
|
60
|
-
|
59
|
+
def preprocess(self, data, series_id): # TODO: re-use self.le for explanations
|
60
|
+
self.le[series_id], df_encoded = _label_encode_dataframe(
|
61
61
|
data,
|
62
62
|
no_encode={self.spec.datetime_column.name, self.original_target_column},
|
63
63
|
)
|
@@ -66,8 +66,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
66
66
|
@runtime_dependency(
|
67
67
|
module="automlx",
|
68
68
|
err_msg=(
|
69
|
-
"Please run `pip3 install oracle-automlx>=
|
70
|
-
"`pip3 install oracle-automlx[forecasting]>=23.4.1` "
|
69
|
+
"Please run `pip3 install oracle-automlx[forecasting]>=25.1.1` "
|
71
70
|
"to install the required dependencies for automlx."
|
72
71
|
),
|
73
72
|
)
|
@@ -105,7 +104,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
105
104
|
engine_opts = (
|
106
105
|
None
|
107
106
|
if engine_type == "local"
|
108
|
-
else
|
107
|
+
else {"ray_setup": {"_temp_dir": "/tmp/ray-temp"}}
|
109
108
|
)
|
110
109
|
init(
|
111
110
|
engine=engine_type,
|
@@ -125,7 +124,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
125
124
|
self.forecast_output.init_series_output(
|
126
125
|
series_id=s_id, data_at_series=df
|
127
126
|
)
|
128
|
-
data = self.preprocess(df)
|
127
|
+
data = self.preprocess(df, s_id)
|
129
128
|
data_i = self.drop_horizon(data)
|
130
129
|
X_pred = self.get_horizon(data).drop(target, axis=1)
|
131
130
|
|
@@ -157,7 +156,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
157
156
|
target
|
158
157
|
].values
|
159
158
|
|
160
|
-
self.models[s_id] =
|
159
|
+
self.models[s_id] = {}
|
160
|
+
self.models[s_id]["model"] = model
|
161
|
+
self.models[s_id]["le"] = self.le[s_id]
|
161
162
|
|
162
163
|
# In case of Naive model, model.forecast function call does not return confidence intervals.
|
163
164
|
if f"{target}_ci_upper" not in summary_frame:
|
@@ -218,7 +219,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
218
219
|
other_sections = []
|
219
220
|
|
220
221
|
if len(self.models) > 0:
|
221
|
-
for s_id,
|
222
|
+
for s_id, artifacts in models.items():
|
223
|
+
m = artifacts["model"]
|
222
224
|
selected_models[s_id] = {
|
223
225
|
"series_id": s_id,
|
224
226
|
"selected_model": m.selected_model_,
|
@@ -247,17 +249,17 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
247
249
|
self.explain_model()
|
248
250
|
|
249
251
|
global_explanation_section = None
|
250
|
-
|
251
|
-
|
252
|
-
global_explanation_df = pd.DataFrame(self.global_explanation)
|
252
|
+
# Convert the global explanation data to a DataFrame
|
253
|
+
global_explanation_df = pd.DataFrame(self.global_explanation)
|
253
254
|
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
255
|
+
self.formatted_global_explanation = (
|
256
|
+
global_explanation_df / global_explanation_df.sum(axis=0) * 100
|
257
|
+
)
|
258
|
+
|
259
|
+
self.formatted_global_explanation.rename(
|
260
|
+
columns={self.spec.datetime_column.name: ForecastOutputColumns.DATE},
|
261
|
+
inplace=True,
|
262
|
+
)
|
261
263
|
|
262
264
|
aggregate_local_explanations = pd.DataFrame()
|
263
265
|
for s_id, local_ex_df in self.local_explanation.items():
|
@@ -269,11 +271,15 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
269
271
|
self.formatted_local_explanation = aggregate_local_explanations
|
270
272
|
|
271
273
|
if not self.target_cat_col:
|
272
|
-
self.formatted_global_explanation =
|
273
|
-
|
274
|
-
|
274
|
+
self.formatted_global_explanation = (
|
275
|
+
self.formatted_global_explanation.rename(
|
276
|
+
{"Series 1": self.original_target_column},
|
277
|
+
axis=1,
|
278
|
+
)
|
279
|
+
)
|
280
|
+
self.formatted_local_explanation.drop(
|
281
|
+
"Series", axis=1, inplace=True
|
275
282
|
)
|
276
|
-
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
|
277
283
|
|
278
284
|
# Create a markdown section for the global explainability
|
279
285
|
global_explanation_section = rc.Block(
|
@@ -320,7 +326,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
320
326
|
)
|
321
327
|
|
322
328
|
def get_explain_predict_fn(self, series_id):
|
323
|
-
selected_model = self.models[series_id]
|
329
|
+
selected_model = self.models[series_id]["model"]
|
324
330
|
|
325
331
|
# If training date, use method below. If future date, use forecast!
|
326
332
|
def _custom_predict_fn(
|
@@ -338,12 +344,12 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
338
344
|
data[dt_column_name] = seconds_to_datetime(
|
339
345
|
data[dt_column_name], dt_format=self.spec.datetime_column.format
|
340
346
|
)
|
341
|
-
data = self.preprocess(data)
|
347
|
+
data = self.preprocess(data, series_id)
|
342
348
|
horizon_data = horizon_data.drop(target_col, axis=1)
|
343
349
|
horizon_data[dt_column_name] = seconds_to_datetime(
|
344
350
|
horizon_data[dt_column_name], dt_format=self.spec.datetime_column.format
|
345
351
|
)
|
346
|
-
horizon_data = self.preprocess(horizon_data)
|
352
|
+
horizon_data = self.preprocess(horizon_data, series_id)
|
347
353
|
|
348
354
|
rows = []
|
349
355
|
for i in range(data.shape[0]):
|
@@ -421,8 +427,10 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
421
427
|
if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
|
422
428
|
# Use the MLExplainer class from AutoMLx to generate explanations
|
423
429
|
explainer = automlx.MLExplainer(
|
424
|
-
self.models[s_id],
|
425
|
-
self.datasets.additional_data.get_data_for_series(
|
430
|
+
self.models[s_id]["model"],
|
431
|
+
self.datasets.additional_data.get_data_for_series(
|
432
|
+
series_id=s_id
|
433
|
+
)
|
426
434
|
.drop(self.spec.datetime_column.name, axis=1)
|
427
435
|
.head(-self.spec.horizon)
|
428
436
|
if self.spec.additional_data
|
@@ -433,7 +441,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
433
441
|
|
434
442
|
# Generate explanations for the forecast
|
435
443
|
explanations = explainer.explain_prediction(
|
436
|
-
X=self.datasets.additional_data.get_data_for_series(
|
444
|
+
X=self.datasets.additional_data.get_data_for_series(
|
445
|
+
series_id=s_id
|
446
|
+
)
|
437
447
|
.drop(self.spec.datetime_column.name, axis=1)
|
438
448
|
.tail(self.spec.horizon)
|
439
449
|
if self.spec.additional_data
|
@@ -445,7 +455,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
445
455
|
explanations_df = pd.concat(
|
446
456
|
[exp.to_dataframe() for exp in explanations]
|
447
457
|
)
|
448
|
-
explanations_df["row"] = explanations_df.groupby(
|
458
|
+
explanations_df["row"] = explanations_df.groupby(
|
459
|
+
"Feature"
|
460
|
+
).cumcount()
|
449
461
|
explanations_df = explanations_df.pivot(
|
450
462
|
index="row", columns="Feature", values="Attribution"
|
451
463
|
)
|
@@ -453,9 +465,27 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
453
465
|
|
454
466
|
# Store the explanations in the local_explanation dictionary
|
455
467
|
self.local_explanation[s_id] = explanations_df
|
468
|
+
|
469
|
+
self.global_explanation[s_id] = dict(
|
470
|
+
zip(
|
471
|
+
self.local_explanation[s_id].columns,
|
472
|
+
np.nanmean(np.abs(self.local_explanation[s_id]), axis=0),
|
473
|
+
)
|
474
|
+
)
|
456
475
|
else:
|
457
476
|
# Fall back to the default explanation generation method
|
458
477
|
super().explain_model()
|
459
478
|
except Exception as e:
|
460
|
-
|
479
|
+
if s_id in self.errors_dict:
|
480
|
+
self.errors_dict[s_id]["explainer_error"] = str(e)
|
481
|
+
self.errors_dict[s_id]["explainer_error_trace"] = traceback.format_exc()
|
482
|
+
else:
|
483
|
+
self.errors_dict[s_id] = {
|
484
|
+
"model_name": self.spec.model,
|
485
|
+
"explainer_error": str(e),
|
486
|
+
"explainer_error_trace": traceback.format_exc(),
|
487
|
+
}
|
488
|
+
logger.warning(
|
489
|
+
f"Failed to generate explanations for series {s_id} with error: {e}."
|
490
|
+
)
|
461
491
|
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|