oracle-ads 2.12.11__py3-none-any.whl → 2.13.1rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/app.py +23 -10
- ads/aqua/common/enums.py +19 -14
- ads/aqua/common/errors.py +3 -4
- ads/aqua/common/utils.py +2 -2
- ads/aqua/constants.py +1 -0
- ads/aqua/evaluation/constants.py +7 -7
- ads/aqua/evaluation/errors.py +3 -4
- ads/aqua/extension/model_handler.py +23 -0
- ads/aqua/extension/models/ws_models.py +5 -6
- ads/aqua/finetuning/constants.py +3 -3
- ads/aqua/model/constants.py +7 -7
- ads/aqua/model/enums.py +4 -5
- ads/aqua/model/model.py +22 -0
- ads/aqua/modeldeployment/entities.py +3 -1
- ads/common/auth.py +33 -20
- ads/common/extended_enum.py +52 -44
- ads/llm/__init__.py +11 -8
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/model/artifact_downloader.py +3 -4
- ads/model/datascience_model.py +84 -64
- ads/model/generic_model.py +3 -3
- ads/model/model_metadata.py +17 -11
- ads/model/service/oci_datascience_model.py +12 -14
- ads/opctl/anomaly_detection.py +11 -0
- ads/opctl/backend/marketplace/helm_helper.py +13 -14
- ads/opctl/cli.py +4 -5
- ads/opctl/cmds.py +28 -32
- ads/opctl/config/merger.py +8 -11
- ads/opctl/config/resolver.py +25 -30
- ads/opctl/forecast.py +11 -0
- ads/opctl/operator/cli.py +9 -9
- ads/opctl/operator/common/backend_factory.py +56 -60
- ads/opctl/operator/common/const.py +5 -5
- ads/opctl/operator/lowcode/anomaly/const.py +8 -9
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +43 -48
- ads/opctl/operator/lowcode/forecast/__main__.py +5 -5
- ads/opctl/operator/lowcode/forecast/const.py +6 -6
- ads/opctl/operator/lowcode/forecast/model/arima.py +6 -3
- ads/opctl/operator/lowcode/forecast/model/automlx.py +53 -31
- ads/opctl/operator/lowcode/forecast/model/base_model.py +57 -30
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +60 -2
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +5 -2
- ads/opctl/operator/lowcode/forecast/model/prophet.py +28 -15
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +19 -11
- ads/opctl/operator/lowcode/pii/constant.py +6 -7
- ads/opctl/operator/lowcode/recommender/constant.py +12 -7
- ads/opctl/operator/runtime/marketplace_runtime.py +4 -10
- ads/opctl/operator/runtime/runtime.py +4 -6
- ads/pipeline/ads_pipeline_run.py +13 -25
- ads/pipeline/visualizer/graph_renderer.py +3 -4
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/METADATA +6 -6
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/RECORD +56 -52
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/entry_points.txt +0 -0
@@ -1,14 +1,13 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c) 2023 Oracle and/or its affiliates.
|
3
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
|
-
from ads.common.extended_enum import
|
6
|
+
from ads.common.extended_enum import ExtendedEnum
|
8
7
|
from ads.opctl.operator.lowcode.common.const import DataColumns
|
9
8
|
|
10
9
|
|
11
|
-
class SupportedModels(
|
10
|
+
class SupportedModels(ExtendedEnum):
|
12
11
|
"""Supported anomaly models."""
|
13
12
|
|
14
13
|
AutoTS = "autots"
|
@@ -38,7 +37,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
|
|
38
37
|
BOCPD = "bocpd"
|
39
38
|
|
40
39
|
|
41
|
-
class NonTimeADSupportedModels(
|
40
|
+
class NonTimeADSupportedModels(ExtendedEnum):
|
42
41
|
"""Supported non time-based anomaly detection models."""
|
43
42
|
|
44
43
|
OneClassSVM = "oneclasssvm"
|
@@ -48,7 +47,7 @@ class NonTimeADSupportedModels(str, metaclass=ExtendedEnumMeta):
|
|
48
47
|
# DBScan = "dbscan"
|
49
48
|
|
50
49
|
|
51
|
-
class TODSSubModels(
|
50
|
+
class TODSSubModels(ExtendedEnum):
|
52
51
|
"""Supported TODS sub models."""
|
53
52
|
|
54
53
|
OCSVM = "ocsvm"
|
@@ -78,7 +77,7 @@ TODS_MODEL_MAP = {
|
|
78
77
|
}
|
79
78
|
|
80
79
|
|
81
|
-
class MerlionADModels(
|
80
|
+
class MerlionADModels(ExtendedEnum):
|
82
81
|
"""Supported Merlion AD sub models."""
|
83
82
|
|
84
83
|
# point anomaly
|
@@ -126,7 +125,7 @@ MERLIONAD_MODEL_MAP = {
|
|
126
125
|
}
|
127
126
|
|
128
127
|
|
129
|
-
class SupportedMetrics(
|
128
|
+
class SupportedMetrics(ExtendedEnum):
|
130
129
|
UNSUPERVISED_UNIFY95 = "unsupervised_unify95"
|
131
130
|
UNSUPERVISED_UNIFY95_LOG_LOSS = "unsupervised_unify95_log_loss"
|
132
131
|
UNSUPERVISED_N1_EXPERTS = "unsupervised_n-1_experts"
|
@@ -158,7 +157,7 @@ class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
|
|
158
157
|
ELAPSED_TIME = "Elapsed Time"
|
159
158
|
|
160
159
|
|
161
|
-
class OutputColumns(
|
160
|
+
class OutputColumns(ExtendedEnum):
|
162
161
|
ANOMALY_COL = "anomaly"
|
163
162
|
SCORE_COL = "score"
|
164
163
|
Series = DataColumns.Series
|
@@ -1,71 +1,66 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
3
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
import ast
|
8
7
|
import base64
|
9
|
-
from typing import
|
8
|
+
from typing import TYPE_CHECKING, Dict, List, Optional
|
10
9
|
|
11
10
|
import oci
|
12
11
|
import requests
|
13
|
-
from typing import TYPE_CHECKING
|
14
12
|
|
15
13
|
try:
|
16
14
|
from kubernetes.client import (
|
17
|
-
V1ServiceStatus,
|
18
|
-
V1Service,
|
19
|
-
V1LoadBalancerStatus,
|
20
15
|
V1LoadBalancerIngress,
|
16
|
+
V1LoadBalancerStatus,
|
17
|
+
V1Service,
|
18
|
+
V1ServiceStatus,
|
21
19
|
)
|
22
20
|
except ImportError:
|
23
21
|
if TYPE_CHECKING:
|
24
22
|
from kubernetes.client import (
|
25
|
-
V1ServiceStatus,
|
26
|
-
V1Service,
|
27
|
-
V1LoadBalancerStatus,
|
28
23
|
V1LoadBalancerIngress,
|
24
|
+
V1LoadBalancerStatus,
|
25
|
+
V1Service,
|
26
|
+
V1ServiceStatus,
|
29
27
|
)
|
30
28
|
|
31
|
-
|
32
|
-
|
33
|
-
from ads.opctl.operator.lowcode.feature_store_marketplace.models.apigw_config import (
|
34
|
-
APIGatewayConfig,
|
35
|
-
)
|
29
|
+
import click
|
30
|
+
from oci.resource_manager.models import AssociatedResourceSummary, StackSummary
|
36
31
|
|
37
|
-
from ads.common.oci_client import OCIClientFactory
|
38
|
-
from ads.opctl.operator.lowcode.feature_store_marketplace.const import (
|
39
|
-
LISTING_ID,
|
40
|
-
APIGW_STACK_NAME,
|
41
|
-
STACK_URL,
|
42
|
-
NLB_RULES_ADDRESS,
|
43
|
-
NODES_RULES_ADDRESS,
|
44
|
-
)
|
45
32
|
from ads import logger
|
46
|
-
import
|
33
|
+
from ads.common import auth as authutil
|
34
|
+
from ads.common.oci_client import OCIClientFactory
|
47
35
|
from ads.opctl import logger
|
48
|
-
|
49
36
|
from ads.opctl.backend.marketplace.marketplace_utils import (
|
50
37
|
Color,
|
51
38
|
print_heading,
|
52
39
|
print_ticker,
|
53
40
|
)
|
54
|
-
from ads.opctl.operator.lowcode.feature_store_marketplace.
|
55
|
-
|
41
|
+
from ads.opctl.operator.lowcode.feature_store_marketplace.const import (
|
42
|
+
APIGW_STACK_NAME,
|
43
|
+
LISTING_ID,
|
44
|
+
NLB_RULES_ADDRESS,
|
45
|
+
NODES_RULES_ADDRESS,
|
46
|
+
STACK_URL,
|
47
|
+
)
|
48
|
+
from ads.opctl.operator.lowcode.feature_store_marketplace.models.apigw_config import (
|
49
|
+
APIGatewayConfig,
|
56
50
|
)
|
57
|
-
|
58
51
|
from ads.opctl.operator.lowcode.feature_store_marketplace.models.db_config import (
|
59
52
|
DBConfig,
|
60
53
|
)
|
61
|
-
from ads.
|
54
|
+
from ads.opctl.operator.lowcode.feature_store_marketplace.models.mysql_config import (
|
55
|
+
MySqlConfig,
|
56
|
+
)
|
62
57
|
|
63
58
|
|
64
59
|
def get_db_details() -> DBConfig:
|
65
60
|
jdbc_url = "jdbc:mysql://{}/{}?createDatabaseIfNotExist=true"
|
66
61
|
mysql_db_config = MySqlConfig()
|
67
62
|
print_heading(
|
68
|
-
|
63
|
+
"MySQL database configuration",
|
69
64
|
colors=[Color.BOLD, Color.BLUE],
|
70
65
|
prefix_newline_count=2,
|
71
66
|
)
|
@@ -76,12 +71,12 @@ def get_db_details() -> DBConfig:
|
|
76
71
|
"Is password provided as plain-text or via a Vault secret?\n"
|
77
72
|
"(https://docs.oracle.com/en-us/iaas/Content/KeyManagement/Concepts/keyoverview.htm)",
|
78
73
|
type=click.Choice(MySqlConfig.MySQLAuthType.values()),
|
79
|
-
default=MySqlConfig.MySQLAuthType.BASIC
|
74
|
+
default=MySqlConfig.MySQLAuthType.BASIC,
|
80
75
|
)
|
81
76
|
)
|
82
77
|
if mysql_db_config.auth_type == MySqlConfig.MySQLAuthType.BASIC:
|
83
78
|
basic_auth_config = MySqlConfig.BasicConfig()
|
84
|
-
basic_auth_config.password = click.prompt(
|
79
|
+
basic_auth_config.password = click.prompt("Password", hide_input=True)
|
85
80
|
mysql_db_config.basic_config = basic_auth_config
|
86
81
|
|
87
82
|
elif mysql_db_config.auth_type == MySqlConfig.MySQLAuthType.VAULT:
|
@@ -176,12 +171,12 @@ def detect_or_create_stack(apigw_config: APIGatewayConfig):
|
|
176
171
|
).data
|
177
172
|
|
178
173
|
if len(stacks) >= 1:
|
179
|
-
print(
|
174
|
+
print("Auto-detected feature store stack(s) in tenancy:")
|
180
175
|
for stack in stacks:
|
181
176
|
_print_stack_detail(stack)
|
182
177
|
choices = {"1": "new", "2": "existing"}
|
183
178
|
stack_provision_method = click.prompt(
|
184
|
-
|
179
|
+
"Select stack provisioning method:\n1.Create new stack\n2.Existing stack\n",
|
185
180
|
type=click.Choice(list(choices.keys())),
|
186
181
|
show_choices=False,
|
187
182
|
)
|
@@ -240,20 +235,20 @@ def get_api_gw_details(compartment_id: str) -> APIGatewayConfig:
|
|
240
235
|
|
241
236
|
|
242
237
|
def get_nlb_id_from_service(service: "V1Service", apigw_config: APIGatewayConfig):
|
243
|
-
status:
|
244
|
-
lb_status:
|
245
|
-
lb_ingress:
|
238
|
+
status: V1ServiceStatus = service.status
|
239
|
+
lb_status: V1LoadBalancerStatus = status.load_balancer
|
240
|
+
lb_ingress: V1LoadBalancerIngress = lb_status.ingress[0]
|
246
241
|
resource_client = OCIClientFactory(**authutil.default_signer()).create_client(
|
247
242
|
oci.resource_search.ResourceSearchClient
|
248
243
|
)
|
249
244
|
search_details = oci.resource_search.models.FreeTextSearchDetails()
|
250
245
|
search_details.matching_context_type = "NONE"
|
251
246
|
search_details.text = lb_ingress.ip
|
252
|
-
resources: List[
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
)
|
247
|
+
resources: List[oci.resource_search.models.ResourceSummary] = (
|
248
|
+
resource_client.search_resources(
|
249
|
+
search_details, tenant_id=apigw_config.root_compartment_id
|
250
|
+
).data.items
|
251
|
+
)
|
257
252
|
private_ips = list(filter(lambda obj: obj.resource_type == "PrivateIp", resources))
|
258
253
|
if len(private_ips) != 1:
|
259
254
|
return click.prompt(
|
@@ -264,12 +259,12 @@ def get_nlb_id_from_service(service: "V1Service", apigw_config: APIGatewayConfig
|
|
264
259
|
nlb_client = OCIClientFactory(**authutil.default_signer()).create_client(
|
265
260
|
oci.network_load_balancer.NetworkLoadBalancerClient
|
266
261
|
)
|
267
|
-
nlbs: List[
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
)
|
262
|
+
nlbs: List[oci.network_load_balancer.models.NetworkLoadBalancerSummary] = (
|
263
|
+
nlb_client.list_network_load_balancers(
|
264
|
+
compartment_id=nlb_private_ip.compartment_id,
|
265
|
+
display_name=nlb_private_ip.display_name,
|
266
|
+
).data.items
|
267
|
+
)
|
273
268
|
if len(nlbs) != 1:
|
274
269
|
return click.prompt(
|
275
270
|
f"Please enter OCID of load balancer associated with ip: {lb_ingress.ip}"
|
@@ -1,7 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c) 2023,
|
3
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
import json
|
@@ -15,17 +14,17 @@ from ads.opctl import logger
|
|
15
14
|
from ads.opctl.operator.common.const import ENV_OPERATOR_ARGS
|
16
15
|
from ads.opctl.operator.common.utils import _parse_input_args
|
17
16
|
|
17
|
+
from .model.forecast_datasets import ForecastDatasets, ForecastResults
|
18
18
|
from .operator_config import ForecastOperatorConfig
|
19
|
-
from .model.forecast_datasets import ForecastDatasets
|
20
19
|
from .whatifserve import ModelDeploymentManager
|
21
20
|
|
22
21
|
|
23
|
-
def operate(operator_config: ForecastOperatorConfig) ->
|
22
|
+
def operate(operator_config: ForecastOperatorConfig) -> ForecastResults:
|
24
23
|
"""Runs the forecasting operator."""
|
25
24
|
from .model.factory import ForecastOperatorModelFactory
|
26
25
|
|
27
26
|
datasets = ForecastDatasets(operator_config)
|
28
|
-
ForecastOperatorModelFactory.get_model(
|
27
|
+
results = ForecastOperatorModelFactory.get_model(
|
29
28
|
operator_config, datasets
|
30
29
|
).generate_report()
|
31
30
|
# saving to model catalog
|
@@ -36,6 +35,7 @@ def operate(operator_config: ForecastOperatorConfig) -> None:
|
|
36
35
|
if spec.what_if_analysis.model_deployment:
|
37
36
|
mdm.create_deployment()
|
38
37
|
mdm.save_deployment_info()
|
38
|
+
return results
|
39
39
|
|
40
40
|
|
41
41
|
def verify(spec: Dict, **kwargs: Dict) -> bool:
|
@@ -1,13 +1,13 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
|
3
|
-
# Copyright (c) 2023,
|
3
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
|
6
|
-
from ads.common.extended_enum import
|
6
|
+
from ads.common.extended_enum import ExtendedEnum
|
7
7
|
from ads.opctl.operator.lowcode.common.const import DataColumns
|
8
8
|
|
9
9
|
|
10
|
-
class SupportedModels(
|
10
|
+
class SupportedModels(ExtendedEnum):
|
11
11
|
"""Supported forecast models."""
|
12
12
|
|
13
13
|
Prophet = "prophet"
|
@@ -19,7 +19,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
|
|
19
19
|
# Auto = "auto"
|
20
20
|
|
21
21
|
|
22
|
-
class SpeedAccuracyMode(
|
22
|
+
class SpeedAccuracyMode(ExtendedEnum):
|
23
23
|
"""
|
24
24
|
Enum representing different modes based on time taken and accuracy for explainability.
|
25
25
|
"""
|
@@ -35,7 +35,7 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
|
|
35
35
|
ratio[AUTOMLX] = 0 # constant
|
36
36
|
|
37
37
|
|
38
|
-
class SupportedMetrics(
|
38
|
+
class SupportedMetrics(ExtendedEnum):
|
39
39
|
"""Supported forecast metrics."""
|
40
40
|
|
41
41
|
MAPE = "MAPE"
|
@@ -62,7 +62,7 @@ class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
|
|
62
62
|
ELAPSED_TIME = "Elapsed Time"
|
63
63
|
|
64
64
|
|
65
|
-
class ForecastOutputColumns(
|
65
|
+
class ForecastOutputColumns(ExtendedEnum):
|
66
66
|
"""The column names for the forecast.csv output file"""
|
67
67
|
|
68
68
|
DATE = "Date"
|
@@ -116,7 +116,10 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
116
116
|
lower_bound=self.get_horizon(forecast["yhat_lower"]).values,
|
117
117
|
)
|
118
118
|
|
119
|
-
self.models[s_id] =
|
119
|
+
self.models[s_id] = {}
|
120
|
+
self.models[s_id]["model"] = model
|
121
|
+
self.models[s_id]["le"] = self.le[s_id]
|
122
|
+
self.models[s_id]["predict_component_cols"] = X_pred.columns
|
120
123
|
|
121
124
|
params = vars(model).copy()
|
122
125
|
for param in ["arima_res_", "endog_index_"]:
|
@@ -163,7 +166,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
163
166
|
sec5_text = rc.Heading("ARIMA Model Parameters", level=2)
|
164
167
|
blocks = [
|
165
168
|
rc.Html(
|
166
|
-
m.summary().as_html(),
|
169
|
+
m['model'].summary().as_html(),
|
167
170
|
label=s_id if self.target_cat_col else None,
|
168
171
|
)
|
169
172
|
for i, (s_id, m) in enumerate(self.models.items())
|
@@ -251,7 +254,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
|
|
251
254
|
def get_explain_predict_fn(self, series_id):
|
252
255
|
def _custom_predict(
|
253
256
|
data,
|
254
|
-
model=self.models[series_id],
|
257
|
+
model=self.models[series_id]["model"],
|
255
258
|
dt_column_name=self.datasets._datetime_column_name,
|
256
259
|
target_col=self.original_target_column,
|
257
260
|
):
|
@@ -1,5 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2023,
|
2
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
import logging
|
5
5
|
import os
|
@@ -56,8 +56,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
56
56
|
)
|
57
57
|
return model_kwargs_cleaned, time_budget
|
58
58
|
|
59
|
-
def preprocess(self, data): # TODO: re-use self.le for explanations
|
60
|
-
|
59
|
+
def preprocess(self, data, series_id): # TODO: re-use self.le for explanations
|
60
|
+
self.le[series_id], df_encoded = _label_encode_dataframe(
|
61
61
|
data,
|
62
62
|
no_encode={self.spec.datetime_column.name, self.original_target_column},
|
63
63
|
)
|
@@ -66,8 +66,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
66
66
|
@runtime_dependency(
|
67
67
|
module="automlx",
|
68
68
|
err_msg=(
|
69
|
-
"Please run `pip3 install oracle-automlx>=
|
70
|
-
"`pip3 install oracle-automlx[forecasting]>=23.4.1` "
|
69
|
+
"Please run `pip3 install oracle-automlx[forecasting]>=25.1.1` "
|
71
70
|
"to install the required dependencies for automlx."
|
72
71
|
),
|
73
72
|
)
|
@@ -105,7 +104,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
105
104
|
engine_opts = (
|
106
105
|
None
|
107
106
|
if engine_type == "local"
|
108
|
-
else
|
107
|
+
else {"ray_setup": {"_temp_dir": "/tmp/ray-temp"}}
|
109
108
|
)
|
110
109
|
init(
|
111
110
|
engine=engine_type,
|
@@ -125,7 +124,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
125
124
|
self.forecast_output.init_series_output(
|
126
125
|
series_id=s_id, data_at_series=df
|
127
126
|
)
|
128
|
-
data = self.preprocess(df)
|
127
|
+
data = self.preprocess(df, s_id)
|
129
128
|
data_i = self.drop_horizon(data)
|
130
129
|
X_pred = self.get_horizon(data).drop(target, axis=1)
|
131
130
|
|
@@ -157,7 +156,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
157
156
|
target
|
158
157
|
].values
|
159
158
|
|
160
|
-
self.models[s_id] =
|
159
|
+
self.models[s_id] = {}
|
160
|
+
self.models[s_id]["model"] = model
|
161
|
+
self.models[s_id]["le"] = self.le[s_id]
|
161
162
|
|
162
163
|
# In case of Naive model, model.forecast function call does not return confidence intervals.
|
163
164
|
if f"{target}_ci_upper" not in summary_frame:
|
@@ -218,7 +219,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
218
219
|
other_sections = []
|
219
220
|
|
220
221
|
if len(self.models) > 0:
|
221
|
-
for s_id,
|
222
|
+
for s_id, artifacts in models.items():
|
223
|
+
m = artifacts["model"]
|
222
224
|
selected_models[s_id] = {
|
223
225
|
"series_id": s_id,
|
224
226
|
"selected_model": m.selected_model_,
|
@@ -247,17 +249,18 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
247
249
|
self.explain_model()
|
248
250
|
|
249
251
|
global_explanation_section = None
|
250
|
-
if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
|
251
|
-
# Convert the global explanation data to a DataFrame
|
252
|
-
global_explanation_df = pd.DataFrame(self.global_explanation)
|
253
252
|
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
253
|
+
# Convert the global explanation data to a DataFrame
|
254
|
+
global_explanation_df = pd.DataFrame(self.global_explanation)
|
255
|
+
|
256
|
+
self.formatted_global_explanation = (
|
257
|
+
global_explanation_df / global_explanation_df.sum(axis=0) * 100
|
258
|
+
)
|
259
|
+
|
260
|
+
self.formatted_global_explanation.rename(
|
261
|
+
columns={self.spec.datetime_column.name: ForecastOutputColumns.DATE},
|
262
|
+
inplace=True,
|
263
|
+
)
|
261
264
|
|
262
265
|
aggregate_local_explanations = pd.DataFrame()
|
263
266
|
for s_id, local_ex_df in self.local_explanation.items():
|
@@ -269,11 +272,15 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
269
272
|
self.formatted_local_explanation = aggregate_local_explanations
|
270
273
|
|
271
274
|
if not self.target_cat_col:
|
272
|
-
self.formatted_global_explanation =
|
273
|
-
|
274
|
-
|
275
|
+
self.formatted_global_explanation = (
|
276
|
+
self.formatted_global_explanation.rename(
|
277
|
+
{"Series 1": self.original_target_column},
|
278
|
+
axis=1,
|
279
|
+
)
|
280
|
+
)
|
281
|
+
self.formatted_local_explanation.drop(
|
282
|
+
"Series", axis=1, inplace=True
|
275
283
|
)
|
276
|
-
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
|
277
284
|
|
278
285
|
# Create a markdown section for the global explainability
|
279
286
|
global_explanation_section = rc.Block(
|
@@ -320,7 +327,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
320
327
|
)
|
321
328
|
|
322
329
|
def get_explain_predict_fn(self, series_id):
|
323
|
-
selected_model = self.models[series_id]
|
330
|
+
selected_model = self.models[series_id]["model"]
|
324
331
|
|
325
332
|
# If training date, use method below. If future date, use forecast!
|
326
333
|
def _custom_predict_fn(
|
@@ -338,12 +345,12 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
338
345
|
data[dt_column_name] = seconds_to_datetime(
|
339
346
|
data[dt_column_name], dt_format=self.spec.datetime_column.format
|
340
347
|
)
|
341
|
-
data = self.preprocess(data)
|
348
|
+
data = self.preprocess(data, series_id)
|
342
349
|
horizon_data = horizon_data.drop(target_col, axis=1)
|
343
350
|
horizon_data[dt_column_name] = seconds_to_datetime(
|
344
351
|
horizon_data[dt_column_name], dt_format=self.spec.datetime_column.format
|
345
352
|
)
|
346
|
-
horizon_data = self.preprocess(horizon_data)
|
353
|
+
horizon_data = self.preprocess(horizon_data, series_id)
|
347
354
|
|
348
355
|
rows = []
|
349
356
|
for i in range(data.shape[0]):
|
@@ -421,8 +428,10 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
421
428
|
if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
|
422
429
|
# Use the MLExplainer class from AutoMLx to generate explanations
|
423
430
|
explainer = automlx.MLExplainer(
|
424
|
-
self.models[s_id],
|
425
|
-
self.datasets.additional_data.get_data_for_series(
|
431
|
+
self.models[s_id]["model"],
|
432
|
+
self.datasets.additional_data.get_data_for_series(
|
433
|
+
series_id=s_id
|
434
|
+
)
|
426
435
|
.drop(self.spec.datetime_column.name, axis=1)
|
427
436
|
.head(-self.spec.horizon)
|
428
437
|
if self.spec.additional_data
|
@@ -433,7 +442,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
433
442
|
|
434
443
|
# Generate explanations for the forecast
|
435
444
|
explanations = explainer.explain_prediction(
|
436
|
-
X=self.datasets.additional_data.get_data_for_series(
|
445
|
+
X=self.datasets.additional_data.get_data_for_series(
|
446
|
+
series_id=s_id
|
447
|
+
)
|
437
448
|
.drop(self.spec.datetime_column.name, axis=1)
|
438
449
|
.tail(self.spec.horizon)
|
439
450
|
if self.spec.additional_data
|
@@ -445,7 +456,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
445
456
|
explanations_df = pd.concat(
|
446
457
|
[exp.to_dataframe() for exp in explanations]
|
447
458
|
)
|
448
|
-
explanations_df["row"] = explanations_df.groupby(
|
459
|
+
explanations_df["row"] = explanations_df.groupby(
|
460
|
+
"Feature"
|
461
|
+
).cumcount()
|
449
462
|
explanations_df = explanations_df.pivot(
|
450
463
|
index="row", columns="Feature", values="Attribution"
|
451
464
|
)
|
@@ -453,9 +466,18 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
453
466
|
|
454
467
|
# Store the explanations in the local_explanation dictionary
|
455
468
|
self.local_explanation[s_id] = explanations_df
|
469
|
+
|
470
|
+
self.global_explanation[s_id] = dict(
|
471
|
+
zip(
|
472
|
+
self.local_explanation[s_id].columns,
|
473
|
+
np.nanmean((self.local_explanation[s_id]), axis=0),
|
474
|
+
)
|
475
|
+
)
|
456
476
|
else:
|
457
477
|
# Fall back to the default explanation generation method
|
458
478
|
super().explain_model()
|
459
479
|
except Exception as e:
|
460
|
-
logger.warning(
|
480
|
+
logger.warning(
|
481
|
+
f"Failed to generate explanations for series {s_id} with error: {e}."
|
482
|
+
)
|
461
483
|
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|