oracle-ads 2.12.8__py3-none-any.whl → 2.12.10rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. ads/aqua/__init__.py +4 -4
  2. ads/aqua/app.py +12 -2
  3. ads/aqua/common/enums.py +3 -0
  4. ads/aqua/common/utils.py +62 -2
  5. ads/aqua/data.py +2 -19
  6. ads/aqua/evaluation/entities.py +6 -0
  7. ads/aqua/evaluation/evaluation.py +25 -3
  8. ads/aqua/extension/deployment_handler.py +8 -4
  9. ads/aqua/extension/finetune_handler.py +8 -14
  10. ads/aqua/extension/model_handler.py +25 -6
  11. ads/aqua/extension/ui_handler.py +13 -1
  12. ads/aqua/finetuning/constants.py +5 -2
  13. ads/aqua/finetuning/entities.py +70 -17
  14. ads/aqua/finetuning/finetuning.py +79 -82
  15. ads/aqua/model/entities.py +4 -1
  16. ads/aqua/model/model.py +95 -29
  17. ads/aqua/modeldeployment/deployment.py +13 -1
  18. ads/aqua/modeldeployment/entities.py +7 -4
  19. ads/aqua/ui.py +24 -2
  20. ads/common/auth.py +9 -9
  21. ads/llm/autogen/__init__.py +2 -0
  22. ads/llm/autogen/constants.py +15 -0
  23. ads/llm/autogen/reports/__init__.py +2 -0
  24. ads/llm/autogen/reports/base.py +67 -0
  25. ads/llm/autogen/reports/data.py +103 -0
  26. ads/llm/autogen/reports/session.py +526 -0
  27. ads/llm/autogen/reports/templates/chat_box.html +13 -0
  28. ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
  29. ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
  30. ads/llm/autogen/reports/utils.py +56 -0
  31. ads/llm/autogen/v02/__init__.py +4 -0
  32. ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
  33. ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
  34. ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
  35. ads/llm/autogen/v02/loggers/__init__.py +6 -0
  36. ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
  37. ads/llm/autogen/v02/loggers/session_logger.py +580 -0
  38. ads/llm/autogen/v02/loggers/utils.py +86 -0
  39. ads/llm/autogen/v02/runtime_logging.py +163 -0
  40. ads/llm/guardrails/base.py +6 -5
  41. ads/llm/langchain/plugins/chat_models/oci_data_science.py +46 -20
  42. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +38 -11
  43. ads/model/__init__.py +11 -13
  44. ads/model/artifact.py +47 -8
  45. ads/model/extractor/embedding_onnx_extractor.py +80 -0
  46. ads/model/framework/embedding_onnx_model.py +438 -0
  47. ads/model/generic_model.py +26 -24
  48. ads/model/model_metadata.py +8 -7
  49. ads/opctl/config/merger.py +13 -14
  50. ads/opctl/operator/common/operator_config.py +4 -4
  51. ads/opctl/operator/lowcode/common/transformations.py +12 -5
  52. ads/opctl/operator/lowcode/common/utils.py +11 -5
  53. ads/opctl/operator/lowcode/forecast/const.py +3 -0
  54. ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
  55. ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
  56. ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
  57. ads/opctl/operator/lowcode/forecast/model/base_model.py +58 -17
  58. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
  59. ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
  60. ads/opctl/operator/lowcode/forecast/model_evaluator.py +3 -2
  61. ads/opctl/operator/lowcode/forecast/schema.yaml +13 -0
  62. ads/opctl/operator/lowcode/forecast/utils.py +8 -6
  63. ads/telemetry/base.py +18 -11
  64. ads/telemetry/client.py +33 -13
  65. ads/templates/schemas/openapi.json +1740 -0
  66. ads/templates/score_embedding_onnx.jinja2 +202 -0
  67. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/METADATA +9 -10
  68. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/RECORD +71 -50
  69. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/LICENSE.txt +0 -0
  70. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/WHEEL +0 -0
  71. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/entry_points.txt +0 -0
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
3
  # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -11,20 +10,21 @@ import sys
11
10
  from abc import ABC, abstractmethod
12
11
  from dataclasses import dataclass, field, fields
13
12
  from pathlib import Path
14
- from typing import Dict, List, Tuple, Union, Optional, Any
13
+ from typing import Any, Dict, List, Optional, Tuple, Union
15
14
 
