oracle-ads 2.12.9__py3-none-any.whl → 2.12.10rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. ads/aqua/__init__.py +4 -4
  2. ads/aqua/common/enums.py +3 -0
  3. ads/aqua/common/utils.py +62 -2
  4. ads/aqua/data.py +2 -19
  5. ads/aqua/extension/finetune_handler.py +8 -14
  6. ads/aqua/extension/model_handler.py +19 -2
  7. ads/aqua/finetuning/constants.py +5 -2
  8. ads/aqua/finetuning/entities.py +64 -17
  9. ads/aqua/finetuning/finetuning.py +38 -54
  10. ads/aqua/model/entities.py +2 -1
  11. ads/aqua/model/model.py +61 -23
  12. ads/common/auth.py +9 -9
  13. ads/llm/autogen/__init__.py +2 -0
  14. ads/llm/autogen/constants.py +15 -0
  15. ads/llm/autogen/reports/__init__.py +2 -0
  16. ads/llm/autogen/reports/base.py +67 -0
  17. ads/llm/autogen/reports/data.py +103 -0
  18. ads/llm/autogen/reports/session.py +526 -0
  19. ads/llm/autogen/reports/templates/chat_box.html +13 -0
  20. ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
  21. ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
  22. ads/llm/autogen/reports/utils.py +56 -0
  23. ads/llm/autogen/v02/__init__.py +4 -0
  24. ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
  25. ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
  26. ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
  27. ads/llm/autogen/v02/loggers/__init__.py +6 -0
  28. ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
  29. ads/llm/autogen/v02/loggers/session_logger.py +580 -0
  30. ads/llm/autogen/v02/loggers/utils.py +86 -0
  31. ads/llm/autogen/v02/runtime_logging.py +163 -0
  32. ads/llm/langchain/plugins/chat_models/oci_data_science.py +12 -11
  33. ads/model/__init__.py +11 -13
  34. ads/model/artifact.py +47 -8
  35. ads/model/extractor/embedding_onnx_extractor.py +80 -0
  36. ads/model/framework/embedding_onnx_model.py +438 -0
  37. ads/model/generic_model.py +26 -24
  38. ads/model/model_metadata.py +8 -7
  39. ads/opctl/config/merger.py +13 -14
  40. ads/opctl/operator/common/operator_config.py +4 -4
  41. ads/opctl/operator/lowcode/common/transformations.py +12 -5
  42. ads/opctl/operator/lowcode/common/utils.py +11 -5
  43. ads/opctl/operator/lowcode/forecast/const.py +2 -0
  44. ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
  45. ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
  46. ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
  47. ads/opctl/operator/lowcode/forecast/model/base_model.py +61 -14
  48. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
  49. ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
  50. ads/opctl/operator/lowcode/forecast/schema.yaml +13 -0
  51. ads/opctl/operator/lowcode/forecast/utils.py +4 -3
  52. ads/telemetry/base.py +18 -11
  53. ads/telemetry/client.py +33 -13
  54. ads/templates/schemas/openapi.json +1740 -0
  55. ads/templates/score_embedding_onnx.jinja2 +202 -0
  56. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/METADATA +7 -8
  57. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/RECORD +60 -39
  58. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/LICENSE.txt +0 -0
  59. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/WHEEL +0 -0
  60. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/entry_points.txt +0 -0
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
 
@@ -11,15 +10,16 @@ from dataclasses import dataclass
11
10
  from typing import Any, Dict, List
12
11
 
13
12
  from ads.common.serializer import DataClassSerializable
14
-
15
- from ads.opctl.operator.common.utils import OperatorValidator
16
13
  from ads.opctl.operator.common.errors import InvalidParameterError
14
+ from ads.opctl.operator.common.utils import OperatorValidator
15
+
17
16
 
18
17
  @dataclass(repr=True)
19
18
  class InputData(DataClassSerializable):
20
19
  """Class representing operator specification input data details."""
21
20
 
22
21
  connect_args: Dict = None
22
+ data: Dict = None
23
23
  format: str = None
24
24
  columns: List[str] = None
25
25
  url: str = None
@@ -1,10 +1,11 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
6
  from abc import ABC
7
7
 
8
+ import numpy as np
8
9
  import pandas as pd
9
10
 
10
11
  from ads.opctl import logger
@@ -209,18 +210,24 @@ class Transformations(ABC):
209
210
  -------
