oracle-ads 2.12.8__py3-none-any.whl → 2.12.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. ads/aqua/__init__.py +4 -3
  2. ads/aqua/app.py +40 -18
  3. ads/aqua/client/__init__.py +3 -0
  4. ads/aqua/client/client.py +799 -0
  5. ads/aqua/common/enums.py +3 -0
  6. ads/aqua/common/utils.py +62 -2
  7. ads/aqua/data.py +2 -19
  8. ads/aqua/evaluation/entities.py +6 -0
  9. ads/aqua/evaluation/evaluation.py +45 -15
  10. ads/aqua/extension/aqua_ws_msg_handler.py +14 -7
  11. ads/aqua/extension/base_handler.py +12 -9
  12. ads/aqua/extension/deployment_handler.py +8 -4
  13. ads/aqua/extension/finetune_handler.py +8 -14
  14. ads/aqua/extension/model_handler.py +30 -6
  15. ads/aqua/extension/ui_handler.py +13 -1
  16. ads/aqua/finetuning/constants.py +5 -2
  17. ads/aqua/finetuning/entities.py +73 -17
  18. ads/aqua/finetuning/finetuning.py +110 -82
  19. ads/aqua/model/entities.py +5 -1
  20. ads/aqua/model/model.py +230 -104
  21. ads/aqua/modeldeployment/deployment.py +35 -11
  22. ads/aqua/modeldeployment/entities.py +7 -4
  23. ads/aqua/ui.py +24 -2
  24. ads/cli.py +16 -8
  25. ads/common/auth.py +9 -9
  26. ads/llm/autogen/__init__.py +2 -0
  27. ads/llm/autogen/constants.py +15 -0
  28. ads/llm/autogen/reports/__init__.py +2 -0
  29. ads/llm/autogen/reports/base.py +67 -0
  30. ads/llm/autogen/reports/data.py +103 -0
  31. ads/llm/autogen/reports/session.py +526 -0
  32. ads/llm/autogen/reports/templates/chat_box.html +13 -0
  33. ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
  34. ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
  35. ads/llm/autogen/reports/utils.py +56 -0
  36. ads/llm/autogen/v02/__init__.py +4 -0
  37. ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
  38. ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
  39. ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
  40. ads/llm/autogen/v02/loggers/__init__.py +6 -0
  41. ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
  42. ads/llm/autogen/v02/loggers/session_logger.py +580 -0
  43. ads/llm/autogen/v02/loggers/utils.py +86 -0
  44. ads/llm/autogen/v02/runtime_logging.py +163 -0
  45. ads/llm/guardrails/base.py +6 -5
  46. ads/llm/langchain/plugins/chat_models/oci_data_science.py +46 -20
  47. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +38 -11
  48. ads/model/__init__.py +11 -13
  49. ads/model/artifact.py +47 -8
  50. ads/model/extractor/embedding_onnx_extractor.py +80 -0
  51. ads/model/framework/embedding_onnx_model.py +438 -0
  52. ads/model/generic_model.py +26 -24
  53. ads/model/model_metadata.py +8 -7
  54. ads/opctl/config/merger.py +13 -14
  55. ads/opctl/operator/common/operator_config.py +4 -4
  56. ads/opctl/operator/lowcode/common/transformations.py +50 -8
  57. ads/opctl/operator/lowcode/common/utils.py +22 -6
  58. ads/opctl/operator/lowcode/forecast/__main__.py +10 -0
  59. ads/opctl/operator/lowcode/forecast/const.py +3 -0
  60. ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
  61. ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
  62. ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
  63. ads/opctl/operator/lowcode/forecast/model/base_model.py +58 -17
  64. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +1 -1
  65. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
  66. ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
  67. ads/opctl/operator/lowcode/forecast/model_evaluator.py +3 -2
  68. ads/opctl/operator/lowcode/forecast/operator_config.py +31 -0
  69. ads/opctl/operator/lowcode/forecast/schema.yaml +76 -0
  70. ads/opctl/operator/lowcode/forecast/utils.py +8 -6
  71. ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
  72. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +233 -0
  73. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +238 -0
  74. ads/telemetry/base.py +18 -11
  75. ads/telemetry/client.py +33 -13
  76. ads/templates/schemas/openapi.json +1740 -0
  77. ads/templates/score_embedding_onnx.jinja2 +202 -0
  78. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/METADATA +11 -10
  79. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/RECORD +82 -56
  80. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/LICENSE.txt +0 -0
  81. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/WHEEL +0 -0
  82. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/entry_points.txt +0 -0
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
3
  # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -11,20 +10,21 @@ import sys