16
- import ads.dataset.factory as factory
17
15
  import fsspec
18
16
  import git
19
17
  import oci.data_science.models
20
18
  import pandas as pd
21
19
  import yaml
20
+ from oci.util import to_dict
21
+
22
22
  from ads.common import logger
23
23
  from ads.common.error import ChangesNotCommitted
24
24
  from ads.common.extended_enum import ExtendedEnumMeta
25
- from ads.common.serializer import DataClassSerializable
26
25
  from ads.common.object_storage_details import ObjectStorageDetails
27
- from oci.util import to_dict
26
+ from ads.common.serializer import DataClassSerializable
27
+ from ads.dataset import factory
28
28
 
29
29
  try:
30
30
  from yaml import CDumper as dumper
@@ -173,6 +173,7 @@ class Framework(str, metaclass=ExtendedEnumMeta):
173
173
  WORD2VEC = "word2vec"
174
174
  ENSEMBLE = "ensemble"
175
175
  SPARK = "pyspark"
176
+ EMBEDDING_ONNX = "embedding_onnx"
176
177
  OTHER = "other"
177
178
 
178
179
 
@@ -1398,7 +1399,7 @@ class ModelCustomMetadata(ModelMetadata):
1398
1399
  if (
1399
1400
  not data
1400
1401
  or not isinstance(data, Dict)
1401
- or not "data" in data
1402
+ or "data" not in data
1402
1403
  or not isinstance(data["data"], List)
1403
1404
  ):
