oracle-ads 2.12.11__py3-none-any.whl → 2.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. ads/aqua/__init__.py +7 -1
  2. ads/aqua/app.py +41 -27
  3. ads/aqua/client/client.py +48 -11
  4. ads/aqua/common/entities.py +28 -1
  5. ads/aqua/common/enums.py +32 -21
  6. ads/aqua/common/errors.py +3 -4
  7. ads/aqua/common/utils.py +10 -15
  8. ads/aqua/config/container_config.py +203 -0
  9. ads/aqua/config/evaluation/evaluation_service_config.py +5 -181
  10. ads/aqua/constants.py +1 -1
  11. ads/aqua/evaluation/constants.py +7 -7
  12. ads/aqua/evaluation/errors.py +3 -4
  13. ads/aqua/evaluation/evaluation.py +4 -4
  14. ads/aqua/extension/base_handler.py +4 -0
  15. ads/aqua/extension/model_handler.py +41 -27
  16. ads/aqua/extension/models/ws_models.py +5 -6
  17. ads/aqua/finetuning/constants.py +3 -3
  18. ads/aqua/finetuning/finetuning.py +2 -3
  19. ads/aqua/model/constants.py +7 -7
  20. ads/aqua/model/entities.py +2 -3
  21. ads/aqua/model/enums.py +4 -5
  22. ads/aqua/model/model.py +46 -29
  23. ads/aqua/modeldeployment/deployment.py +6 -14
  24. ads/aqua/modeldeployment/entities.py +5 -3
  25. ads/aqua/server/__init__.py +4 -0
  26. ads/aqua/server/__main__.py +24 -0
  27. ads/aqua/server/app.py +47 -0
  28. ads/aqua/server/aqua_spec.yml +1291 -0
  29. ads/aqua/ui.py +5 -199
  30. ads/common/auth.py +50 -28
  31. ads/common/extended_enum.py +52 -44
  32. ads/common/utils.py +91 -11
  33. ads/config.py +3 -0
  34. ads/llm/__init__.py +12 -8
  35. ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
  36. ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
  37. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +32 -23
  38. ads/model/artifact_downloader.py +6 -4
  39. ads/model/common/utils.py +15 -3
  40. ads/model/datascience_model.py +422 -71
  41. ads/model/generic_model.py +3 -3
  42. ads/model/model_metadata.py +70 -24
  43. ads/model/model_version_set.py +5 -3
  44. ads/model/service/oci_datascience_model.py +487 -17
  45. ads/opctl/anomaly_detection.py +11 -0
  46. ads/opctl/backend/marketplace/helm_helper.py +13 -14
  47. ads/opctl/cli.py +4 -5
  48. ads/opctl/cmds.py +28 -32
  49. ads/opctl/config/merger.py +8 -11
  50. ads/opctl/config/resolver.py +25 -30
  51. ads/opctl/forecast.py +11 -0
  52. ads/opctl/operator/cli.py +9 -9
  53. ads/opctl/operator/common/backend_factory.py +56 -60
  54. ads/opctl/operator/common/const.py +5 -5
  55. ads/opctl/operator/common/utils.py +16 -0
  56. ads/opctl/operator/lowcode/anomaly/const.py +8 -9
  57. ads/opctl/operator/lowcode/common/data.py +5 -2
  58. ads/opctl/operator/lowcode/common/transformations.py +2 -12
  59. ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +43 -48
  60. ads/opctl/operator/lowcode/forecast/__main__.py +5 -5
  61. ads/opctl/operator/lowcode/forecast/const.py +6 -6
  62. ads/opctl/operator/lowcode/forecast/model/arima.py +6 -3
  63. ads/opctl/operator/lowcode/forecast/model/automlx.py +61 -31
  64. ads/opctl/operator/lowcode/forecast/model/base_model.py +66 -40
  65. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +79 -13
  66. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +5 -2
  67. ads/opctl/operator/lowcode/forecast/model/prophet.py +28 -15
  68. ads/opctl/operator/lowcode/forecast/model_evaluator.py +13 -15
  69. ads/opctl/operator/lowcode/forecast/schema.yaml +1 -1
  70. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +7 -0
  71. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +19 -11
  72. ads/opctl/operator/lowcode/pii/constant.py +6 -7
  73. ads/opctl/operator/lowcode/recommender/constant.py +12 -7
  74. ads/opctl/operator/runtime/marketplace_runtime.py +4 -10
  75. ads/opctl/operator/runtime/runtime.py +4 -6
  76. ads/pipeline/ads_pipeline_run.py +13 -25
  77. ads/pipeline/visualizer/graph_renderer.py +3 -4
  78. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/METADATA +18 -15
  79. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/RECORD +82 -74
  80. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/WHEEL +1 -1
  81. ads/aqua/config/evaluation/evaluation_service_model_config.py +0 -8
  82. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/entry_points.txt +0 -0
  83. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info/licenses}/LICENSE.txt +0 -0