11
10
  from abc import ABC, abstractmethod
12
11
  from dataclasses import dataclass, field, fields
13
12
  from pathlib import Path
14
- from typing import Dict, List, Tuple, Union, Optional, Any
13
+ from typing import Any, Dict, List, Optional, Tuple, Union
15
14
 
16
- import ads.dataset.factory as factory
17
15
  import fsspec
18
16
  import git
19
17
  import oci.data_science.models
20
18
  import pandas as pd
21
19
  import yaml
20
+ from oci.util import to_dict
21
+
22
22
  from ads.common import logger
23
23
  from ads.common.error import ChangesNotCommitted
24
24
  from ads.common.extended_enum import ExtendedEnumMeta
25
- from ads.common.serializer import DataClassSerializable
26
25
  from ads.common.object_storage_details import ObjectStorageDetails
27
- from oci.util import to_dict
26
+ from ads.common.serializer import DataClassSerializable
27
+ from ads.dataset import factory
28
28
 
29
29
  try:
30
30
  from yaml import CDumper as dumper
@@ -173,6 +173,7 @@ class Framework(str, metaclass=ExtendedEnumMeta):
173
173
  WORD2VEC = "word2vec"
174
174
  ENSEMBLE = "ensemble"
175
175
  SPARK = "pyspark"
176
+ EMBEDDING_ONNX = "embedding_onnx"
176
177
  OTHER = "other"
177
178
 
178
179
 
@@ -1398,7 +1399,7 @@ class ModelCustomMetadata(ModelMetadata):
1398
1399
  if (
1399
1400
  not data
1400
1401
  or not isinstance(data, Dict)
1401
- or not "data" in data
1402
+ or "data" not in data
1402
1403
  or not isinstance(data["data"], List)
1403
1404
  ):