1404
1405
  raise ValueError(
@@ -1550,7 +1551,7 @@ class ModelTaxonomyMetadata(ModelMetadata):
1550
1551
  if (
1551
1552
  not data
1552
1553
  or not isinstance(data, Dict)
1553
- or not "data" in data
1554
+ or "data" not in data
1554
1555
  or not isinstance(data["data"], List)
1555
1556
  ):
1556
1557
  raise ValueError(
@@ -1,35 +1,33 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2022, 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import os
8
7
  from string import Template
9
8
  from typing import Dict
10
- import json
11
9
 
12
10
  import yaml
13
11
 
14
12
  from ads.common.auth import AuthType, ResourcePrincipal
15
13
  from ads.opctl import logger
16
14
  from ads.opctl.config.base import ConfigProcessor
17
- from ads.opctl.config.utils import read_from_ini, _DefaultNoneDict
18
- from ads.opctl.utils import is_in_notebook_session, get_service_pack_prefix
15
+ from ads.opctl.config.utils import _DefaultNoneDict, read_from_ini
19
16
  from ads.opctl.constants import (
20
- DEFAULT_PROFILE,
21
- DEFAULT_OCI_CONFIG_FILE,
22
- DEFAULT_CONDA_PACK_FOLDER,
23
- DEFAULT_ADS_CONFIG_FOLDER,
24
- ADS_JOBS_CONFIG_FILE_NAME,
25
17
  ADS_CONFIG_FILE_NAME,
26
- ADS_ML_PIPELINE_CONFIG_FILE_NAME,
27
18
  ADS_DATAFLOW_CONFIG_FILE_NAME,
19
+ ADS_JOBS_CONFIG_FILE_NAME,
28
20
  ADS_LOCAL_BACKEND_CONFIG_FILE_NAME,
21
+ ADS_ML_PIPELINE_CONFIG_FILE_NAME,
29
22
  ADS_MODEL_DEPLOYMENT_CONFIG_FILE_NAME,
30
- DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
31
23
  BACKEND_NAME,
24
+ DEFAULT_ADS_CONFIG_FOLDER,
25
+ DEFAULT_CONDA_PACK_FOLDER,
26
+ DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
27
+ DEFAULT_OCI_CONFIG_FILE,
28
+ DEFAULT_PROFILE,
32
29
  )
30
+ from ads.opctl.utils import get_service_pack_prefix, is_in_notebook_session
33
31
 
34
32
 
35
33
  class ConfigMerger(ConfigProcessor):
@@ -41,8 +39,9 @@ class ConfigMerger(ConfigProcessor):
41
39
  """
42
40
 
43
41
  def process(self, **kwargs) -> None:
44
- config_string = Template(json.dumps(self.config)).safe_substitute(os.environ)
45
- self.config = json.loads(config_string)
42
+ for key, value in self.config.items():
43
+ if isinstance(value, str): # Substitute only if the value is a string
44
+ self.config[key] = Template(value).safe_substitute(os.environ)
46
45
 
47
46
  if "runtime" not in self.config:
48
47
  self.config["runtime"] = {}
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
 
@@ -11,15 +10,16 @@ from dataclasses import dataclass
11
10
  from typing import Any, Dict, List
12
11
 
13
12
  from ads.common.serializer import DataClassSerializable
14
-
15
- from ads.opctl.operator.common.utils import OperatorValidator
16
13
  from ads.opctl.operator.common.errors import InvalidParameterError
14
+ from ads.opctl.operator.common.utils import OperatorValidator
15
+
17
16
 
18
17
  @dataclass(repr=True)
19
18
  class InputData(DataClassSerializable):
20
19
  """Class representing operator specification input data details."""
21
20
 
22
21
  connect_args: Dict = None
22
+ data: Dict = None
23
23
  format: str = None
24
24
  columns: List[str] = None
25
25
  url: str = None
@@ -1,10 +1,11 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
6
  from abc import ABC
7
7
 
8
+ import numpy as np
8
9
  import pandas as pd
9
10
 
10
11
  from ads.opctl import logger
@@ -209,18 +210,24 @@ class Transformations(ABC):
209
210
  -------
210
211
  A new Pandas DataFrame with treated outliears.
211
212
  """
212
- df["z_score"] = (
213
+ return df
214
+ df["__z_score"] = (
213
215
  df[self.target_column_name]
214
216
  .groupby(DataColumns.Series)
215
217
  .transform(lambda x: (x - x.mean()) / x.std())
216
218
  )
217
- outliers_mask = df["z_score"].abs() > 3
219
+ outliers_mask = df["__z_score"].abs() > 3
220
+
221
+ if df[self.target_column_name].dtype == np.int:
222
+ df[self.target_column_name].astype(np.float)
223
+
218
224
  df.loc[outliers_mask, self.target_column_name] = (
219
225
  df[self.target_column_name]
220
226
  .groupby(DataColumns.Series)
221
- .transform(lambda x: x.mean())
227
+ .transform(lambda x: np.median(x))
222
228
  )
223
- return df.drop("z_score", axis=1)
229
+ df_ret = df.drop("__z_score", axis=1)
230
+ return df_ret
224
231
 
225
232
  def _check_historical_dataset(self, df):
226
233
  expected_names = [self.target_column_name, self.dt_column_name] + (
@@ -1,6 +1,6 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
6
  import logging
@@ -40,6 +40,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
40
40
  if data_spec is None:
41
41
  raise InvalidParameterError("No details provided for this data source.")
42
42
  filename = data_spec.url
43
+ data = data_spec.data
43
44
  format = data_spec.format
44
45
  columns = data_spec.columns
45
46
  connect_args = data_spec.connect_args
@@ -51,9 +52,12 @@ def load_data(data_spec, storage_options=None, **kwargs):
51
52
  default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
52
53
  )
53
54
  if vault_secret_id is not None and connect_args is None:
54
- connect_args = dict()
55
+ connect_args = {}
55
56
 
56
- if filename is not None:
57
+ if data is not None:
58
+ if format == "spark":
59
+ data = data.toPandas()
60
+ elif filename is not None:
57
61
  if not format:
58
62
  _, format = os.path.splitext(filename)
59
63
  format = format[1:]
@@ -98,7 +102,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
98
102
  except Exception as e:
99
103
  raise Exception(
100
104
  f"Could not retrieve database credentials from vault {vault_secret_id}: {e}"
101
- )
105
+ ) from e
102
106
 
103
107
  con = oracledb.connect(**connect_args)
104
108
  if table_name is not None:
@@ -122,6 +126,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
122
126
 
123
127
 
124
128
  def write_data(data, filename, format, storage_options, index=False, **kwargs):
129
+ disable_print()
125
130
  if not format:
126
131
  _, format = os.path.splitext(filename)
127
132
  format = format[1:]
@@ -130,7 +135,8 @@ def write_data(data, filename, format, storage_options, index=False, **kwargs):
130
135
  return call_pandas_fsspec(
131
136
  write_fn, filename, index=index, storage_options=storage_options, **kwargs
132
137
  )
133
- raise OperatorYamlContentError(
138
+ enable_print()
139
+ raise InvalidParameterError(
134
140
  f"The format {format} is not currently supported for writing data. Please change the format parameter for the data output: {filename} ."
135
141
  )
136
142
 
@@ -27,10 +27,12 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
27
27
  HIGH_ACCURACY = "HIGH_ACCURACY"
28
28
  BALANCED = "BALANCED"
29
29
  FAST_APPROXIMATE = "FAST_APPROXIMATE"
30
+ AUTOMLX = "AUTOMLX"
30
31
  ratio = {}
31
32
  ratio[HIGH_ACCURACY] = 1 # 100 % data used for generating explanations
32
33
  ratio[BALANCED] = 0.5 # 50 % data used for generating explanations
33
34
  ratio[FAST_APPROXIMATE] = 0 # constant
35
+ ratio[AUTOMLX] = 0 # constant
34
36
 
35
37
 
36
38
  class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
@@ -87,3 +89,4 @@ SUMMARY_METRICS_HORIZON_LIMIT = 10
87
89
  PROPHET_INTERNAL_DATE_COL = "ds"
88
90
  RENDER_LIMIT = 5000
89
91
  AUTO_SELECT = "auto-select"
92
+ BACKTEST_REPORT_NAME = "back_test.csv"
@@ -164,11 +164,11 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
164
164
  blocks = [
165
165
  rc.Html(
166
166
  m.summary().as_html(),
167
- label=s_id,
167
+ label=s_id if self.target_cat_col else None,
168
168
  )
169
169
  for i, (s_id, m) in enumerate(self.models.items())
170
170
  ]
171
- sec5 = rc.Select(blocks=blocks)
171
+ sec5 = rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
172
172
  all_sections = [sec5_text, sec5]
173
173
 
174
174
  if self.spec.generate_explanations:
@@ -188,6 +188,21 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
188
188
  axis=1,
189
189
  )
190
190
  )
191
+ aggregate_local_explanations = pd.DataFrame()
192
+ for s_id, local_ex_df in self.local_explanation.items():
193
+ local_ex_df_copy = local_ex_df.copy()
194
+ local_ex_df_copy["Series"] = s_id
195
+ aggregate_local_explanations = pd.concat(
196
+ [aggregate_local_explanations, local_ex_df_copy], axis=0
197
+ )
198
+ self.formatted_local_explanation = aggregate_local_explanations
199
+
200
+ if not self.target_cat_col:
201
+ self.formatted_global_explanation = self.formatted_global_explanation.rename(
202
+ {"Series 1": self.original_target_column},
203
+ axis=1,
204
+ )
205
+ self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
191
206
 
192
207
  # Create a markdown section for the global explainability
193
208
  global_explanation_section = rc.Block(
@@ -198,26 +213,17 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
198
213
  rc.DataTable(self.formatted_global_explanation, index=True),
199
214
  )
200
215
 
201
- aggregate_local_explanations = pd.DataFrame()
202
- for s_id, local_ex_df in self.local_explanation.items():
203
- local_ex_df_copy = local_ex_df.copy()
204
- local_ex_df_copy["Series"] = s_id
205
- aggregate_local_explanations = pd.concat(
206
- [aggregate_local_explanations, local_ex_df_copy], axis=0
207
- )
208
- self.formatted_local_explanation = aggregate_local_explanations
209
-
210
216
  blocks = [
211
217
  rc.DataTable(
212
218
  local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
213
- label=s_id,
219
+ label=s_id if self.target_cat_col else None,
214
220
  index=True,
215
221
  )
216
222
  for s_id, local_ex_df in self.local_explanation.items()
217
223
  ]
218
224
  local_explanation_section = rc.Block(
219
225
  rc.Heading("Local Explanation of Models", level=2),
220
- rc.Select(blocks=blocks),
226
+ rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
221
227
  )
222
228
 
223
229
  # Append the global explanation text and section to the "all_sections" list
@@ -17,6 +17,7 @@ from ads.opctl.operator.lowcode.common.utils import (
17
17
  from ads.opctl.operator.lowcode.forecast.const import (
18
18
  AUTOMLX_METRIC_MAP,
19
19
  ForecastOutputColumns,
20
+ SpeedAccuracyMode,
20
21
  SupportedModels,
21
22
  )
22
23
  from ads.opctl.operator.lowcode.forecast.utils import _label_encode_dataframe
@@ -81,22 +82,6 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
81
82
 
82
83
  from automlx import Pipeline, init
83
84
 
84
- cpu_count = os.cpu_count()
85
- try:
86
- if cpu_count < 4:
87
- engine = "local"
88
- engine_opts = None
89
- else:
90
- engine = "ray"
91
- engine_opts = ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
92
- init(
93
- engine=engine,
94
- engine_opts=engine_opts,
95
- loglevel=logging.CRITICAL,
96
- )
97
- except Exception as e:
98
- logger.info(f"Error. Has Ray already been initialized? Skipping. {e}")
99
-
100
85
  full_data_dict = self.datasets.get_data_by_series()
101
86
 
102
87
  self.models = {}
@@ -112,6 +97,26 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
112
97
  # Clean up kwargs for pass through
113
98
  model_kwargs_cleaned, time_budget = self.set_kwargs()
114
99
 
100
+ cpu_count = os.cpu_count()
101
+ try:
102
+ engine_type = model_kwargs_cleaned.pop(
103
+ "engine", "local" if cpu_count <= 4 else "ray"
104
+ )
105
+ engine_opts = (
106
+ None
107
+ if engine_type == "local"
108
+ else ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
109
+ )
110
+ init(
111
+ engine=engine_type,
112
+ engine_opts=engine_opts,
113
+ loglevel=logging.CRITICAL,
114
+ )
115
+ except Exception as e:
116
+ logger.info(
117
+ f"Error initializing automlx. Has Ray already been initialized? Skipping. {e}"
118
+ )
119
+
115
120
  for s_id, df in full_data_dict.items():
116
121
  try:
117
122
  logger.debug(f"Running automlx on series {s_id}")
@@ -223,6 +228,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
223
228
  selected_models.items(), columns=["series_id", "best_selected_model"]
224
229
  )
225
230
  selected_df = selected_models_df["best_selected_model"].apply(pd.Series)
231
+ if not self.target_cat_col:
232
+ selected_df = selected_df.drop("series_id", axis=1)
226
233
  selected_models_section = rc.Block(
227
234
  rc.Heading("Selected Models Overview", level=2),
228
235
  rc.Text(
@@ -239,27 +246,18 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
239
246
  # If the key is present, call the "explain_model" method
240
247
  self.explain_model()
241
248
 
242
- # Convert the global explanation data to a DataFrame
243
- global_explanation_df = pd.DataFrame(self.global_explanation)
249
+ global_explanation_section = None
250
+ if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
251
+ # Convert the global explanation data to a DataFrame
252
+ global_explanation_df = pd.DataFrame(self.global_explanation)
244
253
 
245
- self.formatted_global_explanation = (
246
- global_explanation_df / global_explanation_df.sum(axis=0) * 100
247
- )
248
- self.formatted_global_explanation = (
249
- self.formatted_global_explanation.rename(
254
+ self.formatted_global_explanation = (
255
+ global_explanation_df / global_explanation_df.sum(axis=0) * 100
256
+ )
257
+ self.formatted_global_explanation = self.formatted_global_explanation.rename(
250
258
  {self.spec.datetime_column.name: ForecastOutputColumns.DATE},
251
259
  axis=1,
252
260
  )
253
- )
254
-
255
- # Create a markdown section for the global explainability
256
- global_explanation_section = rc.Block(
257
- rc.Heading("Global Explanation of Models", level=2),
258
- rc.Text(
259
- "The following tables provide the feature attribution for the global explainability."
260
- ),
261
- rc.DataTable(self.formatted_global_explanation, index=True),
262
- )
263
261
 
264
262
  aggregate_local_explanations = pd.DataFrame()
265
263
  for s_id, local_ex_df in self.local_explanation.items():
@@ -270,22 +268,41 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
270
268
  )
271
269
  self.formatted_local_explanation = aggregate_local_explanations
272
270
 
271
+ if not self.target_cat_col:
272
+ self.formatted_global_explanation = self.formatted_global_explanation.rename(
273
+ {"Series 1": self.original_target_column},
274
+ axis=1,
275
+ )
276
+ self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
277
+
278
+ # Create a markdown section for the global explainability
279
+ global_explanation_section = rc.Block(
280
+ rc.Heading("Global Explanation of Models", level=2),
281
+ rc.Text(
282
+ "The following tables provide the feature attribution for the global explainability."
283
+ ),
284
+ rc.DataTable(self.formatted_global_explanation, index=True),
285
+ )
286
+
273
287
  blocks = [
274
288
  rc.DataTable(
275
289
  local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
276
- label=s_id,
290
+ label=s_id if self.target_cat_col else None,
277
291
  index=True,
278
292
  )
279
293
  for s_id, local_ex_df in self.local_explanation.items()
280
294
  ]
281
295
  local_explanation_section = rc.Block(
282
296
  rc.Heading("Local Explanation of Models", level=2),
283
- rc.Select(blocks=blocks),
297
+ rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
284
298
  )
285
299
 
286
300
  # Append the global explanation text and section to the "other_sections" list
301
+ if global_explanation_section:
302
+ other_sections.append(global_explanation_section)
303
+
304
+ # Append the local explanation text and section to the "other_sections" list
287
305
  other_sections = other_sections + [
288
- global_explanation_section,
289
306
  local_explanation_section,
290
307
  ]
291
308
  except Exception as e:
@@ -366,3 +383,79 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
366
383
  return self.models.get(self.series_id).forecast(
367
384
  X=data_temp, periods=data_temp.shape[0]
368
385
  )[self.series_id]
386
+
387
+ @runtime_dependency(
388
+ module="automlx",
389
+ err_msg=(
390
+ "Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
391
+ ),
392
+ )
393
+ def explain_model(self):
394
+ """
395
+ Generates explanations for the model using the AutoMLx library.
396
+
397
+ Parameters
398
+ ----------
399
+ None
400
+
401
+ Returns
402
+ -------
403
+ None
404
+
405
+ Notes
406
+ -----
407
+ This function works by generating local explanations for each series in the dataset.
408
+ It uses the ``MLExplainer`` class from the AutoMLx library to generate feature attributions
409
+ for each series. The feature attributions are then stored in the ``self.local_explanation`` dictionary.
410
+
411
+ If the accuracy mode is set to AutoMLX, it uses the AutoMLx library to generate explanations.
412
+ Otherwise, it falls back to the default explanation generation method.
413
+ """
414
+ import automlx
415
+
416
+ # Loop through each series in the dataset
417
+ for s_id, data_i in self.datasets.get_data_by_series(
418
+ include_horizon=False
419
+ ).items():
420
+ try:
421
+ if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
422
+ # Use the MLExplainer class from AutoMLx to generate explanations
423
+ explainer = automlx.MLExplainer(
424
+ self.models[s_id],
425
+ self.datasets.additional_data.get_data_for_series(series_id=s_id)
426
+ .drop(self.spec.datetime_column.name, axis=1)
427
+ .head(-self.spec.horizon)
428
+ if self.spec.additional_data
429
+ else None,
430
+ pd.DataFrame(data_i[self.spec.target_column]),
431
+ task="forecasting",
432
+ )
433
+
434
+ # Generate explanations for the forecast
435
+ explanations = explainer.explain_prediction(
436
+ X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
437
+ .drop(self.spec.datetime_column.name, axis=1)
438
+ .tail(self.spec.horizon)
439
+ if self.spec.additional_data
440
+ else None,
441
+ forecast_timepoints=list(range(self.spec.horizon + 1)),
442
+ )
443
+
444
+ # Convert the explanations to a DataFrame
445
+ explanations_df = pd.concat(
446
+ [exp.to_dataframe() for exp in explanations]
447
+ )
448
+ explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
449
+ explanations_df = explanations_df.pivot(
450
+ index="row", columns="Feature", values="Attribution"
451
+ )
452
+ explanations_df = explanations_df.reset_index(drop=True)
453
+
454
+ # Store the explanations in the local_explanation dictionary
455
+ self.local_explanation[s_id] = explanations_df
456
+ else:
457
+ # Fall back to the default explanation generation method
458
+ super().explain_model()
459
+ except Exception as e:
460
+ logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.")
461
+ logger.debug(f"Full Traceback: {traceback.format_exc()}")
@@ -242,6 +242,7 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
242
242
  self.models.df_wide_numeric, series=s_id
243
243
  ),
244
244
  self.datasets.list_series_ids(),
245
+ target_category_column=self.target_cat_col
245
246
  )
246
247
  section_1 = rc.Block(
247
248
  rc.Heading("Forecast Overview", level=2),