@@ -9,6 +9,7 @@ import os
9
9
  import sys
10
10
  import time
11
11
  import traceback
12
+ import uuid
12
13
  from string import Template
13
14
  from typing import Any, Dict, List, Tuple
14
15
 
@@ -17,6 +18,7 @@ import yaml
17
18
  from cerberus import Validator
18
19
 
19
20
  from ads.opctl import logger, utils
21
+ from ads.common.oci_logging import OCILog
20
22
 
21
23
  CONTAINER_NETWORK = "CONTAINER_NETWORK"
22
24
 
@@ -190,3 +192,17 @@ def print_traceback():
190
192
  if logger.level == logging.DEBUG:
191
193
  ex_type, ex, tb = sys.exc_info()
192
194
  traceback.print_tb(tb)
195
+
196
+
197
+ def create_log_in_log_group(compartment_id, log_group_id, auth, log_name=None):
198
+ """
199
+ Creates a log within a given log group
200
+ """
201
+ if not log_name:
202
+ log_name = f"log-{int(time.time())}-{uuid.uuid4()}"
203
+ log = OCILog(display_name=log_name,
204
+ log_group_id=log_group_id,
205
+ compartment_id=compartment_id,
206
+ **auth).create()
207
+ logger.info(f"Created log with log OCID {log.id}")
208
+ return log.id
@@ -1,14 +1,13 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
- from ads.common.extended_enum import ExtendedEnumMeta
6
+ from ads.common.extended_enum import ExtendedEnum
8
7
  from ads.opctl.operator.lowcode.common.const import DataColumns
9
8
 
10
9
 
11
- class SupportedModels(str, metaclass=ExtendedEnumMeta):
10
+ class SupportedModels(ExtendedEnum):
12
11
  """Supported anomaly models."""
13
12
 
14
13
  AutoTS = "autots"
@@ -38,7 +37,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
38
37
  BOCPD = "bocpd"
39
38
 
40
39
 
41
- class NonTimeADSupportedModels(str, metaclass=ExtendedEnumMeta):
40
+ class NonTimeADSupportedModels(ExtendedEnum):
42
41
  """Supported non time-based anomaly detection models."""
43
42
 
44
43
  OneClassSVM = "oneclasssvm"
@@ -48,7 +47,7 @@ class NonTimeADSupportedModels(str, metaclass=ExtendedEnumMeta):
48
47
  # DBScan = "dbscan"
49
48
 
50
49
 
51
- class TODSSubModels(str, metaclass=ExtendedEnumMeta):
50
+ class TODSSubModels(ExtendedEnum):
52
51
  """Supported TODS sub models."""
53
52
 
54
53
  OCSVM = "ocsvm"
@@ -78,7 +77,7 @@ TODS_MODEL_MAP = {
78
77
  }
79
78
 
80
79
 
81
- class MerlionADModels(str, metaclass=ExtendedEnumMeta):
80
+ class MerlionADModels(ExtendedEnum):
82
81
  """Supported Merlion AD sub models."""
83
82
 
84
83
  # point anomaly
@@ -126,7 +125,7 @@ MERLIONAD_MODEL_MAP = {
126
125
  }
