oracle-ads 2.12.11__py3-none-any.whl → 2.13.1rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. ads/aqua/app.py +23 -10
  2. ads/aqua/common/enums.py +19 -14
  3. ads/aqua/common/errors.py +3 -4
  4. ads/aqua/common/utils.py +2 -2
  5. ads/aqua/constants.py +1 -0
  6. ads/aqua/evaluation/constants.py +7 -7
  7. ads/aqua/evaluation/errors.py +3 -4
  8. ads/aqua/extension/model_handler.py +23 -0
  9. ads/aqua/extension/models/ws_models.py +5 -6
  10. ads/aqua/finetuning/constants.py +3 -3
  11. ads/aqua/model/constants.py +7 -7
  12. ads/aqua/model/enums.py +4 -5
  13. ads/aqua/model/model.py +22 -0
  14. ads/aqua/modeldeployment/entities.py +3 -1
  15. ads/common/auth.py +33 -20
  16. ads/common/extended_enum.py +52 -44
  17. ads/llm/__init__.py +11 -8
  18. ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
  19. ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
  20. ads/model/artifact_downloader.py +3 -4
  21. ads/model/datascience_model.py +84 -64
  22. ads/model/generic_model.py +3 -3
  23. ads/model/model_metadata.py +17 -11
  24. ads/model/service/oci_datascience_model.py +12 -14
  25. ads/opctl/anomaly_detection.py +11 -0
  26. ads/opctl/backend/marketplace/helm_helper.py +13 -14
  27. ads/opctl/cli.py +4 -5
  28. ads/opctl/cmds.py +28 -32
  29. ads/opctl/config/merger.py +8 -11
  30. ads/opctl/config/resolver.py +25 -30
  31. ads/opctl/forecast.py +11 -0
  32. ads/opctl/operator/cli.py +9 -9
  33. ads/opctl/operator/common/backend_factory.py +56 -60
  34. ads/opctl/operator/common/const.py +5 -5
  35. ads/opctl/operator/lowcode/anomaly/const.py +8 -9
  36. ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +43 -48
  37. ads/opctl/operator/lowcode/forecast/__main__.py +5 -5
  38. ads/opctl/operator/lowcode/forecast/const.py +6 -6
  39. ads/opctl/operator/lowcode/forecast/model/arima.py +6 -3
  40. ads/opctl/operator/lowcode/forecast/model/automlx.py +53 -31
  41. ads/opctl/operator/lowcode/forecast/model/base_model.py +57 -30
  42. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +60 -2
  43. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +5 -2
  44. ads/opctl/operator/lowcode/forecast/model/prophet.py +28 -15
  45. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +19 -11
  46. ads/opctl/operator/lowcode/pii/constant.py +6 -7
  47. ads/opctl/operator/lowcode/recommender/constant.py +12 -7
  48. ads/opctl/operator/runtime/marketplace_runtime.py +4 -10
  49. ads/opctl/operator/runtime/runtime.py +4 -6
  50. ads/pipeline/ads_pipeline_run.py +13 -25
  51. ads/pipeline/visualizer/graph_renderer.py +3 -4
  52. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/METADATA +6 -6
  53. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/RECORD +56 -52
  54. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/LICENSE.txt +0 -0
  55. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/WHEEL +0 -0
  56. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/entry_points.txt +0 -0
@@ -1,14 +1,13 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
- from ads.common.extended_enum import ExtendedEnumMeta
6
+ from ads.common.extended_enum import ExtendedEnum
8
7
  from ads.opctl.operator.lowcode.common.const import DataColumns
9
8
 
10
9
 
11
- class SupportedModels(str, metaclass=ExtendedEnumMeta):
10
+ class SupportedModels(ExtendedEnum):
12
11
  """Supported anomaly models."""
13
12
 
14
13
  AutoTS = "autots"
@@ -38,7 +37,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
38
37
  BOCPD = "bocpd"
39
38
 
40
39
 
41
- class NonTimeADSupportedModels(str, metaclass=ExtendedEnumMeta):
40
+ class NonTimeADSupportedModels(ExtendedEnum):
42
41
  """Supported non time-based anomaly detection models."""
43
42
 
44
43
  OneClassSVM = "oneclasssvm"
@@ -48,7 +47,7 @@ class NonTimeADSupportedModels(str, metaclass=ExtendedEnumMeta):
48
47
  # DBScan = "dbscan"
49
48
 