210
211
  A new Pandas DataFrame with treated outliears.
211
212
  """
212
- df["z_score"] = (
213
+ return df
214
+ df["__z_score"] = (
213
215
  df[self.target_column_name]
214
216
  .groupby(DataColumns.Series)
215
217
  .transform(lambda x: (x - x.mean()) / x.std())
216
218
  )
217
- outliers_mask = df["z_score"].abs() > 3
219
+ outliers_mask = df["__z_score"].abs() > 3
220
+
221
+ if df[self.target_column_name].dtype == np.int:
222
+ df[self.target_column_name].astype(np.float)
223
+
218
224
  df.loc[outliers_mask, self.target_column_name] = (
219
225
  df[self.target_column_name]
220
226
  .groupby(DataColumns.Series)
221
- .transform(lambda x: x.mean())
227
+ .transform(lambda x: np.median(x))
222
228
  )
223
- return df.drop("z_score", axis=1)
229
+ df_ret = df.drop("__z_score", axis=1)
230
+ return df_ret
224
231
 
225
232
  def _check_historical_dataset(self, df):
226
233
  expected_names = [self.target_column_name, self.dt_column_name] + (
@@ -1,6 +1,6 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
6
  import logging
@@ -40,6 +40,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
40
40
  if data_spec is None:
41
41
  raise InvalidParameterError("No details provided for this data source.")
42
42
  filename = data_spec.url
43
+ data = data_spec.data
43
44
  format = data_spec.format
44
45
  columns = data_spec.columns
45
46
  connect_args = data_spec.connect_args
@@ -51,9 +52,12 @@ def load_data(data_spec, storage_options=None, **kwargs):
51
52
  default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
52
53
  )
53
54
  if vault_secret_id is not None and connect_args is None:
54
- connect_args = dict()
55
+ connect_args = {}
55
56
 
56
- if filename is not None:
57
+ if data is not None:
58
+ if format == "spark":
59
+ data = data.toPandas()
60
+ elif filename is not None:
57
61
  if not format:
58
62
  _, format = os.path.splitext(filename)
59
63
  format = format[1:]
@@ -98,7 +102,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
98
102
  except Exception as e:
99
103
  raise Exception(
100
104
  f"Could not retrieve database credentials from vault {vault_secret_id}: {e}"
101
- )
105
+ ) from e
102
106
 
103
107
  con = oracledb.connect(**connect_args)
104
108
  if table_name is not None:
@@ -122,6 +126,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
122
126
 
123
127
 
124
128
  def write_data(data, filename, format, storage_options, index=False, **kwargs):
129
+ disable_print()
125
130
  if not format:
126
131
  _, format = os.path.splitext(filename)
127
132
  format = format[1:]
@@ -130,7 +135,8 @@ def write_data(data, filename, format, storage_options, index=False, **kwargs):
130
135
  return call_pandas_fsspec(
131
136
  write_fn, filename, index=index, storage_options=storage_options, **kwargs
132
137
  )
133
- raise OperatorYamlContentError(
138
+ enable_print()
139
+ raise InvalidParameterError(
134
140
  f"The format {format} is not currently supported for writing data. Please change the format parameter for the data output: {filename} ."
135
141
  )
136
142
 
@@ -27,10 +27,12 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
27
27
  HIGH_ACCURACY = "HIGH_ACCURACY"
28
28
  BALANCED = "BALANCED"
29
29
  FAST_APPROXIMATE = "FAST_APPROXIMATE"
30
+ AUTOMLX = "AUTOMLX"
30
31
  ratio = {}
31
32
  ratio[HIGH_ACCURACY] = 1 # 100 % data used for generating explanations
32
33
  ratio[BALANCED] = 0.5 # 50 % data used for generating explanations
33
34
  ratio[FAST_APPROXIMATE] = 0 # constant
35
+ ratio[AUTOMLX] = 0 # constant
34
36
 
35
37
 
36
38
  class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
@@ -164,11 +164,11 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
164
164
  blocks = [
165
165
  rc.Html(
166
166
  m.summary().as_html(),
167
- label=s_id,
167
+ label=s_id if self.target_cat_col else None,
168
168
  )
169
169
  for i, (s_id, m) in enumerate(self.models.items())
170
170
  ]
171
- sec5 = rc.Select(blocks=blocks)
171
+ sec5 = rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
172
172
  all_sections = [sec5_text, sec5]
173
173
 
174
174
  if self.spec.generate_explanations:
@@ -188,6 +188,21 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
188
188
  axis=1,
189
189
  )
190
190
  )
191
+ aggregate_local_explanations = pd.DataFrame()
192
+ for s_id, local_ex_df in self.local_explanation.items():
193
+ local_ex_df_copy = local_ex_df.copy()
194
+ local_ex_df_copy["Series"] = s_id
195
+ aggregate_local_explanations = pd.concat(
196
+ [aggregate_local_explanations, local_ex_df_copy], axis=0
197
+ )
198
+ self.formatted_local_explanation = aggregate_local_explanations
199
+
200
+ if not self.target_cat_col:
201
+ self.formatted_global_explanation = self.formatted_global_explanation.rename(
202
+ {"Series 1": self.original_target_column},
203
+ axis=1,
204
+ )
205
+ self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
191
206
 
192
207
  # Create a markdown section for the global explainability
193
208
  global_explanation_section = rc.Block(
@@ -198,26 +213,17 @@ class ArimaOperatorModel(ForecastOperatorBaseModel):
198
213
  rc.DataTable(self.formatted_global_explanation, index=True),
199
214
  )
200
215
 
201
- aggregate_local_explanations = pd.DataFrame()
202
- for s_id, local_ex_df in self.local_explanation.items():
203
- local_ex_df_copy = local_ex_df.copy()
204
- local_ex_df_copy["Series"] = s_id
205
- aggregate_local_explanations = pd.concat(
206
- [aggregate_local_explanations, local_ex_df_copy], axis=0
207
- )
208
- self.formatted_local_explanation = aggregate_local_explanations
209
-
210
216
  blocks = [
211
217
  rc.DataTable(
212
218
  local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
213
- label=s_id,
219
+ label=s_id if self.target_cat_col else None,
214
220
  index=True,
215
221
  )
216
222
  for s_id, local_ex_df in self.local_explanation.items()
217
223
  ]
218
224
  local_explanation_section = rc.Block(
219
225
  rc.Heading("Local Explanation of Models", level=2),
220
- rc.Select(blocks=blocks),
226
+ rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
221
227
  )
222
228
 
223
229
  # Append the global explanation text and section to the "all_sections" list
@@ -17,6 +17,7 @@ from ads.opctl.operator.lowcode.common.utils import (
17
17
  from ads.opctl.operator.lowcode.forecast.const import (
18
18
  AUTOMLX_METRIC_MAP,
19
19
  ForecastOutputColumns,
20
+ SpeedAccuracyMode,
20
21
  SupportedModels,
21
22
  )
22
23
  from ads.opctl.operator.lowcode.forecast.utils import _label_encode_dataframe
@@ -81,22 +82,6 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
81
82
 
82
83
  from automlx import Pipeline, init
83
84
 
84
- cpu_count = os.cpu_count()
85
- try:
86
- if cpu_count < 4:
87
- engine = "local"
88
- engine_opts = None
89
- else:
90
- engine = "ray"
91
- engine_opts = ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
92
- init(
93
- engine=engine,
94
- engine_opts=engine_opts,
95
- loglevel=logging.CRITICAL,
96
- )
97
- except Exception as e:
98
- logger.info(f"Error. Has Ray already been initialized? Skipping. {e}")
99
-
100
85
  full_data_dict = self.datasets.get_data_by_series()
101
86
 
102
87
  self.models = {}
@@ -112,6 +97,26 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
112
97
  # Clean up kwargs for pass through
113
98
  model_kwargs_cleaned, time_budget = self.set_kwargs()
114
99
 
100
+ cpu_count = os.cpu_count()
101
+ try:
102
+ engine_type = model_kwargs_cleaned.pop(
103
+ "engine", "local" if cpu_count <= 4 else "ray"
104
+ )
105
+ engine_opts = (
106
+ None
107
+ if engine_type == "local"
108
+ else ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
109
+ )
110
+ init(
111
+ engine=engine_type,
112
+ engine_opts=engine_opts,
113
+ loglevel=logging.CRITICAL,
114
+ )
115
+ except Exception as e:
116
+ logger.info(
117
+ f"Error initializing automlx. Has Ray already been initialized? Skipping. {e}"
118
+ )
119
+
115
120
  for s_id, df in full_data_dict.items():
116
121
  try:
117
122
  logger.debug(f"Running automlx on series {s_id}")
@@ -223,6 +228,8 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
223
228
  selected_models.items(), columns=["series_id", "best_selected_model"]
224
229
  )
225
230
  selected_df = selected_models_df["best_selected_model"].apply(pd.Series)
231
+ if not self.target_cat_col:
232
+ selected_df = selected_df.drop("series_id", axis=1)
226
233
  selected_models_section = rc.Block(
227
234
  rc.Heading("Selected Models Overview", level=2),
228
235
  rc.Text(
@@ -239,27 +246,18 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
239
246
  # If the key is present, call the "explain_model" method
240
247
  self.explain_model()
241
248
 
242
- # Convert the global explanation data to a DataFrame
243
- global_explanation_df = pd.DataFrame(self.global_explanation)
249
+ global_explanation_section = None
250
+ if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
251
+ # Convert the global explanation data to a DataFrame
252
+ global_explanation_df = pd.DataFrame(self.global_explanation)
244
253
 
245
- self.formatted_global_explanation = (
246
- global_explanation_df / global_explanation_df.sum(axis=0) * 100
247
- )
248
- self.formatted_global_explanation = (
249
- self.formatted_global_explanation.rename(
254
+ self.formatted_global_explanation = (
255
+ global_explanation_df / global_explanation_df.sum(axis=0) * 100
256
+ )
257
+ self.formatted_global_explanation = self.formatted_global_explanation.rename(
250
258
  {self.spec.datetime_column.name: ForecastOutputColumns.DATE},
251
259
  axis=1,
252
260
  )
253
- )
254
-
255
- # Create a markdown section for the global explainability
256
- global_explanation_section = rc.Block(
257
- rc.Heading("Global Explanation of Models", level=2),
258
- rc.Text(
259
- "The following tables provide the feature attribution for the global explainability."
260
- ),
261
- rc.DataTable(self.formatted_global_explanation, index=True),
262
- )
263
261
 
264
262
  aggregate_local_explanations = pd.DataFrame()
265
263
  for s_id, local_ex_df in self.local_explanation.items():
@@ -270,22 +268,41 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
270
268
  )
271
269
  self.formatted_local_explanation = aggregate_local_explanations
272
270
 
271
+ if not self.target_cat_col:
272
+ self.formatted_global_explanation = self.formatted_global_explanation.rename(
273
+ {"Series 1": self.original_target_column},
274
+ axis=1,
275
+ )
276
+ self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
277
+
278
+ # Create a markdown section for the global explainability
279
+ global_explanation_section = rc.Block(
280
+ rc.Heading("Global Explanation of Models", level=2),
281
+ rc.Text(
282
+ "The following tables provide the feature attribution for the global explainability."
283
+ ),
284
+ rc.DataTable(self.formatted_global_explanation, index=True),
285
+ )
286
+
273
287
  blocks = [
274
288
  rc.DataTable(
275
289
  local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
276
- label=s_id,
290
+ label=s_id if self.target_cat_col else None,
277
291
  index=True,
278
292
  )
279
293
  for s_id, local_ex_df in self.local_explanation.items()
280
294
  ]
281
295
  local_explanation_section = rc.Block(
282
296
  rc.Heading("Local Explanation of Models", level=2),
283
- rc.Select(blocks=blocks),
297
+ rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
284
298
  )
285
299
 
286
300
  # Append the global explanation text and section to the "other_sections" list
301
+ if global_explanation_section:
302
+ other_sections.append(global_explanation_section)
303
+
304
+ # Append the local explanation text and section to the "other_sections" list
287
305
  other_sections = other_sections + [
288
- global_explanation_section,
289
306
  local_explanation_section,
290
307
  ]
291
308
  except Exception as e:
@@ -366,3 +383,79 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
366
383
  return self.models.get(self.series_id).forecast(
367
384
  X=data_temp, periods=data_temp.shape[0]
368
385
  )[self.series_id]
386
+
387
+ @runtime_dependency(
388
+ module="automlx",
389
+ err_msg=(
390
+ "Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
391
+ ),
392
+ )
393
+ def explain_model(self):
394
+ """
395
+ Generates explanations for the model using the AutoMLx library.
396
+
397
+ Parameters
398
+ ----------
399
+ None
400
+
401
+ Returns
402
+ -------
403
+ None
404
+
405
+ Notes
406
+ -----
407
+ This function works by generating local explanations for each series in the dataset.
408
+ It uses the ``MLExplainer`` class from the AutoMLx library to generate feature attributions
409
+ for each series. The feature attributions are then stored in the ``self.local_explanation`` dictionary.
410
+
411
+ If the accuracy mode is set to AutoMLX, it uses the AutoMLx library to generate explanations.
412
+ Otherwise, it falls back to the default explanation generation method.
413
+ """
414
+ import automlx
415
+
416
+ # Loop through each series in the dataset
417
+ for s_id, data_i in self.datasets.get_data_by_series(
418
+ include_horizon=False
419
+ ).items():
420
+ try:
421
+ if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
422
+ # Use the MLExplainer class from AutoMLx to generate explanations
423
+ explainer = automlx.MLExplainer(
424
+ self.models[s_id],
425
+ self.datasets.additional_data.get_data_for_series(series_id=s_id)
426
+ .drop(self.spec.datetime_column.name, axis=1)
427
+ .head(-self.spec.horizon)
428
+ if self.spec.additional_data
429
+ else None,
430
+ pd.DataFrame(data_i[self.spec.target_column]),
431
+ task="forecasting",
432
+ )
433
+
434
+ # Generate explanations for the forecast
435
+ explanations = explainer.explain_prediction(
436
+ X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
437
+ .drop(self.spec.datetime_column.name, axis=1)
438
+ .tail(self.spec.horizon)
439
+ if self.spec.additional_data
440
+ else None,
441
+ forecast_timepoints=list(range(self.spec.horizon + 1)),
442
+ )
443
+
444
+ # Convert the explanations to a DataFrame
445
+ explanations_df = pd.concat(
446
+ [exp.to_dataframe() for exp in explanations]
447
+ )
448
+ explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
449
+ explanations_df = explanations_df.pivot(
450
+ index="row", columns="Feature", values="Attribution"
451
+ )
452
+ explanations_df = explanations_df.reset_index(drop=True)
453
+
454
+ # Store the explanations in the local_explanation dictionary
455
+ self.local_explanation[s_id] = explanations_df
456
+ else:
457
+ # Fall back to the default explanation generation method
458
+ super().explain_model()
459
+ except Exception as e:
460
+ logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.")
461
+ logger.debug(f"Full Traceback: {traceback.format_exc()}")
@@ -242,6 +242,7 @@ class AutoTSOperatorModel(ForecastOperatorBaseModel):
242
242
  self.models.df_wide_numeric, series=s_id
243
243
  ),
244
244
  self.datasets.list_series_ids(),
245
+ target_category_column=self.target_cat_col
245
246
  )
246
247
  section_1 = rc.Block(
247
248
  rc.Heading("Forecast Overview", level=2),
@@ -28,6 +28,7 @@ from ads.opctl.operator.lowcode.common.utils import (
28
28
  seconds_to_datetime,
29
29
  write_data,
30
30
  )
31
+ from ads.opctl.operator.lowcode.common.const import DataColumns
31
32
  from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
32
33
  from ads.opctl.operator.lowcode.forecast.utils import (
33
34
  _build_metrics_df,
@@ -43,11 +44,12 @@ from ads.opctl.operator.lowcode.forecast.utils import (
43
44
 
44
45
  from ..const import (
45
46
  AUTO_SELECT,
47
+ BACKTEST_REPORT_NAME,
46
48
  SUMMARY_METRICS_HORIZON_LIMIT,
47
49
  SpeedAccuracyMode,
48
50
  SupportedMetrics,
49
51
  SupportedModels,
50
- BACKTEST_REPORT_NAME
52
+ BACKTEST_REPORT_NAME,
51
53
  )
52
54
  from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
53
55
  from .forecast_datasets import ForecastDatasets
@@ -69,7 +71,7 @@ class ForecastOperatorBaseModel(ABC):
69
71
  self.config: ForecastOperatorConfig = config
70
72
  self.spec: ForecastOperatorSpec = config.spec
71
73
  self.datasets: ForecastDatasets = datasets
72
-
74
+ self.target_cat_col = self.spec.target_category_columns
73
75
  self.full_data_dict = datasets.get_data_by_series()
74
76
 
75
77
  self.test_eval_metrics = None
@@ -124,6 +126,9 @@ class ForecastOperatorBaseModel(ABC):
124
126
 
125
127
  if self.spec.generate_report or self.spec.generate_metrics:
126
128
  self.eval_metrics = self.generate_train_metrics()
129
+ if not self.target_cat_col:
130
+ self.eval_metrics.rename({"Series 1": self.original_target_column},
131
+ axis=1, inplace=True)
127
132
 
128
133
  if self.spec.test_data:
129
134
  try:
@@ -134,6 +139,9 @@ class ForecastOperatorBaseModel(ABC):
134
139
  ) = self._test_evaluate_metrics(
135
140
  elapsed_time=elapsed_time,
136
141
  )
142
+ if not self.target_cat_col:
143
+ self.test_eval_metrics.rename({"Series 1": self.original_target_column},
144
+ axis=1, inplace=True)
137
145
  except Exception:
138
146
  logger.warn("Unable to generate Test Metrics.")
139
147
  logger.debug(f"Full Traceback: {traceback.format_exc()}")
@@ -179,7 +187,7 @@ class ForecastOperatorBaseModel(ABC):
179
187
  first_5_rows_blocks = [
180
188
  rc.DataTable(
181
189
  df.head(5),
182
- label=s_id,
190
+ label=s_id if self.target_cat_col else None,
183
191
  index=True,
184
192
  )
185
193
  for s_id, df in self.full_data_dict.items()
@@ -188,7 +196,7 @@ class ForecastOperatorBaseModel(ABC):
188
196
  last_5_rows_blocks = [
189
197
  rc.DataTable(
190
198
  df.tail(5),
191
- label=s_id,
199
+ label=s_id if self.target_cat_col else None,
192
200
  index=True,
193
201
  )
194
202
  for s_id, df in self.full_data_dict.items()
@@ -197,7 +205,7 @@ class ForecastOperatorBaseModel(ABC):
197
205
  data_summary_blocks = [
198
206
  rc.DataTable(
199
207
  df.describe(),
200
- label=s_id,
208
+ label=s_id if self.target_cat_col else None,
201
209
  index=True,
202
210
  )
203
211
  for s_id, df in self.full_data_dict.items()
@@ -215,17 +223,17 @@ class ForecastOperatorBaseModel(ABC):
215
223
  rc.Block(
216
224
  first_10_title,
217
225
  # series_subtext,
218
- rc.Select(blocks=first_5_rows_blocks),
226
+ rc.Select(blocks=first_5_rows_blocks) if self.target_cat_col else first_5_rows_blocks[0],
219
227
  ),
220
228
  rc.Block(
221
229
  last_10_title,
222
230
  # series_subtext,
223
- rc.Select(blocks=last_5_rows_blocks),
231
+ rc.Select(blocks=last_5_rows_blocks) if self.target_cat_col else last_5_rows_blocks[0],
224
232
  ),
225
233
  rc.Block(
226
234
  summary_title,
227
235
  # series_subtext,
228
- rc.Select(blocks=data_summary_blocks),
236
+ rc.Select(blocks=data_summary_blocks) if self.target_cat_col else data_summary_blocks[0],
229
237
  ),
230
238
  rc.Separator(),
231
239
  )
@@ -259,7 +267,11 @@ class ForecastOperatorBaseModel(ABC):
259
267
  output_dir = self.spec.output_directory.url
260
268
  file_path = f"{output_dir}/{BACKTEST_REPORT_NAME}"
261
269
  if self.spec.model == AUTO_SELECT:
262
- backtest_sections.append(rc.Heading("Auto-Select Backtesting and Performance Metrics", level=2))
270
+ backtest_sections.append(
271
+ rc.Heading(
272
+ "Auto-Select Backtesting and Performance Metrics", level=2
273
+ )
274
+ )
263
275
  if not os.path.exists(file_path):
264
276
  failure_msg = rc.Text(
265
277
  "auto-select could not be executed. Please check the "
@@ -268,15 +280,23 @@ class ForecastOperatorBaseModel(ABC):
268
280
  backtest_sections.append(failure_msg)
269
281
  else:
270
282
  backtest_stats = pd.read_csv(file_path)
271
- model_metric_map = backtest_stats.drop(columns=['metric', 'backtest'])
272
- average_dict = {k: round(v, 4) for k, v in model_metric_map.mean().to_dict().items()}
283
+ model_metric_map = backtest_stats.drop(
284
+ columns=["metric", "backtest"]
285
+ )
286
+ average_dict = {
287
+ k: round(v, 4)
288
+ for k, v in model_metric_map.mean().to_dict().items()
289
+ }
273
290
  best_model = min(average_dict, key=average_dict.get)
274
291
  summary_text = rc.Text(
275
292
  f"Overall, the average {self.spec.metric} scores for the models are {average_dict}, with"
276
- f" {best_model} being identified as the top-performing model during backtesting.")
293
+ f" {best_model} being identified as the top-performing model during backtesting."
294
+ )
277
295
  backtest_table = rc.DataTable(backtest_stats, index=True)
278
296
  liner_plot = get_auto_select_plot(backtest_stats)
279
- backtest_sections.extend([backtest_table, summary_text, liner_plot])
297
+ backtest_sections.extend(
298
+ [backtest_table, summary_text, liner_plot]
299
+ )
280
300
 
281
301
  forecast_plots = []
282
302
  if len(self.forecast_output.list_series_ids()) > 0:
@@ -288,6 +308,7 @@ class ForecastOperatorBaseModel(ABC):
288
308
  horizon=self.spec.horizon,
289
309
  test_data=test_data,
290
310
  ci_interval_width=self.spec.confidence_interval_width,
311
+ target_category_column=self.target_cat_col
291
312
  )
292
313
  if (
293
314
  series_name is not None
@@ -301,7 +322,14 @@ class ForecastOperatorBaseModel(ABC):
301
322
  forecast_plots = [forecast_text, forecast_sec]
302
323
 
303
324
  yaml_appendix_title = rc.Heading("Reference: YAML File", level=2)
304
- yaml_appendix = rc.Yaml(self.config.to_dict())
325
+ config_dict = self.config.to_dict()
326
+ # pop the data incase it isn't json serializable
327
+ config_dict["spec"]["historical_data"].pop("data")
328
+ if config_dict["spec"].get("additional_data"):
329
+ config_dict["spec"]["additional_data"].pop("data")
330
+ if config_dict["spec"].get("test_data"):
331
+ config_dict["spec"]["test_data"].pop("data")
332
+ yaml_appendix = rc.Yaml(config_dict)
305
333
  report_sections = (
306
334
  [summary]
307
335
  + backtest_sections
@@ -463,6 +491,7 @@ class ForecastOperatorBaseModel(ABC):
463
491
  f2.write(f1.read())
464
492
 
465
493
  # forecast csv report
494
+ result_df = result_df if self.target_cat_col else result_df.drop(DataColumns.Series, axis=1)
466
495
  write_data(
467
496
  data=result_df,
468
497
  filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
@@ -637,6 +666,13 @@ class ForecastOperatorBaseModel(ABC):
637
666
  storage_options=storage_options,
638
667
  )
639
668
 
669
+ def _validate_automlx_explanation_mode(self):
670
+ if self.spec.model != SupportedModels.AutoMLX and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
671
+ raise ValueError(
672
+ "AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
673
+ "Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
674
+ )
675
+
640
676
  @runtime_dependency(
641
677
  module="shap",
642
678
  err_msg=(
@@ -665,6 +701,9 @@ class ForecastOperatorBaseModel(ABC):
665
701
  )
666
702
  ratio = SpeedAccuracyMode.ratio[self.spec.explanations_accuracy_mode]
667
703
 
704
+ # validate the automlx mode is use for automlx model
705
+ self._validate_automlx_explanation_mode()
706
+
668
707
  for s_id, data_i in self.datasets.get_data_by_series(
669
708
  include_horizon=False
670
709
  ).items():
@@ -699,6 +738,14 @@ class ForecastOperatorBaseModel(ABC):
699
738
  logger.warn(
700
739
  "No explanations generated. Ensure that additional data has been provided."
701
740
  )
741
+ elif (
742
+ self.spec.model == SupportedModels.AutoMLX
743
+ and self.spec.explanations_accuracy_mode
744
+ == SpeedAccuracyMode.AUTOMLX
745
+ ):
746
+ logger.warning(
747
+ "Global explanations not available for AutoMLX models with inherent explainability"
748
+ )
702
749
  else:
703
750
  self.global_explanation[s_id] = dict(
704
751
  zip(