127
126
 
128
127
 
129
- class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
128
+ class SupportedMetrics(ExtendedEnum):
130
129
  UNSUPERVISED_UNIFY95 = "unsupervised_unify95"
131
130
  UNSUPERVISED_UNIFY95_LOG_LOSS = "unsupervised_unify95_log_loss"
132
131
  UNSUPERVISED_N1_EXPERTS = "unsupervised_n-1_experts"
@@ -158,7 +157,7 @@ class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
158
157
  ELAPSED_TIME = "Elapsed Time"
159
158
 
160
159
 
161
- class OutputColumns(str, metaclass=ExtendedEnumMeta):
160
+ class OutputColumns(ExtendedEnum):
162
161
  ANOMALY_COL = "anomaly"
163
162
  SCORE_COL = "score"
164
163
  Series = DataColumns.Series
@@ -19,13 +19,16 @@ from .transformations import Transformations
19
19
 
20
20
 
21
21
  class AbstractData(ABC):
22
- def __init__(self, spec: dict, name="input_data"):
22
+ def __init__(self, spec, name="input_data", data=None):
23
23
  self.Transformations = Transformations
24
24
  self.data = None
25
25
  self._data_dict = dict()
26
26
  self.name = name
27
27
  self.spec = spec
28
- self.load_transform_ingest_data(spec)
28
+ if data is not None:
29
+ self.data = data
30
+ else:
31
+ self.load_transform_ingest_data(spec)
29
32
 
30
33
  def get_raw_data_by_cat(self, category):
31
34
  mapping = self._data_transformer.get_target_category_columns_map()
@@ -31,7 +31,6 @@ class Transformations(ABC):
31
31
  dataset_info : ForecastOperatorConfig
32
32
  """
33
33
  self.name = name
34
- self.has_artificial_series = False
35
34
  self.dataset_info = dataset_info
36
35
  self.target_category_columns = dataset_info.target_category_columns
37
36
  self.target_column_name = dataset_info.target_column
@@ -136,7 +135,6 @@ class Transformations(ABC):
136
135
  self._target_category_columns_map = {}
137
136
  if not self.target_category_columns:
138
137
  df[DataColumns.Series] = "Series 1"
139
- self.has_artificial_series = True
140
138
  else:
141
139
  df[DataColumns.Series] = merge_category_columns(
142
140
  df, self.target_category_columns
@@ -209,7 +207,7 @@ class Transformations(ABC):
209
207
 
210
208
  def _missing_value_imputation_add(self, df):
211
209
  """
212
- Function fills missing values in the pandas dataframe using liner interpolation
210
+ Function fills missing values with zero
213
211
 
214
212
  Parameters
215
213
  ----------
@@ -219,15 +217,7 @@ class Transformations(ABC):
219
217
  -------
220
218
  A new Pandas DataFrame without missing values.
221
219
  """
222
- # find columns that all all NA and replace with 0
223
- for col in df.columns:
224
- # find next int not in list
225
- i = 0
226
- vals = df[col].unique()
227
- while i in vals:
228
- i = i + 1
229
- df[col] = df[col].fillna(0)
230
- return df
220
+ return df.fillna(0)
231
221
 
232
222
  def _outlier_treatment(self, df):
233
223
  """
@@ -1,71 +1,66 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import ast
8
7
  import base64
9
- from typing import Optional, List, Dict
8
+ from typing import TYPE_CHECKING, Dict, List, Optional
10
9
 
11
10
  import oci
12
11
  import requests
13
- from typing import TYPE_CHECKING
14
12
 
15
13
  try:
16
14
  from kubernetes.client import (
17
- V1ServiceStatus,
18
- V1Service,
19
- V1LoadBalancerStatus,
20
15
  V1LoadBalancerIngress,
16
+ V1LoadBalancerStatus,
17
+ V1Service,
18
+ V1ServiceStatus,
21
19
  )
22
20
  except ImportError:
23
21
  if TYPE_CHECKING:
24
22
  from kubernetes.client import (
25
- V1ServiceStatus,
26
- V1Service,
27
- V1LoadBalancerStatus,
28
23
  V1LoadBalancerIngress,
24
+ V1LoadBalancerStatus,
25
+ V1Service,
26
+ V1ServiceStatus,
29
27
  )
30
28
 
31
- from oci.resource_manager.models import StackSummary, AssociatedResourceSummary
32
-
33
- from ads.opctl.operator.lowcode.feature_store_marketplace.models.apigw_config import (
34
- APIGatewayConfig,
35
- )
29
+ import click
30
+ from oci.resource_manager.models import AssociatedResourceSummary, StackSummary
36
31
 
37
- from ads.common.oci_client import OCIClientFactory
38
- from ads.opctl.operator.lowcode.feature_store_marketplace.const import (
39
- LISTING_ID,
40
- APIGW_STACK_NAME,
41
- STACK_URL,
42
- NLB_RULES_ADDRESS,
43
- NODES_RULES_ADDRESS,
44
- )
45
32
  from ads import logger
46
- import click
33
+ from ads.common import auth as authutil
34
+ from ads.common.oci_client import OCIClientFactory
47
35
  from ads.opctl import logger
48
-
49
36
  from ads.opctl.backend.marketplace.marketplace_utils import (
50
37
  Color,
51
38
  print_heading,
52
39
  print_ticker,
53
40
  )
54
- from ads.opctl.operator.lowcode.feature_store_marketplace.models.mysql_config import (
55
- MySqlConfig,
41
+ from ads.opctl.operator.lowcode.feature_store_marketplace.const import (
42
+ APIGW_STACK_NAME,
43
+ LISTING_ID,
44
+ NLB_RULES_ADDRESS,
45
+ NODES_RULES_ADDRESS,
46
+ STACK_URL,
47
+ )
48
+ from ads.opctl.operator.lowcode.feature_store_marketplace.models.apigw_config import (
49
+ APIGatewayConfig,
56
50
  )
57
-
58
51
  from ads.opctl.operator.lowcode.feature_store_marketplace.models.db_config import (
59
52
  DBConfig,
60
53
  )
61
- from ads.common import auth as authutil
54
+ from ads.opctl.operator.lowcode.feature_store_marketplace.models.mysql_config import (
55
+ MySqlConfig,
56
+ )
62
57
 
63
58
 
64
59
  def get_db_details() -> DBConfig:
65
60
  jdbc_url = "jdbc:mysql://{}/{}?createDatabaseIfNotExist=true"
66
61
  mysql_db_config = MySqlConfig()
67
62
  print_heading(
68
- f"MySQL database configuration",
63
+ "MySQL database configuration",
69
64
  colors=[Color.BOLD, Color.BLUE],
70
65
  prefix_newline_count=2,
71
66
  )
@@ -76,12 +71,12 @@ def get_db_details() -> DBConfig:
76
71
  "Is password provided as plain-text or via a Vault secret?\n"
77
72
  "(https://docs.oracle.com/en-us/iaas/Content/KeyManagement/Concepts/keyoverview.htm)",
78
73
  type=click.Choice(MySqlConfig.MySQLAuthType.values()),
79
- default=MySqlConfig.MySQLAuthType.BASIC.value,
74
+ default=MySqlConfig.MySQLAuthType.BASIC,
80
75
  )
81
76
  )
82
77
  if mysql_db_config.auth_type == MySqlConfig.MySQLAuthType.BASIC:
83
78
  basic_auth_config = MySqlConfig.BasicConfig()
84
- basic_auth_config.password = click.prompt(f"Password", hide_input=True)
79
+ basic_auth_config.password = click.prompt("Password", hide_input=True)
85
80
  mysql_db_config.basic_config = basic_auth_config
86
81
 
87
82
  elif mysql_db_config.auth_type == MySqlConfig.MySQLAuthType.VAULT:
@@ -176,12 +171,12 @@ def detect_or_create_stack(apigw_config: APIGatewayConfig):
176
171
  ).data
177
172
 