50
49
 
51
- class TODSSubModels(str, metaclass=ExtendedEnumMeta):
50
+ class TODSSubModels(ExtendedEnum):
52
51
  """Supported TODS sub models."""
53
52
 
54
53
  OCSVM = "ocsvm"
@@ -78,7 +77,7 @@ TODS_MODEL_MAP = {
78
77
  }
79
78
 
80
79
 
81
- class MerlionADModels(str, metaclass=ExtendedEnumMeta):
80
+ class MerlionADModels(ExtendedEnum):
82
81
  """Supported Merlion AD sub models."""
83
82
 
84
83
  # point anomaly
@@ -126,7 +125,7 @@ MERLIONAD_MODEL_MAP = {
126
125
  }
127
126
 
128
127
 
129
- class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
128
+ class SupportedMetrics(ExtendedEnum):
130
129
  UNSUPERVISED_UNIFY95 = "unsupervised_unify95"
131
130
  UNSUPERVISED_UNIFY95_LOG_LOSS = "unsupervised_unify95_log_loss"
132
131
  UNSUPERVISED_N1_EXPERTS = "unsupervised_n-1_experts"
@@ -158,7 +157,7 @@ class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
158
157
  ELAPSED_TIME = "Elapsed Time"
159
158
 
160
159
 
161
- class OutputColumns(str, metaclass=ExtendedEnumMeta):
160
+ class OutputColumns(ExtendedEnum):
162
161
  ANOMALY_COL = "anomaly"
163
162
  SCORE_COL = "score"
164
163
  Series = DataColumns.Series
@@ -1,71 +1,66 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import ast
8
7
  import base64
9
- from typing import Optional, List, Dict
8
+ from typing import TYPE_CHECKING, Dict, List, Optional
10
9
 
11
10
  import oci
12
11
  import requests
13
- from typing import TYPE_CHECKING
14
12
 
15
13
  try:
16
14
  from kubernetes.client import (
17
- V1ServiceStatus,
18
- V1Service,
19
- V1LoadBalancerStatus,
20
15
  V1LoadBalancerIngress,
16
+ V1LoadBalancerStatus,
17
+ V1Service,
18
+ V1ServiceStatus,
21
19
  )
22
20
  except ImportError:
23
21
  if TYPE_CHECKING:
24
22
  from kubernetes.client import (
25
- V1ServiceStatus,
26
- V1Service,
27
- V1LoadBalancerStatus,
28
23
  V1LoadBalancerIngress,
24
+ V1LoadBalancerStatus,
25
+ V1Service,
26
+ V1ServiceStatus,
29
27
  )
30
28
 
31
- from oci.resource_manager.models import StackSummary, AssociatedResourceSummary
32
-
33
- from ads.opctl.operator.lowcode.feature_store_marketplace.models.apigw_config import (
34
- APIGatewayConfig,
35
- )
29
+ import click
30
+ from oci.resource_manager.models import AssociatedResourceSummary, StackSummary
36
31
 
37
- from ads.common.oci_client import OCIClientFactory
38
- from ads.opctl.operator.lowcode.feature_store_marketplace.const import (
39
- LISTING_ID,
40
- APIGW_STACK_NAME,
41
- STACK_URL,
42
- NLB_RULES_ADDRESS,
43
- NODES_RULES_ADDRESS,
44
- )
45
32
  from ads import logger
46
- import click
33
+ from ads.common import auth as authutil
34
+ from ads.common.oci_client import OCIClientFactory
47
35
  from ads.opctl import logger
48
-
49
36
  from ads.opctl.backend.marketplace.marketplace_utils import (
50
37
  Color,
51
38
  print_heading,
52
39
  print_ticker,
53
40
  )
