oracle-ads 2.12.9__py3-none-any.whl → 2.12.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. ads/aqua/__init__.py +4 -3
  2. ads/aqua/app.py +28 -16
  3. ads/aqua/client/__init__.py +3 -0
  4. ads/aqua/client/client.py +799 -0
  5. ads/aqua/common/enums.py +3 -0
  6. ads/aqua/common/utils.py +62 -2
  7. ads/aqua/data.py +2 -19
  8. ads/aqua/evaluation/evaluation.py +20 -12
  9. ads/aqua/extension/aqua_ws_msg_handler.py +14 -7
  10. ads/aqua/extension/base_handler.py +12 -9
  11. ads/aqua/extension/finetune_handler.py +8 -14
  12. ads/aqua/extension/model_handler.py +24 -2
  13. ads/aqua/finetuning/constants.py +5 -2
  14. ads/aqua/finetuning/entities.py +67 -17
  15. ads/aqua/finetuning/finetuning.py +69 -54
  16. ads/aqua/model/entities.py +3 -1
  17. ads/aqua/model/model.py +196 -98
  18. ads/aqua/modeldeployment/deployment.py +22 -10
  19. ads/cli.py +16 -8
  20. ads/common/auth.py +9 -9
  21. ads/llm/autogen/__init__.py +2 -0
  22. ads/llm/autogen/constants.py +15 -0
  23. ads/llm/autogen/reports/__init__.py +2 -0
  24. ads/llm/autogen/reports/base.py +67 -0
  25. ads/llm/autogen/reports/data.py +103 -0
  26. ads/llm/autogen/reports/session.py +526 -0
  27. ads/llm/autogen/reports/templates/chat_box.html +13 -0
  28. ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
  29. ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
  30. ads/llm/autogen/reports/utils.py +56 -0
  31. ads/llm/autogen/v02/__init__.py +4 -0
  32. ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
  33. ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
  34. ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
  35. ads/llm/autogen/v02/loggers/__init__.py +6 -0
  36. ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
  37. ads/llm/autogen/v02/loggers/session_logger.py +580 -0
  38. ads/llm/autogen/v02/loggers/utils.py +86 -0
  39. ads/llm/autogen/v02/runtime_logging.py +163 -0
  40. ads/llm/langchain/plugins/chat_models/oci_data_science.py +12 -11
  41. ads/model/__init__.py +11 -13
  42. ads/model/artifact.py +47 -8
  43. ads/model/extractor/embedding_onnx_extractor.py +80 -0
  44. ads/model/framework/embedding_onnx_model.py +438 -0
  45. ads/model/generic_model.py +26 -24
  46. ads/model/model_metadata.py +8 -7
  47. ads/opctl/config/merger.py +13 -14
  48. ads/opctl/operator/common/operator_config.py +4 -4
  49. ads/opctl/operator/lowcode/common/transformations.py +50 -8
  50. ads/opctl/operator/lowcode/common/utils.py +22 -6
  51. ads/opctl/operator/lowcode/forecast/__main__.py +10 -0
  52. ads/opctl/operator/lowcode/forecast/const.py +2 -0
  53. ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
  54. ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
  55. ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
  56. ads/opctl/operator/lowcode/forecast/model/base_model.py +61 -14
  57. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +1 -1
  58. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
  59. ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
  60. ads/opctl/operator/lowcode/forecast/operator_config.py +31 -0
  61. ads/opctl/operator/lowcode/forecast/schema.yaml +76 -0
  62. ads/opctl/operator/lowcode/forecast/utils.py +4 -3
  63. ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
  64. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +233 -0
  65. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +238 -0
  66. ads/telemetry/base.py +18 -11
  67. ads/telemetry/client.py +33 -13
  68. ads/templates/schemas/openapi.json +1740 -0
  69. ads/templates/score_embedding_onnx.jinja2 +202 -0
  70. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/METADATA +9 -8
  71. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/RECORD +74 -48
  72. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/LICENSE.txt +0 -0
  73. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/WHEEL +0 -0
  74. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/entry_points.txt +0 -0