178
173
  if len(stacks) >= 1:
179
- print(f"Auto-detected feature store stack(s) in tenancy:")
174
+ print("Auto-detected feature store stack(s) in tenancy:")
180
175
  for stack in stacks:
181
176
  _print_stack_detail(stack)
182
177
  choices = {"1": "new", "2": "existing"}
183
178
  stack_provision_method = click.prompt(
184
- f"Select stack provisioning method:\n1.Create new stack\n2.Existing stack\n",
179
+ "Select stack provisioning method:\n1.Create new stack\n2.Existing stack\n",
185
180
  type=click.Choice(list(choices.keys())),
186
181
  show_choices=False,
187
182
  )
@@ -240,20 +235,20 @@ def get_api_gw_details(compartment_id: str) -> APIGatewayConfig:
240
235
 
241
236
 
242
237
  def get_nlb_id_from_service(service: "V1Service", apigw_config: APIGatewayConfig):
243
- status: "V1ServiceStatus" = service.status
244
- lb_status: "V1LoadBalancerStatus" = status.load_balancer
245
- lb_ingress: "V1LoadBalancerIngress" = lb_status.ingress[0]
238
+ status: V1ServiceStatus = service.status
239
+ lb_status: V1LoadBalancerStatus = status.load_balancer
240
+ lb_ingress: V1LoadBalancerIngress = lb_status.ingress[0]
246
241
  resource_client = OCIClientFactory(**authutil.default_signer()).create_client(
247
242
  oci.resource_search.ResourceSearchClient
248
243
  )
249
244
  search_details = oci.resource_search.models.FreeTextSearchDetails()
250
245
  search_details.matching_context_type = "NONE"
251
246
  search_details.text = lb_ingress.ip
252
- resources: List[
253
- oci.resource_search.models.ResourceSummary
254
- ] = resource_client.search_resources(
255
- search_details, tenant_id=apigw_config.root_compartment_id
256
- ).data.items
247
+ resources: List[oci.resource_search.models.ResourceSummary] = (
248
+ resource_client.search_resources(
249
+ search_details, tenant_id=apigw_config.root_compartment_id
250
+ ).data.items
251
+ )
257
252
  private_ips = list(filter(lambda obj: obj.resource_type == "PrivateIp", resources))
258
253
  if len(private_ips) != 1:
259
254
  return click.prompt(
@@ -264,12 +259,12 @@ def get_nlb_id_from_service(service: "V1Service", apigw_config: APIGatewayConfig
264
259
  nlb_client = OCIClientFactory(**authutil.default_signer()).create_client(
265
260
  oci.network_load_balancer.NetworkLoadBalancerClient
266
261
  )
267
- nlbs: List[
268
- oci.network_load_balancer.models.NetworkLoadBalancerSummary
269
- ] = nlb_client.list_network_load_balancers(
270
- compartment_id=nlb_private_ip.compartment_id,
271
- display_name=nlb_private_ip.display_name,
272
- ).data.items
262
+ nlbs: List[oci.network_load_balancer.models.NetworkLoadBalancerSummary] = (
263
+ nlb_client.list_network_load_balancers(
264
+ compartment_id=nlb_private_ip.compartment_id,
265
+ display_name=nlb_private_ip.display_name,
266
+ ).data.items
267
+ )
273
268
  if len(nlbs) != 1:
274
269
  return click.prompt(
275
270
  f"Please enter OCID of load balancer associated with ip: {lb_ingress.ip}"
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import json
@@ -15,17 +14,17 @@ from ads.opctl import logger
15
14
  from ads.opctl.operator.common.const import ENV_OPERATOR_ARGS
16
15
  from ads.opctl.operator.common.utils import _parse_input_args
17
16
 
17
+ from .model.forecast_datasets import ForecastDatasets, ForecastResults
18
18
  from .operator_config import ForecastOperatorConfig
19
- from .model.forecast_datasets import ForecastDatasets
20
19
  from .whatifserve import ModelDeploymentManager
21
20
 
22
21
 
23
- def operate(operator_config: ForecastOperatorConfig) -> None:
22
+ def operate(operator_config: ForecastOperatorConfig) -> ForecastResults:
24
23
  """Runs the forecasting operator."""
25
24
  from .model.factory import ForecastOperatorModelFactory
26
25
 
27
26
  datasets = ForecastDatasets(operator_config)
28
- ForecastOperatorModelFactory.get_model(
27
+ results = ForecastOperatorModelFactory.get_model(
29
28
  operator_config, datasets
30
29
  ).generate_report()
31
30
  # saving to model catalog
@@ -36,6 +35,7 @@ def operate(operator_config: ForecastOperatorConfig) -> None:
36
35
  if spec.what_if_analysis.model_deployment:
37
36
  mdm.create_deployment()
38
37
  mdm.save_deployment_info()
38
+ return results
39
39
 
40
40
 
41
41
  def verify(spec: Dict, **kwargs: Dict) -> bool:
@@ -1,13 +1,13 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
- from ads.common.extended_enum import ExtendedEnumMeta
6
+ from ads.common.extended_enum import ExtendedEnum
7
7
  from ads.opctl.operator.lowcode.common.const import DataColumns
8
8
 
9
9
 
10
- class SupportedModels(str, metaclass=ExtendedEnumMeta):
10
+ class SupportedModels(ExtendedEnum):
11
11
  """Supported forecast models."""
12
12
 
13
13
  Prophet = "prophet"
@@ -19,7 +19,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
19
19
  # Auto = "auto"
20
20
 
21
21
 
22
- class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
22
+ class SpeedAccuracyMode(ExtendedEnum):
23
23
  """
24
24
  Enum representing different modes based on time taken and accuracy for explainability.
25
25
  """
@@ -35,7 +35,7 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
35
35
  ratio[AUTOMLX] = 0 # constant
36
36
 
37
37
 
38
- class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
38
+ class SupportedMetrics(ExtendedEnum):
39
39
  """Supported forecast metrics."""
40
40
 
41
41
  MAPE = "MAPE"
@@ -62,7 +62,7 @@ class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
62
62
  ELAPSED_TIME = "Elapsed Time"
63
63
 
64
64
 
65
- class ForecastOutputColumns(str, metaclass=ExtendedEnumMeta):
65
+ class ForecastOutputColumns(ExtendedEnum):
66
66
  """The column names for the forecast.csv output file"""
67
67
 
68
68
  DATE = "Date"
@@ -116,7 +116,10 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
116
116
  lower_bound=self.get_horizon(forecast["yhat_lower"]).values,
117
117
  )
118
118
 
119
- self.models[s_id] = model
119
+ self.models[s_id] = {}
120
+ self.models[s_id]["model"] = model
121
+ self.models[s_id]["le"] = self.le[s_id]
122
+ self.models[s_id]["predict_component_cols"] = X_pred.columns
120
123
 
121
124
  params = vars(model).copy()
122
125
  for param in ["arima_res_", "endog_index_"]:
@@ -163,7 +166,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
163
166
  sec5_text = rc.Heading("ARIMA Model Parameters", level=2)
164
167
  blocks = [
165
168
  rc.Html(
166
- m.summary().as_html(),
169
+ m['model'].summary().as_html(),
167
170
  label=s_id if self.target_cat_col else None,
168
171
  )
169
172
  for i, (s_id, m) in enumerate(self.models.items())
@@ -251,7 +254,7 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
251
254
  def get_explain_predict_fn(self, series_id):
252
255
  def _custom_predict(
253
256
  data,
254
- model=self.models[series_id],
257
+ model=self.models[series_id]["model"],
255
258
  dt_column_name=self.datasets._datetime_column_name,
256
259
  target_col=self.original_target_column,
257
260
  ):
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
  import logging
5
5
  import os
@@ -56,8 +56,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
56
56
  )
57
57
  return model_kwargs_cleaned, time_budget
58
58
 
59
- def preprocess(self, data): # TODO: re-use self.le for explanations
60
- _, df_encoded = _label_encode_dataframe(
59
+ def preprocess(self, data, series_id): # TODO: re-use self.le for explanations
60
+ self.le[series_id], df_encoded = _label_encode_dataframe(
61
61
  data,
62
62
  no_encode={self.spec.datetime_column.name, self.original_target_column},
63
63
  )
@@ -66,8 +66,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
66
66
  @runtime_dependency(
67
67
  module="automlx",
68
68
  err_msg=(
69
- "Please run `pip3 install oracle-automlx>=23.4.1` and "
70
- "`pip3 install oracle-automlx[forecasting]>=23.4.1` "
69
+ "Please run `pip3 install oracle-automlx[forecasting]>=25.1.1` "
71
70
  "to install the required dependencies for automlx."
72
71
  ),
73
72
  )
@@ -105,7 +104,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
105
104
  engine_opts = (
106
105
  None
107
106
  if engine_type == "local"
108
- else ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
107
+ else {"ray_setup": {"_temp_dir": "/tmp/ray-temp"}}
109
108
  )
110
109
  init(
111
110
  engine=engine_type,
@@ -125,7 +124,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
125
124
  self.forecast_output.init_series_output(
126
125
  series_id=s_id, data_at_series=df
127
126
  )
128
- data = self.preprocess(df)
127
+ data = self.preprocess(df, s_id)
129
128
  data_i = self.drop_horizon(data)
130
129
  X_pred = self.get_horizon(data).drop(target, axis=1)
131
130
 
@@ -157,7 +156,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
157
156
  target
158
157
  ].values
159
158
 
160
- self.models[s_id] = model
159
+ self.models[s_id] = {}
160
+ self.models[s_id]["model"] = model
161
+ self.models[s_id]["le"] = self.le[s_id]
161
162
 
162
163
  # In case of Naive model, model.forecast function call does not return confidence intervals.
163
164
  if f"{target}_ci_upper" not in summary_frame:
@@ -218,7 +219,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
218
219
  other_sections = []
219
220
 
220
221
  if len(self.models) > 0:
221
- for s_id, m in models.items():
222
+ for s_id, artifacts in models.items():
223
+ m = artifacts["model"]
222
224
  selected_models[s_id] = {
223
225
  "series_id": s_id,
224
226
  "selected_model": m.selected_model_,
@@ -247,17 +249,17 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
247
249
  self.explain_model()
248
250
 
249
251
  global_explanation_section = None
250
- if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
251
- # Convert the global explanation data to a DataFrame
252
- global_explanation_df = pd.DataFrame(self.global_explanation)
252
+ # Convert the global explanation data to a DataFrame
253
+ global_explanation_df = pd.DataFrame(self.global_explanation)
253
254
 
254
- self.formatted_global_explanation = (
255
- global_explanation_df / global_explanation_df.sum(axis=0) * 100
256
- )
257
- self.formatted_global_explanation = self.formatted_global_explanation.rename(
258
- {self.spec.datetime_column.name: ForecastOutputColumns.DATE},
259
- axis=1,
260
- )
255
+ self.formatted_global_explanation = (
256
+ global_explanation_df / global_explanation_df.sum(axis=0) * 100
257
+ )
258
+
259
+ self.formatted_global_explanation.rename(
260
+ columns={self.spec.datetime_column.name: ForecastOutputColumns.DATE},
261
+ inplace=True,
262
+ )
261
263
 
262
264
  aggregate_local_explanations = pd.DataFrame()
263
265
  for s_id, local_ex_df in self.local_explanation.items():
@@ -269,11 +271,15 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
269
271
  self.formatted_local_explanation = aggregate_local_explanations
270
272
 
271
273
  if not self.target_cat_col:
272
- self.formatted_global_explanation = self.formatted_global_explanation.rename(
273
- {"Series 1": self.original_target_column},
274
- axis=1,
274
+ self.formatted_global_explanation = (
275
+ self.formatted_global_explanation.rename(
276
+ {"Series 1": self.original_target_column},
277
+ axis=1,
278
+ )
279
+ )
280
+ self.formatted_local_explanation.drop(
281
+ "Series", axis=1, inplace=True
275
282
  )
276
- self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
277
283
 
278
284
  # Create a markdown section for the global explainability
279
285
  global_explanation_section = rc.Block(
@@ -320,7 +326,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
320
326
  )
321
327
 
322
328
  def get_explain_predict_fn(self, series_id):
323
- selected_model = self.models[series_id]
329
+ selected_model = self.models[series_id]["model"]
324
330
 
325
331
  # If training date, use method below. If future date, use forecast!
326
332
  def _custom_predict_fn(
@@ -338,12 +344,12 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
338
344
  data[dt_column_name] = seconds_to_datetime(
339
345
  data[dt_column_name], dt_format=self.spec.datetime_column.format
340
346
  )
341
- data = self.preprocess(data)
347
+ data = self.preprocess(data, series_id)
342
348
  horizon_data = horizon_data.drop(target_col, axis=1)
343
349
  horizon_data[dt_column_name] = seconds_to_datetime(
344
350
  horizon_data[dt_column_name], dt_format=self.spec.datetime_column.format
345
351
  )
346
- horizon_data = self.preprocess(horizon_data)
352
+ horizon_data = self.preprocess(horizon_data, series_id)
347
353
 
348
354
  rows = []
349
355
  for i in range(data.shape[0]):
@@ -421,8 +427,10 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
421
427
  if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
422
428
  # Use the MLExplainer class from AutoMLx to generate explanations
423
429
  explainer = automlx.MLExplainer(
424
- self.models[s_id],
425
- self.datasets.additional_data.get_data_for_series(series_id=s_id)
430
+ self.models[s_id]["model"],
431
+ self.datasets.additional_data.get_data_for_series(
432
+ series_id=s_id
433
+ )
426
434
  .drop(self.spec.datetime_column.name, axis=1)
427
435
  .head(-self.spec.horizon)
428
436
  if self.spec.additional_data
@@ -433,7 +441,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
433
441
 
434
442
  # Generate explanations for the forecast
435
443
  explanations = explainer.explain_prediction(
436
- X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
444
+ X=self.datasets.additional_data.get_data_for_series(
445
+ series_id=s_id
446
+ )
437
447
  .drop(self.spec.datetime_column.name, axis=1)
438
448
  .tail(self.spec.horizon)
439
449
  if self.spec.additional_data
@@ -445,7 +455,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
445
455
  explanations_df = pd.concat(
446
456
  [exp.to_dataframe() for exp in explanations]
447
457
  )
448
- explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
458
+ explanations_df["row"] = explanations_df.groupby(
459
+ "Feature"
460
+ ).cumcount()
449
461
  explanations_df = explanations_df.pivot(
450
462
  index="row", columns="Feature", values="Attribution"
451
463
  )
@@ -453,9 +465,27 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
453
465
 
454
466
  # Store the explanations in the local_explanation dictionary
455
467
  self.local_explanation[s_id] = explanations_df
468
+
469
+ self.global_explanation[s_id] = dict(
470
+ zip(
471
+ self.local_explanation[s_id].columns,
472
+ np.nanmean(np.abs(self.local_explanation[s_id]), axis=0),
473
+ )
474
+ )
456
475
  else:
457
476
  # Fall back to the default explanation generation method
458
477
  super().explain_model()
459
478
  except Exception as e:
460
- logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.")
479
+ if s_id in self.errors_dict:
480
+ self.errors_dict[s_id]["explainer_error"] = str(e)
481
+ self.errors_dict[s_id]["explainer_error_trace"] = traceback.format_exc()
482
+ else:
483
+ self.errors_dict[s_id] = {
484
+ "model_name": self.spec.model,
485
+ "explainer_error": str(e),
486
+ "explainer_error_trace": traceback.format_exc(),
487
+ }
488
+ logger.warning(
489
+ f"Failed to generate explanations for series {s_id} with error: {e}."
490
+ )
461
491
  logger.debug(f"Full Traceback: {traceback.format_exc()}")