1404
1405
  raise ValueError(
@@ -1550,7 +1551,7 @@ class ModelTaxonomyMetadata(ModelMetadata):
1550
1551
  if (
1551
1552
  not data
1552
1553
  or not isinstance(data, Dict)
1553
- or not "data" in data
1554
+ or "data" not in data
1554
1555
  or not isinstance(data["data"], List)
1555
1556
  ):
1556
1557
  raise ValueError(
@@ -1,35 +1,33 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2022, 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import os
8
7
  from string import Template
9
8
  from typing import Dict
10
- import json
11
9
 
12
10
  import yaml
13
11
 
14
12
  from ads.common.auth import AuthType, ResourcePrincipal
15
13
  from ads.opctl import logger
16
14
  from ads.opctl.config.base import ConfigProcessor
17
- from ads.opctl.config.utils import read_from_ini, _DefaultNoneDict
18
- from ads.opctl.utils import is_in_notebook_session, get_service_pack_prefix
15
+ from ads.opctl.config.utils import _DefaultNoneDict, read_from_ini
19
16
  from ads.opctl.constants import (
20
- DEFAULT_PROFILE,
21
- DEFAULT_OCI_CONFIG_FILE,
22
- DEFAULT_CONDA_PACK_FOLDER,
23
- DEFAULT_ADS_CONFIG_FOLDER,
24
- ADS_JOBS_CONFIG_FILE_NAME,
25
17
  ADS_CONFIG_FILE_NAME,
26
- ADS_ML_PIPELINE_CONFIG_FILE_NAME,
27
18
  ADS_DATAFLOW_CONFIG_FILE_NAME,
19
+ ADS_JOBS_CONFIG_FILE_NAME,
28
20
  ADS_LOCAL_BACKEND_CONFIG_FILE_NAME,
21
+ ADS_ML_PIPELINE_CONFIG_FILE_NAME,
29
22
  ADS_MODEL_DEPLOYMENT_CONFIG_FILE_NAME,
30
- DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
31
23
  BACKEND_NAME,
24
+ DEFAULT_ADS_CONFIG_FOLDER,
25
+ DEFAULT_CONDA_PACK_FOLDER,
26
+ DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
27
+ DEFAULT_OCI_CONFIG_FILE,
28
+ DEFAULT_PROFILE,
32
29
  )
30
+ from ads.opctl.utils import get_service_pack_prefix, is_in_notebook_session
33
31
 
34
32
 
35
33
  class ConfigMerger(ConfigProcessor):
@@ -41,8 +39,9 @@ class ConfigMerger(ConfigProcessor):
41
39
  """
42
40
 
43
41
  def process(self, **kwargs) -> None:
44
- config_string = Template(json.dumps(self.config)).safe_substitute(os.environ)
45
- self.config = json.loads(config_string)
42
+ for key, value in self.config.items():
43
+ if isinstance(value, str): # Substitute only if the value is a string
44
+ self.config[key] = Template(value).safe_substitute(os.environ)
46
45
 
47
46
  if "runtime" not in self.config:
48
47
  self.config["runtime"] = {}
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
 
@@ -11,15 +10,16 @@ from dataclasses import dataclass
11
10
  from typing import Any, Dict, List
12
11
 
13
12
  from ads.common.serializer import DataClassSerializable
14
-
15
- from ads.opctl.operator.common.utils import OperatorValidator
16
13
  from ads.opctl.operator.common.errors import InvalidParameterError
14
+ from ads.opctl.operator.common.utils import OperatorValidator
15
+
17
16
 
18
17
  @dataclass(repr=True)
19
18
  class InputData(DataClassSerializable):
20
19
  """Class representing operator specification input data details."""
21
20
 
22
21
  connect_args: Dict = None
22
+ data: Dict = None
23
23
  format: str = None
24
24
  columns: List[str] = None
25
25
  url: str = None
@@ -1,10 +1,11 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
6
  from abc import ABC
7
7
 
8
+ import numpy as np
8
9
  import pandas as pd
9
10
 
10
11
  from ads.opctl import logger
@@ -14,6 +15,7 @@ from ads.opctl.operator.lowcode.common.errors import (
14
15
  InvalidParameterError,
15
16
  )
16
17
  from ads.opctl.operator.lowcode.common.utils import merge_category_columns
18
+ from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorSpec
17
19
 
18
20
 
19
21
  class Transformations(ABC):
@@ -33,6 +35,7 @@ class Transformations(ABC):
33
35
  self.dataset_info = dataset_info
34
36
  self.target_category_columns = dataset_info.target_category_columns
35
37
  self.target_column_name = dataset_info.target_column
38
+ self.raw_column_names = None
36
39
  self.dt_column_name = (
37
40
  dataset_info.datetime_column.name if dataset_info.datetime_column else None
38
41
  )
@@ -59,7 +62,8 @@ class Transformations(ABC):
59
62
 
60
63
  """
61
64
  clean_df = self._remove_trailing_whitespace(data)
62
- # clean_df = self._normalize_column_names(clean_df)
65
+ if isinstance(self.dataset_info, ForecastOperatorSpec):
66
+ clean_df = self._clean_column_names(clean_df)
63
67
  if self.name == "historical_data":
64
68
  self._check_historical_dataset(clean_df)
65
69
  clean_df = self._set_series_id_column(clean_df)
@@ -97,8 +101,36 @@ class Transformations(ABC):
97
101
  def _remove_trailing_whitespace(self, df):
98
102
  return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
99
103
 
100
- # def _normalize_column_names(self, df):
101
- # return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))
104
+ def _clean_column_names(self, df):
105
+ """
106
+ Remove all whitespaces from column names in a DataFrame and store the original names.
107
+
108
+ Parameters:
109
+ df (pd.DataFrame): The DataFrame whose column names need to be cleaned.
110
+
111
+ Returns:
112
+ pd.DataFrame: The DataFrame with cleaned column names.
113
+ """
114
+
115
+ self.raw_column_names = {
116
+ col: col.replace(" ", "") for col in df.columns if " " in col
117
+ }
118
+ df.columns = [self.raw_column_names.get(col, col) for col in df.columns]
119
+
120
+ if self.target_column_name:
121
+ self.target_column_name = self.raw_column_names.get(
122
+ self.target_column_name, self.target_column_name
123
+ )
124
+ self.dt_column_name = self.raw_column_names.get(
125
+ self.dt_column_name, self.dt_column_name
126
+ )
127
+
128
+ if self.target_category_columns:
129
+ self.target_category_columns = [
130
+ self.raw_column_names.get(col, col)
131
+ for col in self.target_category_columns
132
+ ]
133
+ return df
102
134
 
103
135
  def _set_series_id_column(self, df):
104
136
  self._target_category_columns_map = {}
@@ -209,23 +241,33 @@ class Transformations(ABC):
209
241
  -------
210
242
  A new Pandas DataFrame with treated outliears.
211
243
  """
212
- df["z_score"] = (
244
+ return df
245
+ df["__z_score"] = (
213
246
  df[self.target_column_name]
214
247
  .groupby(DataColumns.Series)
215
248
  .transform(lambda x: (x - x.mean()) / x.std())
216
249
  )
217
- outliers_mask = df["z_score"].abs() > 3
250
+ outliers_mask = df["__z_score"].abs() > 3
251
+
252
+ if df[self.target_column_name].dtype == np.int:
253
+ df[self.target_column_name].astype(np.float)
254
+
218
255
  df.loc[outliers_mask, self.target_column_name] = (
219
256
  df[self.target_column_name]
220
257
  .groupby(DataColumns.Series)
221
- .transform(lambda x: x.mean())
258
+ .transform(lambda x: np.median(x))
222
259
  )
223
- return df.drop("z_score", axis=1)
260
+ df_ret = df.drop("__z_score", axis=1)
261
+ return df_ret
224
262
 
225
263
  def _check_historical_dataset(self, df):
226
264
  expected_names = [self.target_column_name, self.dt_column_name] + (
227
265
  self.target_category_columns if self.target_category_columns else []
228
266
  )
267
+
268
+ if self.raw_column_names:
269
+ expected_names.extend(list(self.raw_column_names.values()))
270
+
229
271
  if set(df.columns) != set(expected_names):
230
272
  raise DataMismatchError(
231
273
  f"Expected {self.name} to have columns: {expected_names}, but instead found column names: {df.columns}. Is the {self.name} path correct?"
@@ -1,6 +1,6 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
6
  import logging
@@ -12,6 +12,7 @@ from typing import List, Union
12
12
 
13
13
  import fsspec
14
14
  import oracledb
15
+ import json
15
16
  import pandas as pd
16
17
 
17
18
  from ads.common.object_storage_details import ObjectStorageDetails
@@ -40,6 +41,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
40
41
  if data_spec is None:
41
42
  raise InvalidParameterError("No details provided for this data source.")
42
43
  filename = data_spec.url
44
+ data = data_spec.data
43
45
  format = data_spec.format
44
46
  columns = data_spec.columns
45
47
  connect_args = data_spec.connect_args
@@ -51,9 +53,12 @@ def load_data(data_spec, storage_options=None, **kwargs):
51
53
  default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
52
54
  )
53
55
  if vault_secret_id is not None and connect_args is None:
54
- connect_args = dict()
56
+ connect_args = {}
55
57
 
56
- if filename is not None:
58
+ if data is not None:
59
+ if format == "spark":
60
+ data = data.toPandas()
61
+ elif filename is not None:
57
62
  if not format:
58
63
  _, format = os.path.splitext(filename)
59
64
  format = format[1:]
@@ -98,7 +103,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
98
103
  except Exception as e:
99
104
  raise Exception(
100
105
  f"Could not retrieve database credentials from vault {vault_secret_id}: {e}"
101
- )
106
+ ) from e
102
107
 
103
108
  con = oracledb.connect(**connect_args)
104
109
  if table_name is not None:
@@ -121,7 +126,8 @@ def load_data(data_spec, storage_options=None, **kwargs):
121
126
  return data
122
127
 
123
128
 
124
- def write_data(data, filename, format, storage_options, index=False, **kwargs):
129
+ def write_data(data, filename, format, storage_options=None, index=False, **kwargs):
130
+ disable_print()
125
131
  if not format:
126
132
  _, format = os.path.splitext(filename)
127
133
  format = format[1:]
@@ -130,11 +136,21 @@ def write_data(data, filename, format, storage_options, index=False, **kwargs):
130
136
  return call_pandas_fsspec(
131
137
  write_fn, filename, index=index, storage_options=storage_options, **kwargs
132
138
  )
133
- raise OperatorYamlContentError(
139
+ enable_print()
140
+ raise InvalidParameterError(
134
141
  f"The format {format} is not currently supported for writing data. Please change the format parameter for the data output: {filename} ."
135
142
  )
136
143
 
137
144
 
145
+ def write_simple_json(data, path):
146
+ if ObjectStorageDetails.is_oci_path(path):
147
+ storage_options = default_signer()
148
+ else:
149
+ storage_options = {}
150
+ with fsspec.open(path, mode="w", **storage_options) as f:
151
+ json.dump(data, f, indent=4)
152
+
153
+
138
154
  def merge_category_columns(data, target_category_columns):
139
155
  result = data.apply(
140
156
  lambda x: "__".join([str(x[col]) for col in target_category_columns]), axis=1
@@ -17,6 +17,7 @@ from ads.opctl.operator.common.utils import _parse_input_args
17
17
 
18
18
  from .operator_config import ForecastOperatorConfig
19
19
  from .model.forecast_datasets import ForecastDatasets
20
+ from .whatifserve import ModelDeploymentManager
20
21
 
21
22
 
22
23
  def operate(operator_config: ForecastOperatorConfig) -> None:
@@ -27,6 +28,15 @@ def operate(operator_config: ForecastOperatorConfig) -> None:
27
28
  ForecastOperatorModelFactory.get_model(
28
29
  operator_config, datasets
29
30
  ).generate_report()
31
+ # saving to model catalog
32
+ spec = operator_config.spec
33
+ if spec.what_if_analysis and datasets.additional_data:
34
+ mdm = ModelDeploymentManager(spec, datasets.additional_data)
35
+ mdm.save_to_catalog()
36
+ if spec.what_if_analysis.model_deployment:
37
+ mdm.create_deployment()
38
+ mdm.save_deployment_info()
39
+
30
40
 
31
41
  def verify(spec: Dict, **kwargs: Dict) -> bool:
32
42
  """Verifies the forecasting operator config."""
@@ -27,10 +27,12 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
27
27
  HIGH_ACCURACY = "HIGH_ACCURACY"
28
28
  BALANCED = "BALANCED"
29
29
  FAST_APPROXIMATE = "FAST_APPROXIMATE"
30
+ AUTOMLX = "AUTOMLX"
30
31
  ratio = {}
31
32
  ratio[HIGH_ACCURACY] = 1 # 100 % data used for generating explanations
32
33
  ratio[BALANCED] = 0.5 # 50 % data used for generating explanations
33
34
  ratio[FAST_APPROXIMATE] = 0 # constant
35
+ ratio[AUTOMLX] = 0 # constant
34
36
 
35
37
 
36
38
  class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
@@ -87,3 +89,4 @@ SUMMARY_METRICS_HORIZON_LIMIT = 10
87
89
  PROPHET_INTERNAL_DATE_COL = "ds"
88
90
  RENDER_LIMIT = 5000
89
91
  AUTO_SELECT = "auto-select"
92
+ BACKTEST_REPORT_NAME = "back_test.csv"
@@ -164,11 +164,11 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
164
164
  blocks = [
165
165
  rc.Html(
166
166
  m.summary().as_html(),
167
- label=s_id,
167
+ label=s_id if self.target_cat_col else None,
168
168
  )
169
169
  for i, (s_id, m) in enumerate(self.models.items())
170
170
  ]
171
- sec5 = rc.Select(blocks=blocks)
171
+ sec5 = rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
172
172
  all_sections = [sec5_text, sec5]
173
173
 
174
174
  if self.spec.generate_explanations:
@@ -188,6 +188,21 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
188
188
  axis=1,
189
189
  )
190
190
  )
191
+ aggregate_local_explanations = pd.DataFrame()
192
+ for s_id, local_ex_df in self.local_explanation.items():
193
+ local_ex_df_copy = local_ex_df.copy()
194
+ local_ex_df_copy["Series"] = s_id
195
+ aggregate_local_explanations = pd.concat(
196
+ [aggregate_local_explanations, local_ex_df_copy], axis=0
197
+ )
198
+ self.formatted_local_explanation = aggregate_local_explanations
199
+
200
+ if not self.target_cat_col:
201
+ self.formatted_global_explanation = self.formatted_global_explanation.rename(
202
+ {"Series 1": self.original_target_column},
203
+ axis=1,
204
+ )
205
+ self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
191
206
 
192
207
  # Create a markdown section for the global explainability
193
208
  global_explanation_section = rc.Block(
@@ -198,26 +213,17 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
198
213
  rc.DataTable(self.formatted_global_explanation, index=True),
199
214
  )
200
215
 
201
- aggregate_local_explanations = pd.DataFrame()
202
- for s_id, local_ex_df in self.local_explanation.items():
203
- local_ex_df_copy = local_ex_df.copy()
204
- local_ex_df_copy["Series"] = s_id
205
- aggregate_local_explanations = pd.concat(
206
- [aggregate_local_explanations, local_ex_df_copy], axis=0
207
- )
208
- self.formatted_local_explanation = aggregate_local_explanations
209
-
210
216
  blocks = [
211
217
  rc.DataTable(
212
218
  local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
213
- label=s_id,
219
+ label=s_id if self.target_cat_col else None,
214
220
  index=True,
215
221
  )
216
222
  for s_id, local_ex_df in self.local_explanation.items()
217
223
  ]
218
224
  local_explanation_section = rc.Block(
219
225
  rc.Heading("Local Explanation of Models", level=2),
220
- rc.Select(blocks=blocks),
226
+ rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
221
227
  )
222
228
 
223
229
  # Append the global explanation text and section to the "all_sections" list
@@ -17,6 +17,7 @@ from ads.opctl.operator.lowcode.common.utils import (
17
17
  from ads.opctl.operator.lowcode.forecast.const import (
18
18
  AUTOMLX_METRIC_MAP,
19
19
  ForecastOutputColumns,
20
+ SpeedAccuracyMode,
20
21
  SupportedModels,
21
22
  )
22
23
  from ads.opctl.operator.lowcode.forecast.utils import _label_encode_dataframe
@@ -81,22 +82,6 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
81
82
 
82
83
  from automlx import Pipeline, init
83
84
 
84
- cpu_count = os.cpu_count()
85
- try:
86
- if cpu_count < 4:
87
- engine = "local"
88
- engine_opts = None
89
- else:
90
- engine = "ray"
91
- engine_opts = ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
92
- init(
93
- engine=engine,
94
- engine_opts=engine_opts,
95
- loglevel=logging.CRITICAL,
96
- )
97
- except Exception as e:
98
- logger.info(f"Error. Has Ray already been initialized? Skipping. {e}")
99
-
100
85
  full_data_dict = self.datasets.get_data_by_series()
101
86
 
102
87
  self.models = {}
@@ -112,6 +97,26 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
112
97
  # Clean up kwargs for pass through
113
98
  model_kwargs_cleaned, time_budget = self.set_kwargs()
114
99
 
100
+ cpu_count = os.cpu_count()
101
+ try:
102
+ engine_type = model_kwargs_cleaned.pop(
103
+ "engine", "local" if cpu_count <= 4 else "ray"
104
+ )
105
+ engine_opts = (
106
+ None
107
+ if engine_type == "local"
108
+ else ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
109
+ )
110
+ init(
111
+ engine=engine_type,
112
+ engine_opts=engine_opts,
113
+ loglevel=logging.CRITICAL,
114
+ )
115
+ except Exception as e:
116
+ logger.info(
117
+ f"Error initializing automlx. Has Ray already been initialized? Skipping. {e}"
118
+ )
119
+
115
120
  for s_id, df in full_data_dict.items():
116
121
  try:
117
122
  logger.debug(f"Running automlx on series {s_id}")
@@ -223,6 +228,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
223
228
  selected_models.items(), columns=["series_id", "best_selected_model"]
224
229
  )
225
230
  selected_df = selected_models_df["best_selected_model"].apply(pd.Series)
231
+ if not self.target_cat_col:
232
+ selected_df = selected_df.drop("series_id", axis=1)
226
233
  selected_models_section = rc.Block(
227
234
  rc.Heading("Selected Models Overview", level=2),
228
235
  rc.Text(
@@ -239,27 +246,18 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
239
246
  # If the key is present, call the "explain_model" method
240
247
  self.explain_model()
241
248
 
242
- # Convert the global explanation data to a DataFrame
243
- global_explanation_df = pd.DataFrame(self.global_explanation)
249
+ global_explanation_section = None
250
+ if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
251
+ # Convert the global explanation data to a DataFrame
252
+ global_explanation_df = pd.DataFrame(self.global_explanation)
244
253
 
245
- self.formatted_global_explanation = (
246
- global_explanation_df / global_explanation_df.sum(axis=0) * 100
247
- )
248
- self.formatted_global_explanation = (
249
- self.formatted_global_explanation.rename(
254
+ self.formatted_global_explanation = (
255
+ global_explanation_df / global_explanation_df.sum(axis=0) * 100
256
+ )
257
+ self.formatted_global_explanation = self.formatted_global_explanation.rename(
250
258
  {self.spec.datetime_column.name: ForecastOutputColumns.DATE},
251
259
  axis=1,
252
260
  )
253
- )
254
-
255
- # Create a markdown section for the global explainability
256
- global_explanation_section = rc.Block(
257
- rc.Heading("Global Explanation of Models", level=2),
258
- rc.Text(
259
- "The following tables provide the feature attribution for the global explainability."
260
- ),
261
- rc.DataTable(self.formatted_global_explanation, index=True),
262
- )
263
261
 
264
262
  aggregate_local_explanations = pd.DataFrame()
265
263
  for s_id, local_ex_df in self.local_explanation.items():
@@ -270,22 +268,41 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
270
268
  )
271
269
  self.formatted_local_explanation = aggregate_local_explanations
272
270
 
271
+ if not self.target_cat_col:
272
+ self.formatted_global_explanation = self.formatted_global_explanation.rename(
273
+ {"Series 1": self.original_target_column},
274
+ axis=1,
275
+ )
276
+ self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
277
+
278
+ # Create a markdown section for the global explainability
279
+ global_explanation_section = rc.Block(
280
+ rc.Heading("Global Explanation of Models", level=2),
281
+ rc.Text(
282
+ "The following tables provide the feature attribution for the global explainability."
283
+ ),
284
+ rc.DataTable(self.formatted_global_explanation, index=True),
285
+ )
286
+
273
287
  blocks = [
274
288
  rc.DataTable(
275
289
  local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
276
- label=s_id,
290
+ label=s_id if self.target_cat_col else None,
277
291
  index=True,
278
292
  )
279
293
  for s_id, local_ex_df in self.local_explanation.items()
280
294
  ]
281
295
  local_explanation_section = rc.Block(
282
296
  rc.Heading("Local Explanation of Models", level=2),
283
- rc.Select(blocks=blocks),
297
+ rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
284
298
  )
285
299
 
286
300
  # Append the global explanation text and section to the "other_sections" list
301
+ if global_explanation_section:
302
+ other_sections.append(global_explanation_section)
303
+
304
+ # Append the local explanation text and section to the "other_sections" list
287
305
  other_sections = other_sections + [
288
- global_explanation_section,
289
306
  local_explanation_section,
290
307
  ]
291
308
  except Exception as e:
@@ -366,3 +383,79 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
366
383
  return self.models.get(self.series_id).forecast(
367
384
  X=data_temp, periods=data_temp.shape[0]
368
385
  )[self.series_id]
386
+
387
+ @runtime_dependency(
388
+ module="automlx",
389
+ err_msg=(
390
+ "Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
391
+ ),
392
+ )
393
+ def explain_model(self):
394
+ """
395
+ Generates explanations for the model using the AutoMLx library.
396
+
397
+ Parameters
398
+ ----------
399
+ None
400
+
401
+ Returns
402
+ -------
403
+ None
404
+
405
+ Notes
406
+ -----
407
+ This function works by generating local explanations for each series in the dataset.
408
+ It uses the ``MLExplainer`` class from the AutoMLx library to generate feature attributions
409
+ for each series. The feature attributions are then stored in the ``self.local_explanation`` dictionary.
410
+
411
+ If the accuracy mode is set to AutoMLX, it uses the AutoMLx library to generate explanations.
412
+ Otherwise, it falls back to the default explanation generation method.
413
+ """
414
+ import automlx
415
+
416
+ # Loop through each series in the dataset
417
+ for s_id, data_i in self.datasets.get_data_by_series(
418
+ include_horizon=False
419
+ ).items():
420
+ try:
421
+ if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
422
+ # Use the MLExplainer class from AutoMLx to generate explanations
423
+ explainer = automlx.MLExplainer(
424
+ self.models[s_id],
425
+ self.datasets.additional_data.get_data_for_series(series_id=s_id)
426
+ .drop(self.spec.datetime_column.name, axis=1)
427
+ .head(-self.spec.horizon)
428
+ if self.spec.additional_data
429
+ else None,
430
+ pd.DataFrame(data_i[self.spec.target_column]),
431
+ task="forecasting",
432
+ )
433
+
434
+ # Generate explanations for the forecast
435
+ explanations = explainer.explain_prediction(
436
+ X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
437
+ .drop(self.spec.datetime_column.name, axis=1)
438
+ .tail(self.spec.horizon)
439
+ if self.spec.additional_data
440
+ else None,
441
+ forecast_timepoints=list(range(self.spec.horizon + 1)),
442
+ )
443
+
444
+ # Convert the explanations to a DataFrame
445
+ explanations_df = pd.concat(
446
+ [exp.to_dataframe() for exp in explanations]
447
+ )
448
+ explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
449
+ explanations_df = explanations_df.pivot(
450
+ index="row", columns="Feature", values="Attribution"
451
+ )
452
+ explanations_df = explanations_df.reset_index(drop=True)
453
+
454
+ # Store the explanations in the local_explanation dictionary
455
+ self.local_explanation[s_id] = explanations_df
456
+ else:
457
+ # Fall back to the default explanation generation method
458
+ super().explain_model()
459
+ except Exception as e:
460
+ logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.")
461
+ logger.debug(f"Full Traceback: {traceback.format_exc()}")
@@ -242,6 +242,7 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
242
242
  self.models.df_wide_numeric, series=s_id
243
243
  ),
244
244
  self.datasets.list_series_ids(),
245
+ target_category_column=self.target_cat_col
245
246
  )
246
247
  section_1 = rc.Block(
247
248
  rc.Heading("Forecast Overview", level=2),