@@ -28,6 +28,7 @@ from ads.opctl.operator.lowcode.common.utils import (
28
28
  seconds_to_datetime,
29
29
  write_data,
30
30
  )
31
+ from ads.opctl.operator.lowcode.common.const import DataColumns
31
32
  from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
32
33
  from ads.opctl.operator.lowcode.forecast.utils import (
33
34
  _build_metrics_df,
@@ -43,11 +44,12 @@ from ads.opctl.operator.lowcode.forecast.utils import (
43
44
 
44
45
  from ..const import (
45
46
  AUTO_SELECT,
47
+ BACKTEST_REPORT_NAME,
46
48
  SUMMARY_METRICS_HORIZON_LIMIT,
47
49
  SpeedAccuracyMode,
48
50
  SupportedMetrics,
49
51
  SupportedModels,
50
- BACKTEST_REPORT_NAME
52
+ BACKTEST_REPORT_NAME,
51
53
  )
52
54
  from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
53
55
  from .forecast_datasets import ForecastDatasets
@@ -69,7 +71,7 @@ class ForecastOperatorBaseModel(ABC):
69
71
  self.config: ForecastOperatorConfig = config
70
72
  self.spec: ForecastOperatorSpec = config.spec
71
73
  self.datasets: ForecastDatasets = datasets
72
-
74
+ self.target_cat_col = self.spec.target_category_columns
73
75
  self.full_data_dict = datasets.get_data_by_series()
74
76
 
75
77
  self.test_eval_metrics = None
@@ -124,6 +126,9 @@ class ForecastOperatorBaseModel(ABC):
124
126
 
125
127
  if self.spec.generate_report or self.spec.generate_metrics:
126
128
  self.eval_metrics = self.generate_train_metrics()
129
+ if not self.target_cat_col:
130
+ self.eval_metrics.rename({"Series 1": self.original_target_column},
131
+ axis=1, inplace=True)
127
132
 
128
133
  if self.spec.test_data:
129
134
  try:
@@ -134,6 +139,9 @@ class ForecastOperatorBaseModel(ABC):
134
139
  ) = self._test_evaluate_metrics(
135
140
  elapsed_time=elapsed_time,
136
141
  )
142
+ if not self.target_cat_col:
143
+ self.test_eval_metrics.rename({"Series 1": self.original_target_column},
144
+ axis=1, inplace=True)
137
145
  except Exception:
138
146
  logger.warn("Unable to generate Test Metrics.")
139
147
  logger.debug(f"Full Traceback: {traceback.format_exc()}")
@@ -179,7 +187,7 @@ class ForecastOperatorBaseModel(ABC):
179
187
  first_5_rows_blocks = [
180
188
  rc.DataTable(
181
189
  df.head(5),
182
- label=s_id,
190
+ label=s_id if self.target_cat_col else None,
183
191
  index=True,
184
192
  )
185
193
  for s_id, df in self.full_data_dict.items()
@@ -188,7 +196,7 @@ class ForecastOperatorBaseModel(ABC):
188
196
  last_5_rows_blocks = [
189
197
  rc.DataTable(
190
198
  df.tail(5),
191
- label=s_id,
199
+ label=s_id if self.target_cat_col else None,
192
200
  index=True,
193
201
  )
194
202
  for s_id, df in self.full_data_dict.items()
@@ -197,7 +205,7 @@ class ForecastOperatorBaseModel(ABC):
197
205
  data_summary_blocks = [
198
206
  rc.DataTable(
199
207
  df.describe(),
200
- label=s_id,
208
+ label=s_id if self.target_cat_col else None,
201
209
  index=True,
202
210
  )
203
211
  for s_id, df in self.full_data_dict.items()
@@ -215,17 +223,17 @@ class ForecastOperatorBaseModel(ABC):
215
223
  rc.Block(
216
224
  first_10_title,
217
225
  # series_subtext,
218
- rc.Select(blocks=first_5_rows_blocks),
226
+ rc.Select(blocks=first_5_rows_blocks) if self.target_cat_col else first_5_rows_blocks[0],
219
227
  ),
220
228
  rc.Block(
221
229
  last_10_title,
222
230
  # series_subtext,
223
- rc.Select(blocks=last_5_rows_blocks),
231
+ rc.Select(blocks=last_5_rows_blocks) if self.target_cat_col else last_5_rows_blocks[0],
224
232
  ),
225
233
  rc.Block(
226
234
  summary_title,
227
235
  # series_subtext,
228
- rc.Select(blocks=data_summary_blocks),
236
+ rc.Select(blocks=data_summary_blocks) if self.target_cat_col else data_summary_blocks[0],
229
237
  ),
230
238
  rc.Separator(),
231
239
  )
@@ -259,7 +267,11 @@ class ForecastOperatorBaseModel(ABC):
259
267
  output_dir = self.spec.output_directory.url
260
268
  file_path = f"{output_dir}/{BACKTEST_REPORT_NAME}"
261
269
  if self.spec.model == AUTO_SELECT:
262
- backtest_sections.append(rc.Heading("Auto-Select Backtesting and Performance Metrics", level=2))
270
+ backtest_sections.append(
271
+ rc.Heading(
272
+ "Auto-Select Backtesting and Performance Metrics", level=2
273
+ )
274
+ )
263
275
  if not os.path.exists(file_path):
264
276
  failure_msg = rc.Text(
265
277
  "auto-select could not be executed. Please check the "
@@ -268,15 +280,23 @@ class ForecastOperatorBaseModel(ABC):
268
280
  backtest_sections.append(failure_msg)
269
281
  else:
270
282
  backtest_stats = pd.read_csv(file_path)
271
- model_metric_map = backtest_stats.drop(columns=['metric', 'backtest'])
272
- average_dict = {k: round(v, 4) for k, v in model_metric_map.mean().to_dict().items()}
283
+ model_metric_map = backtest_stats.drop(
284
+ columns=["metric", "backtest"]
285
+ )
286
+ average_dict = {
287
+ k: round(v, 4)
288
+ for k, v in model_metric_map.mean().to_dict().items()
289
+ }
273
290
  best_model = min(average_dict, key=average_dict.get)
274
291
  summary_text = rc.Text(
275
292
  f"Overall, the average {self.spec.metric} scores for the models are {average_dict}, with"
276
- f" {best_model} being identified as the top-performing model during backtesting.")
293
+ f" {best_model} being identified as the top-performing model during backtesting."
294
+ )
277
295
  backtest_table = rc.DataTable(backtest_stats, index=True)
278
296
  liner_plot = get_auto_select_plot(backtest_stats)
279
- backtest_sections.extend([backtest_table, summary_text, liner_plot])
297
+ backtest_sections.extend(
298
+ [backtest_table, summary_text, liner_plot]
299
+ )
280
300
 
281
301
  forecast_plots = []
282
302
  if len(self.forecast_output.list_series_ids()) > 0:
@@ -288,6 +308,7 @@ class ForecastOperatorBaseModel(ABC):
288
308
  horizon=self.spec.horizon,
289
309
  test_data=test_data,
290
310
  ci_interval_width=self.spec.confidence_interval_width,
311
+ target_category_column=self.target_cat_col
291
312
  )
292
313
  if (
293
314
  series_name is not None
@@ -301,7 +322,14 @@ class ForecastOperatorBaseModel(ABC):
301
322
  forecast_plots = [forecast_text, forecast_sec]
302
323
 
303
324
  yaml_appendix_title = rc.Heading("Reference: YAML File", level=2)
304
- yaml_appendix = rc.Yaml(self.config.to_dict())
325
+ config_dict = self.config.to_dict()
326
+ # pop the data incase it isn't json serializable
327
+ config_dict["spec"]["historical_data"].pop("data")
328
+ if config_dict["spec"].get("additional_data"):
329
+ config_dict["spec"]["additional_data"].pop("data")
330
+ if config_dict["spec"].get("test_data"):
331
+ config_dict["spec"]["test_data"].pop("data")
332
+ yaml_appendix = rc.Yaml(config_dict)
305
333
  report_sections = (
306
334
  [summary]
307
335
  + backtest_sections
@@ -463,6 +491,7 @@ class ForecastOperatorBaseModel(ABC):
463
491
  f2.write(f1.read())
464
492
 
465
493
  # forecast csv report
494
+ result_df = result_df if self.target_cat_col else result_df.drop(DataColumns.Series, axis=1)
466
495
  write_data(
467
496
  data=result_df,
468
497
  filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
@@ -637,6 +666,13 @@ class ForecastOperatorBaseModel(ABC):
637
666
  storage_options=storage_options,
638
667
  )
639
668
 
669
+ def _validate_automlx_explanation_mode(self):
670
+ if self.spec.model != SupportedModels.AutoMLX and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
671
+ raise ValueError(
672
+ "AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
673
+ "Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
674
+ )
675
+
640
676
  @runtime_dependency(
641
677
  module="shap",
642
678
  err_msg=(
@@ -665,6 +701,9 @@ class ForecastOperatorBaseModel(ABC):
665
701
  )
666
702
  ratio = SpeedAccuracyMode.ratio[self.spec.explanations_accuracy_mode]
667
703
 
704
+ # validate the automlx mode is use for automlx model
705
+ self._validate_automlx_explanation_mode()
706
+
668
707
  for s_id, data_i in self.datasets.get_data_by_series(
669
708
  include_horizon=False
670
709
  ).items():
@@ -699,6 +738,14 @@ class ForecastOperatorBaseModel(ABC):
699
738
  logger.warn(
700
739
  "No explanations generated. Ensure that additional data has been provided."
701
740
  )
741
+ elif (
742
+ self.spec.model == SupportedModels.AutoMLX
743
+ and self.spec.explanations_accuracy_mode
744
+ == SpeedAccuracyMode.AUTOMLX
745
+ ):
746
+ logger.warning(
747
+ "Global explanations not available for AutoMLX models with inherent explainability"
748
+ )
702
749
  else:
703
750
  self.global_explanation[s_id] = dict(
704
751
  zip(
@@ -167,7 +167,7 @@ class ForecastDatasets:
167
167
  self.historical_data.data,
168
168
  self.additional_data.data,
169
169
  ],
170
- axis=1,
170
+ axis=1
171
171
  )
172
172
 
173
173
  def get_data_by_series(self, include_horizon=True):
@@ -360,7 +360,7 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
360
360
  pd.Series(
361
361
  m.state_dict(),
362
362
  index=m.state_dict().keys(),
363
- name=s_id,
363
+ name=s_id if self.target_cat_col else self.original_target_column,
364
364
  )
365
365
  )
366
366
  all_model_states = pd.concat(model_states, axis=1)
@@ -373,6 +373,13 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
373
373
  # If the key is present, call the "explain_model" method
374
374
  self.explain_model()
375
375
 
376
+ if not self.target_cat_col:
377
+ self.formatted_global_explanation = self.formatted_global_explanation.rename(
378
+ {"Series 1": self.original_target_column},
379
+ axis=1,
380
+ )
381
+ self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
382
+
376
383
  # Create a markdown section for the global explainability
377
384
  global_explanation_section = rc.Block(
378
385
  rc.Heading("Global Explainability", level=2),
@@ -385,14 +392,14 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
385
392
  blocks = [
386
393
  rc.DataTable(
387
394
  local_ex_df.drop("Series", axis=1),
388
- label=s_id,
395
+ label=s_id if self.target_cat_col else None,
389
396
  index=True,
390
397
  )
391
398
  for s_id, local_ex_df in self.local_explanation.items()
392
399
  ]
393
400
  local_explanation_section = rc.Block(
394
401
  rc.Heading("Local Explanation of Models", level=2),
395
- rc.Select(blocks=blocks),
402
+ rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
396
403
  )
397
404
 
398
405
  # Append the global explanation text and section to the "all_sections" list
@@ -256,6 +256,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
256
256
  self.outputs[s_id], include_legend=True
257
257
  ),
258
258
  series_ids=series_ids,
259
+ target_category_column=self.target_cat_col
259
260
  )
260
261
  section_1 = rc.Block(
261
262
  rc.Heading("Forecast Overview", level=2),
@@ -268,6 +269,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
268
269
  sec2 = _select_plot_list(
269
270
  lambda s_id: self.models[s_id].plot_components(self.outputs[s_id]),
270
271
  series_ids=series_ids,
272
+ target_category_column=self.target_cat_col
271
273
  )
272
274
  section_2 = rc.Block(
273
275
  rc.Heading("Forecast Broken Down by Trend Component", level=2), sec2
@@ -281,7 +283,9 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
281
283
  sec3_figs[s_id].gca(), self.models[s_id], self.outputs[s_id]
282
284
  )
283
285
  sec3 = _select_plot_list(
284
- lambda s_id: sec3_figs[s_id], series_ids=series_ids
286
+ lambda s_id: sec3_figs[s_id],
287
+ series_ids=series_ids,
288
+ target_category_column=self.target_cat_col
285
289
  )
286
290
  section_3 = rc.Block(rc.Heading("Forecast Changepoints", level=2), sec3)
287
291
 
@@ -295,7 +299,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
295
299
  pd.Series(
296
300
  m.seasonalities,
297
301
  index=pd.Index(m.seasonalities.keys(), dtype="object"),
298
- name=s_id,
302
+ name=s_id if self.target_cat_col else self.original_target_column,
299
303
  dtype="object",
300
304
  )
301
305
  )
@@ -316,15 +320,6 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
316
320
  global_explanation_df / global_explanation_df.sum(axis=0) * 100
317
321
  )
318
322
 
319
- # Create a markdown section for the global explainability
320
- global_explanation_section = rc.Block(
321
- rc.Heading("Global Explanation of Models", level=2),
322
- rc.Text(
323
- "The following tables provide the feature attribution for the global explainability."
324
- ),
325
- rc.DataTable(self.formatted_global_explanation, index=True),
326
- )
327
-
328
323
  aggregate_local_explanations = pd.DataFrame()
329
324
  for s_id, local_ex_df in self.local_explanation.items():
330
325
  local_ex_df_copy = local_ex_df.copy()
@@ -334,17 +329,33 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
334
329
  )
335
330
  self.formatted_local_explanation = aggregate_local_explanations
336
331
 
332
+ if not self.target_cat_col:
333
+ self.formatted_global_explanation = self.formatted_global_explanation.rename(
334
+ {"Series 1": self.original_target_column},
335
+ axis=1,
336
+ )
337
+ self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
338
+
339
+ # Create a markdown section for the global explainability
340
+ global_explanation_section = rc.Block(
341
+ rc.Heading("Global Explanation of Models", level=2),
342
+ rc.Text(
343
+ "The following tables provide the feature attribution for the global explainability."
344
+ ),
345
+ rc.DataTable(self.formatted_global_explanation, index=True),
346
+ )
347
+
337
348
  blocks = [
338
349
  rc.DataTable(
339
350
  local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
340
- label=s_id,
351
+ label=s_id if self.target_cat_col else None,
341
352
  index=True,
342
353
  )
343
354
  for s_id, local_ex_df in self.local_explanation.items()
344
355
  ]
345
356
  local_explanation_section = rc.Block(
346
357
  rc.Heading("Local Explanation of Models", level=2),
347
- rc.Select(blocks=blocks),
358
+ rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
348
359
  )
349
360
 
350
361
  # Append the global explanation text and section to the "all_sections" list
@@ -358,11 +369,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
358
369
  logger.debug(f"Full Traceback: {traceback.format_exc()}")
359
370
 
360
371
  model_description = rc.Text(
361
- "Prophet is a procedure for forecasting time series data based on an additive "
362
- "model where non-linear trends are fit with yearly, weekly, and daily seasonality, "
363
- "plus holiday effects. It works best with time series that have strong seasonal "
364
- "effects and several seasons of historical data. Prophet is robust to missing "
365
- "data and shifts in the trend, and typically handles outliers well."
372
+ """Prophet is a procedure for forecasting time series data based on an additive model where non-linear trends are fit with yearly, weekly, and daily seasonality, plus holiday effects. It works best with time series that have strong seasonal effects and several seasons of historical data. Prophet is robust to missing data and shifts in the trend, and typically handles outliers well."""
366
373
  )
367
374
  other_sections = all_sections
368
375
 
@@ -18,6 +18,35 @@ from ads.opctl.operator.lowcode.common.utils import find_output_dirname
18
18
 
19
19
  from .const import SpeedAccuracyMode, SupportedMetrics, SupportedModels
20
20
 
21
+ @dataclass
22
+ class AutoScaling(DataClassSerializable):
23
+ """Class representing simple autoscaling policy"""
24
+ minimum_instance: int = 1
25
+ maximum_instance: int = None
26
+ cool_down_in_seconds: int = 600
27
+ scale_in_threshold: int = 10
28
+ scale_out_threshold: int = 80
29
+ scaling_metric: str = "CPU_UTILIZATION"
30
+
31
+ @dataclass(repr=True)
32
+ class ModelDeploymentServer(DataClassSerializable):
33
+ """Class representing model deployment server specification for whatif-analysis."""
34
+ display_name: str = None
35
+ initial_shape: str = None
36
+ description: str = None
37
+ log_group: str = None
38
+ log_id: str = None
39
+ auto_scaling: AutoScaling = field(default_factory=AutoScaling)
40
+
41
+
42
+ @dataclass(repr=True)
43
+ class WhatIfAnalysis(DataClassSerializable):
44
+ """Class representing operator specification for whatif-analysis."""
45
+ model_display_name: str = None
46
+ compartment_id: str = None
47
+ project_id: str = None
48
+ model_deployment: ModelDeploymentServer = field(default_factory=ModelDeploymentServer)
49
+
21
50
 
22
51
  @dataclass(repr=True)
23
52
  class TestData(InputData):
@@ -90,12 +119,14 @@ class ForecastOperatorSpec(DataClassSerializable):
90
119
  confidence_interval_width: float = None
91
120
  metric: str = None
92
121
  tuning: Tuning = field(default_factory=Tuning)
122
+ what_if_analysis: WhatIfAnalysis = field(default_factory=WhatIfAnalysis)
93
123
 
94
124
  def __post_init__(self):
95
125
  """Adjusts the specification details."""
96
126
  self.output_directory = self.output_directory or OutputDirectory(
97
127
  url=find_output_dirname(self.output_directory)
98
128
  )
129
+ self.generate_model_pickle = True if self.generate_model_pickle or self.what_if_analysis else False
99
130
  self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower()
100
131
  self.model = self.model or SupportedModels.Prophet
101
132
  self.confidence_interval_width = self.confidence_interval_width or 0.80
@@ -37,6 +37,9 @@ spec:
37
37
  nullable: true
38
38
  required: false
39
39
  type: dict
40
+ data:
41
+ nullable: true
42
+ required: false
40
43
  format:
41
44
  allowed:
42
45
  - csv
@@ -48,6 +51,7 @@ spec:
48
51
  - sql_query
49
52
  - hdf
50
53
  - tsv
54
+ - pandas
51
55
  required: false
52
56
  type: string
53
57
  columns:
@@ -92,6 +96,9 @@ spec:
92
96
  nullable: true
93
97
  required: false
94
98
  type: dict
99
+ data:
100
+ nullable: true
101
+ required: false
95
102
  format:
96
103
  allowed:
97
104
  - csv
@@ -103,6 +110,7 @@ spec:
103
110
  - sql_query
104
111
  - hdf
105
112
  - tsv
113
+ - pandas
106
114
  required: false
107
115
  type: string
108
116
  columns:
@@ -146,6 +154,9 @@ spec:
146
154
  nullable: true
147
155
  required: false
148
156
  type: dict
157
+ data:
158
+ nullable: true
159
+ required: false
149
160
  format:
150
161
  allowed:
151
162
  - csv
@@ -157,6 +168,7 @@ spec:
157
168
  - sql_query
158
169
  - hdf
159
170
  - tsv
171
+ - pandas
160
172
  required: false
161
173
  type: string
162
174
  columns:
@@ -332,6 +344,7 @@ spec:
332
344
  - HIGH_ACCURACY
333
345
  - BALANCED
334
346
  - FAST_APPROXIMATE
347
+ - AUTOMLX
335
348
 
336
349
  generate_report:
337
350
  type: boolean
@@ -340,6 +353,69 @@ spec:
340
353
  meta:
341
354
  description: "Report file generation can be enabled using this flag. Defaults to true."
342
355
 
356
+ what_if_analysis:
357
+ type: dict
358
+ required: false
359
+ schema:
360
+ model_deployment:
361
+ type: dict
362
+ required: false
363
+ meta: "If model_deployment id is not specified, a new model deployment is created; otherwise, the model is linked to the specified model deployment."
364
+ schema:
365
+ id:
366
+ type: string
367
+ required: false
368
+ display_name:
369
+ type: string
370
+ required: false
371
+ initial_shape:
372
+ type: string
373
+ required: false
374
+ description:
375
+ type: string
376
+ required: false
377
+ log_group:
378
+ type: string
379
+ required: true
380
+ log_id:
381
+ type: string
382
+ required: true
383
+ auto_scaling:
384
+ type: dict
385
+ required: false
386
+ schema:
387
+ minimum_instance:
388
+ type: integer
389
+ required: true
390
+ maximum_instance:
391
+ type: integer
392
+ required: true
393
+ scale_in_threshold:
394
+ type: integer
395
+ required: true
396
+ scale_out_threshold:
397
+ type: integer
398
+ required: true
399
+ scaling_metric:
400
+ type: string
401
+ required: true
402
+ cool_down_in_seconds:
403
+ type: integer
404
+ required: true
405
+ model_display_name:
406
+ type: string
407
+ required: true
408
+ project_id:
409
+ type: string
410
+ required: false
411
+ meta: "If not provided, The project OCID from config.PROJECT_OCID is used"
412
+ compartment_id:
413
+ type: string
414
+ required: false
415
+ meta: "If not provided, The compartment OCID from config.NB_SESSION_COMPARTMENT_OCID is used."
416
+ meta:
417
+ description: "When enabled, the models are saved to the model catalog. Defaults to false."
418
+
343
419
  generate_metrics:
344
420
  type: boolean
345
421
  required: false
@@ -250,8 +250,8 @@ def evaluate_train_metrics(output):
250
250
  return total_metrics
251
251
 
252
252
 
253
- def _select_plot_list(fn, series_ids):
254
- blocks = [rc.Widget(fn(s_id=s_id), label=s_id) for s_id in series_ids]
253
+ def _select_plot_list(fn, series_ids, target_category_column):
254
+ blocks = [rc.Widget(fn(s_id=s_id), label=s_id if target_category_column else None) for s_id in series_ids]
255
255
  return rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
256
256
 
257
257
 
@@ -283,6 +283,7 @@ def get_forecast_plots(
283
283
  horizon,
284
284
  test_data=None,
285
285
  ci_interval_width=0.95,
286
+ target_category_column=None
286
287
  ):
287
288
  def plot_forecast_plotly(s_id):
288
289
  fig = go.Figure()
@@ -379,7 +380,7 @@ def get_forecast_plots(
379
380
  )
380
381
  return fig
381
382
 
382
- return _select_plot_list(plot_forecast_plotly, forecast_output.list_series_ids())
383
+ return _select_plot_list(plot_forecast_plotly, forecast_output.list_series_ids(), target_category_column)
383
384
 
384
385
 
385
386
  def convert_target(target: str, target_col: str):
@@ -0,0 +1,7 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+
7
+ from .deployment_manager import ModelDeploymentManager