54
- from ads.opctl.operator.lowcode.feature_store_marketplace.models.mysql_config import (
55
- MySqlConfig,
41
+ from ads.opctl.operator.lowcode.feature_store_marketplace.const import (
42
+ APIGW_STACK_NAME,
43
+ LISTING_ID,
44
+ NLB_RULES_ADDRESS,
45
+ NODES_RULES_ADDRESS,
46
+ STACK_URL,
47
+ )
48
+ from ads.opctl.operator.lowcode.feature_store_marketplace.models.apigw_config import (
49
+ APIGatewayConfig,
56
50
  )
57
-
58
51
  from ads.opctl.operator.lowcode.feature_store_marketplace.models.db_config import (
59
52
  DBConfig,
60
53
  )
61
- from ads.common import auth as authutil
54
+ from ads.opctl.operator.lowcode.feature_store_marketplace.models.mysql_config import (
55
+ MySqlConfig,
56
+ )
62
57
 
63
58
 
64
59
  def get_db_details() -> DBConfig:
65
60
  jdbc_url = "jdbc:mysql://{}/{}?createDatabaseIfNotExist=true"
66
61
  mysql_db_config = MySqlConfig()
67
62
  print_heading(
68
- f"MySQL database configuration",
63
+ "MySQL database configuration",
69
64
  colors=[Color.BOLD, Color.BLUE],
70
65
  prefix_newline_count=2,
71
66
  )
@@ -76,12 +71,12 @@ def get_db_details() -> DBConfig:
76
71
  "Is password provided as plain-text or via a Vault secret?\n"
77
72
  "(https://docs.oracle.com/en-us/iaas/Content/KeyManagement/Concepts/keyoverview.htm)",
78
73
  type=click.Choice(MySqlConfig.MySQLAuthType.values()),
79
- default=MySqlConfig.MySQLAuthType.BASIC.value,
74
+ default=MySqlConfig.MySQLAuthType.BASIC,
80
75
  )
81
76
  )
82
77
  if mysql_db_config.auth_type == MySqlConfig.MySQLAuthType.BASIC:
83
78
  basic_auth_config = MySqlConfig.BasicConfig()
84
- basic_auth_config.password = click.prompt(f"Password", hide_input=True)
79
+ basic_auth_config.password = click.prompt("Password", hide_input=True)
85
80
  mysql_db_config.basic_config = basic_auth_config
86
81
 
87
82
  elif mysql_db_config.auth_type == MySqlConfig.MySQLAuthType.VAULT:
@@ -176,12 +171,12 @@ def detect_or_create_stack(apigw_config: APIGatewayConfig):
176
171
  ).data
177
172
 
178
173
  if len(stacks) >= 1:
179
- print(f"Auto-detected feature store stack(s) in tenancy:")
174
+ print("Auto-detected feature store stack(s) in tenancy:")
180
175
  for stack in stacks:
181
176
  _print_stack_detail(stack)
182
177
  choices = {"1": "new", "2": "existing"}
183
178
  stack_provision_method = click.prompt(
184
- f"Select stack provisioning method:\n1.Create new stack\n2.Existing stack\n",
179
+ "Select stack provisioning method:\n1.Create new stack\n2.Existing stack\n",
185
180
  type=click.Choice(list(choices.keys())),
186
181
  show_choices=False,
187
182
  )
@@ -240,20 +235,20 @@ def get_api_gw_details(compartment_id: str) -> APIGatewayConfig:
240
235
 
241
236
 
242
237
  def get_nlb_id_from_service(service: "V1Service", apigw_config: APIGatewayConfig):
243
- status: "V1ServiceStatus" = service.status
244
- lb_status: "V1LoadBalancerStatus" = status.load_balancer
245
- lb_ingress: "V1LoadBalancerIngress" = lb_status.ingress[0]
238
+ status: V1ServiceStatus = service.status
239
+ lb_status: V1LoadBalancerStatus = status.load_balancer
240
+ lb_ingress: V1LoadBalancerIngress = lb_status.ingress[0]
246
241
  resource_client = OCIClientFactory(**authutil.default_signer()).create_client(
247
242
  oci.resource_search.ResourceSearchClient
248
243
  )
249
244
  search_details = oci.resource_search.models.FreeTextSearchDetails()
250
245
  search_details.matching_context_type = "NONE"
251
246
  search_details.text = lb_ingress.ip
252
- resources: List[
253
- oci.resource_search.models.ResourceSummary
254
- ] = resource_client.search_resources(
255
- search_details, tenant_id=apigw_config.root_compartment_id
256
- ).data.items
247
+ resources: List[oci.resource_search.models.ResourceSummary] = (
248
+ resource_client.search_resources(
249
+ search_details, tenant_id=apigw_config.root_compartment_id
250
+ ).data.items
251
+ )
257
252
  private_ips = list(filter(lambda obj: obj.resource_type == "PrivateIp", resources))
258
253
  if len(private_ips) != 1:
259
254
  return click.prompt(
@@ -264,12 +259,12 @@ def get_nlb_id_from_service(service: "V1Service", apigw_config: APIGatewayConfig
264
259
  nlb_client = OCIClientFactory(**authutil.default_signer()).create_client(
265
260
  oci.network_load_balancer.NetworkLoadBalancerClient
266
261
  )
267
- nlbs: List[
268
- oci.network_load_balancer.models.NetworkLoadBalancerSummary
269
- ] = nlb_client.list_network_load_balancers(
270
- compartment_id=nlb_private_ip.compartment_id,
271
- display_name=nlb_private_ip.display_name,
272
- ).data.items
262
+ nlbs: List[oci.network_load_balancer.models.NetworkLoadBalancerSummary] = (
263
+ nlb_client.list_network_load_balancers(
264
+ compartment_id=nlb_private_ip.compartment_id,
265
+ display_name=nlb_private_ip.display_name,
266
+ ).data.items
267
+ )
273
268
  if len(nlbs) != 1:
274
269
  return click.prompt(
275
270
  f"Please enter OCID of load balancer associated with ip: {lb_ingress.ip}"
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import json
@@ -15,17 +14,17 @@ from ads.opctl import logger
15
14
  from ads.opctl.operator.common.const import ENV_OPERATOR_ARGS
16
15
  from ads.opctl.operator.common.utils import _parse_input_args
17
16
 
17
+ from .model.forecast_datasets import ForecastDatasets, ForecastResults
18
18
  from .operator_config import ForecastOperatorConfig
19
- from .model.forecast_datasets import ForecastDatasets
20
19
  from .whatifserve import ModelDeploymentManager
21
20
 
22
21
 
23
- def operate(operator_config: ForecastOperatorConfig) -> None:
22
+ def operate(operator_config: ForecastOperatorConfig) -> ForecastResults:
24
23
  """Runs the forecasting operator."""
25
24
  from .model.factory import ForecastOperatorModelFactory
26
25
 
27
26
  datasets = ForecastDatasets(operator_config)
28
- ForecastOperatorModelFactory.get_model(
27
+ results = ForecastOperatorModelFactory.get_model(
29
28
  operator_config, datasets
30
29
  ).generate_report()
31
30
  # saving to model catalog
@@ -36,6 +35,7 @@ def operate(operator_config: ForecastOperatorConfig) -> None:
36
35
  if spec.what_if_analysis.model_deployment:
37
36
  mdm.create_deployment()
38
37
  mdm.save_deployment_info()
38
+ return results
39
39
 
40
40
 
41
41
  def verify(spec: Dict, **kwargs: Dict) -> bool:
@@ -1,13 +1,13 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
- from ads.common.extended_enum import ExtendedEnumMeta
6
+ from ads.common.extended_enum import ExtendedEnum
7
7
  from ads.opctl.operator.lowcode.common.const import DataColumns
8
8
 
9
9
 
10
- class SupportedModels(str, metaclass=ExtendedEnumMeta):
10
+ class SupportedModels(ExtendedEnum):
11
11
  """Supported forecast models."""
12
12
 
13
13
  Prophet = "prophet"
@@ -19,7 +19,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
19
19
  # Auto = "auto"
20
20
 
21
21
 
22
- class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
22
+ class SpeedAccuracyMode(ExtendedEnum):
23
23
  """
24
24
  Enum representing different modes based on time taken and accuracy for explainability.
25
25
  """
@@ -35,7 +35,7 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
35
35
  ratio[AUTOMLX] = 0 # constant
36
36
 
37
37
 
38
- class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
38
+ class SupportedMetrics(ExtendedEnum):
39
39
  """Supported forecast metrics."""
40
40
 
41
41
  MAPE = "MAPE"
@@ -62,7 +62,7 @@ class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
62
62
  ELAPSED_TIME = "Elapsed Time"
63
63
 
64
64
 
65
- class ForecastOutputColumns(str, metaclass=ExtendedEnumMeta):
65
+ class ForecastOutputColumns(ExtendedEnum):
66
66
  """The column names for the forecast.csv output file"""
67
67
 
68
68
  DATE = "Date"
@@ -116,7 +116,10 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
116
116
  lower_bound=self.get_horizon(forecast["yhat_lower"]).values,
117
117
  )
118
118
 
119
- self.models[s_id] = model
119
+ self.models[s_id] = {}
120
+ self.models[s_id]["model"] = model
121
+ self.models[s_id]["le"] = self.le[s_id]
122
+ self.models[s_id]["predict_component_cols"] = X_pred.columns
120
123
 
121
124
  params = vars(model).copy()
122
125
  for param in ["arima_res_", "endog_index_"]:
@@ -163,7 +166,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
163
166
  sec5_text = rc.Heading("ARIMA Model Parameters", level=2)
164
167
  blocks = [
165
168
  rc.Html(
166
- m.summary().as_html(),
169
+ m['model'].summary().as_html(),
167
170
  label=s_id if self.target_cat_col else None,
168
171
  )
169
172
  for i, (s_id, m) in enumerate(self.models.items())
@@ -251,7 +254,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
251
254
  def get_explain_predict_fn(self, series_id):
252
255
  def _custom_predict(
253
256
  data,
254
- model=self.models[series_id],
257
+ model=self.models[series_id]["model"],
255
258
  dt_column_name=self.datasets._datetime_column_name,
256
259
  target_col=self.original_target_column,
257
260
  ):
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
  import logging
5
5
  import os
@@ -56,8 +56,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
56
56
  )
57
57
  return model_kwargs_cleaned, time_budget
58
58
 
59
- def preprocess(self, data): # TODO: re-use self.le for explanations
60
- _, df_encoded = _label_encode_dataframe(
59
+ def preprocess(self, data, series_id): # TODO: re-use self.le for explanations
60
+ self.le[series_id], df_encoded = _label_encode_dataframe(
61
61
  data,
62
62
  no_encode={self.spec.datetime_column.name, self.original_target_column},
63
63
  )
@@ -66,8 +66,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
66
66
  @runtime_dependency(
67
67
  module="automlx",
68
68
  err_msg=(
69
- "Please run `pip3 install oracle-automlx>=23.4.1` and "
70
- "`pip3 install oracle-automlx[forecasting]>=23.4.1` "
69
+ "Please run `pip3 install oracle-automlx[forecasting]>=25.1.1` "
71
70
  "to install the required dependencies for automlx."
72
71
  ),
73
72
  )
@@ -105,7 +104,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
105
104
  engine_opts = (
106
105
  None
107
106
  if engine_type == "local"
108
- else ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
107
+ else {"ray_setup": {"_temp_dir": "/tmp/ray-temp"}}
109
108
  )
110
109
  init(
111
110
  engine=engine_type,
@@ -125,7 +124,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
125
124
  self.forecast_output.init_series_output(
126
125
  series_id=s_id, data_at_series=df
127
126
  )
128
- data = self.preprocess(df)
127
+ data = self.preprocess(df, s_id)
129
128
  data_i = self.drop_horizon(data)
130
129
  X_pred = self.get_horizon(data).drop(target, axis=1)
131
130
 
@@ -157,7 +156,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
157
156
  target
158
157
  ].values
159
158
 
160
- self.models[s_id] = model
159
+ self.models[s_id] = {}
160
+ self.models[s_id]["model"] = model
161
+ self.models[s_id]["le"] = self.le[s_id]
161
162
 
162
163
  # In case of Naive model, model.forecast function call does not return confidence intervals.
163
164
  if f"{target}_ci_upper" not in summary_frame:
@@ -218,7 +219,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
218
219
  other_sections = []
219
220
 
220
221
  if len(self.models) > 0:
221
- for s_id, m in models.items():
222
+ for s_id, artifacts in models.items():
223
+ m = artifacts["model"]
222
224
  selected_models[s_id] = {
223
225
  "series_id": s_id,
224
226
  "selected_model": m.selected_model_,
@@ -247,17 +249,18 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
247
249
  self.explain_model()
248
250
 
249
251
  global_explanation_section = None
250
- if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
251
- # Convert the global explanation data to a DataFrame
252
- global_explanation_df = pd.DataFrame(self.global_explanation)
253
252
 
254
- self.formatted_global_explanation = (
255
- global_explanation_df / global_explanation_df.sum(axis=0) * 100
256
- )
257
- self.formatted_global_explanation = self.formatted_global_explanation.rename(
258
- {self.spec.datetime_column.name: ForecastOutputColumns.DATE},
259
- axis=1,
260
- )
253
+ # Convert the global explanation data to a DataFrame
254
+ global_explanation_df = pd.DataFrame(self.global_explanation)
255
+
256
+ self.formatted_global_explanation = (
257
+ global_explanation_df / global_explanation_df.sum(axis=0) * 100
258
+ )
259
+
260
+ self.formatted_global_explanation.rename(
261
+ columns={self.spec.datetime_column.name: ForecastOutputColumns.DATE},
262
+ inplace=True,
263
+ )
261
264
 
262
265
  aggregate_local_explanations = pd.DataFrame()
263
266
  for s_id, local_ex_df in self.local_explanation.items():
@@ -269,11 +272,15 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
269
272
  self.formatted_local_explanation = aggregate_local_explanations
270
273
 
271
274
  if not self.target_cat_col:
272
- self.formatted_global_explanation = self.formatted_global_explanation.rename(
273
- {"Series 1": self.original_target_column},
274
- axis=1,
275
+ self.formatted_global_explanation = (
276
+ self.formatted_global_explanation.rename(
277
+ {"Series 1": self.original_target_column},
278
+ axis=1,
279
+ )
280
+ )
281
+ self.formatted_local_explanation.drop(
282
+ "Series", axis=1, inplace=True
275
283
  )
276
- self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
277
284
 
278
285
  # Create a markdown section for the global explainability
279
286
  global_explanation_section = rc.Block(
@@ -320,7 +327,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
320
327
  )
321
328
 
322
329
  def get_explain_predict_fn(self, series_id):
323
- selected_model = self.models[series_id]
330
+ selected_model = self.models[series_id]["model"]
324
331
 
325
332
  # If training date, use method below. If future date, use forecast!
326
333
  def _custom_predict_fn(
@@ -338,12 +345,12 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
338
345
  data[dt_column_name] = seconds_to_datetime(
339
346
  data[dt_column_name], dt_format=self.spec.datetime_column.format
340
347
  )
341
- data = self.preprocess(data)
348
+ data = self.preprocess(data, series_id)
342
349
  horizon_data = horizon_data.drop(target_col, axis=1)
343
350
  horizon_data[dt_column_name] = seconds_to_datetime(
344
351
  horizon_data[dt_column_name], dt_format=self.spec.datetime_column.format
345
352
  )
346
- horizon_data = self.preprocess(horizon_data)
353
+ horizon_data = self.preprocess(horizon_data, series_id)
347
354
 
348
355
  rows = []
349
356
  for i in range(data.shape[0]):
@@ -421,8 +428,10 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
421
428
  if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
422
429
  # Use the MLExplainer class from AutoMLx to generate explanations
423
430
  explainer = automlx.MLExplainer(
424
- self.models[s_id],
425
- self.datasets.additional_data.get_data_for_series(series_id=s_id)
431
+ self.models[s_id]["model"],
432
+ self.datasets.additional_data.get_data_for_series(
433
+ series_id=s_id
434
+ )
426
435
  .drop(self.spec.datetime_column.name, axis=1)
427
436
  .head(-self.spec.horizon)
428
437
  if self.spec.additional_data
@@ -433,7 +442,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
433
442
 
434
443
  # Generate explanations for the forecast
435
444
  explanations = explainer.explain_prediction(
436
- X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
445
+ X=self.datasets.additional_data.get_data_for_series(
446
+ series_id=s_id
447
+ )
437
448
  .drop(self.spec.datetime_column.name, axis=1)
438
449
  .tail(self.spec.horizon)
439
450
  if self.spec.additional_data
@@ -445,7 +456,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
445
456
  explanations_df = pd.concat(
446
457
  [exp.to_dataframe() for exp in explanations]
447
458
  )
448
- explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
459
+ explanations_df["row"] = explanations_df.groupby(
460
+ "Feature"
461
+ ).cumcount()
449
462
  explanations_df = explanations_df.pivot(
450
463
  index="row", columns="Feature", values="Attribution"
451
464
  )
@@ -453,9 +466,18 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
453
466
 
454
467
  # Store the explanations in the local_explanation dictionary
455
468
  self.local_explanation[s_id] = explanations_df
469
+
470
+ self.global_explanation[s_id] = dict(
471
+ zip(
472
+ self.local_explanation[s_id].columns,
473
+ np.nanmean((self.local_explanation[s_id]), axis=0),
474
+ )
475
+ )
456
476
  else:
457
477
  # Fall back to the default explanation generation method
458
478
  super().explain_model()
459
479
  except Exception as e:
460
- logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.")
480
+ logger.warning(
481
+ f"Failed to generate explanations for series {s_id} with error: {e}."
482
+ )
461
483
  logger.debug(f"Full Traceback: {traceback.format_exc()}")