oracle-ads 2.12.9__py3-none-any.whl → 2.12.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +4 -3
- ads/aqua/app.py +28 -16
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +799 -0
- ads/aqua/common/enums.py +3 -0
- ads/aqua/common/utils.py +62 -2
- ads/aqua/data.py +2 -19
- ads/aqua/evaluation/evaluation.py +20 -12
- ads/aqua/extension/aqua_ws_msg_handler.py +14 -7
- ads/aqua/extension/base_handler.py +12 -9
- ads/aqua/extension/finetune_handler.py +8 -14
- ads/aqua/extension/model_handler.py +24 -2
- ads/aqua/finetuning/constants.py +5 -2
- ads/aqua/finetuning/entities.py +67 -17
- ads/aqua/finetuning/finetuning.py +69 -54
- ads/aqua/model/entities.py +3 -1
- ads/aqua/model/model.py +196 -98
- ads/aqua/modeldeployment/deployment.py +22 -10
- ads/cli.py +16 -8
- ads/common/auth.py +9 -9
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +12 -11
- ads/model/__init__.py +11 -13
- ads/model/artifact.py +47 -8
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/generic_model.py +26 -24
- ads/model/model_metadata.py +8 -7
- ads/opctl/config/merger.py +13 -14
- ads/opctl/operator/common/operator_config.py +4 -4
- ads/opctl/operator/lowcode/common/transformations.py +50 -8
- ads/opctl/operator/lowcode/common/utils.py +22 -6
- ads/opctl/operator/lowcode/forecast/__main__.py +10 -0
- ads/opctl/operator/lowcode/forecast/const.py +2 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
- ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
- ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +61 -14
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +1 -1
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
- ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
- ads/opctl/operator/lowcode/forecast/operator_config.py +31 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +76 -0
- ads/opctl/operator/lowcode/forecast/utils.py +4 -3
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +233 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +238 -0
- ads/telemetry/base.py +18 -11
- ads/telemetry/client.py +33 -13
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/METADATA +9 -8
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/RECORD +74 -48
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/entry_points.txt +0 -0
@@ -28,6 +28,7 @@ from ads.opctl.operator.lowcode.common.utils import (
|
|
28
28
|
seconds_to_datetime,
|
29
29
|
write_data,
|
30
30
|
)
|
31
|
+
from ads.opctl.operator.lowcode.common.const import DataColumns
|
31
32
|
from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
|
32
33
|
from ads.opctl.operator.lowcode.forecast.utils import (
|
33
34
|
_build_metrics_df,
|
@@ -43,11 +44,12 @@ from ads.opctl.operator.lowcode.forecast.utils import (
|
|
43
44
|
|
44
45
|
from ..const import (
|
45
46
|
AUTO_SELECT,
|
47
|
+
BACKTEST_REPORT_NAME,
|
46
48
|
SUMMARY_METRICS_HORIZON_LIMIT,
|
47
49
|
SpeedAccuracyMode,
|
48
50
|
SupportedMetrics,
|
49
51
|
SupportedModels,
|
50
|
-
BACKTEST_REPORT_NAME
|
52
|
+
BACKTEST_REPORT_NAME,
|
51
53
|
)
|
52
54
|
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
|
53
55
|
from .forecast_datasets import ForecastDatasets
|
@@ -69,7 +71,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
69
71
|
self.config: ForecastOperatorConfig = config
|
70
72
|
self.spec: ForecastOperatorSpec = config.spec
|
71
73
|
self.datasets: ForecastDatasets = datasets
|
72
|
-
|
74
|
+
self.target_cat_col = self.spec.target_category_columns
|
73
75
|
self.full_data_dict = datasets.get_data_by_series()
|
74
76
|
|
75
77
|
self.test_eval_metrics = None
|
@@ -124,6 +126,9 @@ class ForecastOperatorBaseModel(ABC):
|
|
124
126
|
|
125
127
|
if self.spec.generate_report or self.spec.generate_metrics:
|
126
128
|
self.eval_metrics = self.generate_train_metrics()
|
129
|
+
if not self.target_cat_col:
|
130
|
+
self.eval_metrics.rename({"Series 1": self.original_target_column},
|
131
|
+
axis=1, inplace=True)
|
127
132
|
|
128
133
|
if self.spec.test_data:
|
129
134
|
try:
|
@@ -134,6 +139,9 @@ class ForecastOperatorBaseModel(ABC):
|
|
134
139
|
) = self._test_evaluate_metrics(
|
135
140
|
elapsed_time=elapsed_time,
|
136
141
|
)
|
142
|
+
if not self.target_cat_col:
|
143
|
+
self.test_eval_metrics.rename({"Series 1": self.original_target_column},
|
144
|
+
axis=1, inplace=True)
|
137
145
|
except Exception:
|
138
146
|
logger.warn("Unable to generate Test Metrics.")
|
139
147
|
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|
@@ -179,7 +187,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
179
187
|
first_5_rows_blocks = [
|
180
188
|
rc.DataTable(
|
181
189
|
df.head(5),
|
182
|
-
label=s_id,
|
190
|
+
label=s_id if self.target_cat_col else None,
|
183
191
|
index=True,
|
184
192
|
)
|
185
193
|
for s_id, df in self.full_data_dict.items()
|
@@ -188,7 +196,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
188
196
|
last_5_rows_blocks = [
|
189
197
|
rc.DataTable(
|
190
198
|
df.tail(5),
|
191
|
-
label=s_id,
|
199
|
+
label=s_id if self.target_cat_col else None,
|
192
200
|
index=True,
|
193
201
|
)
|
194
202
|
for s_id, df in self.full_data_dict.items()
|
@@ -197,7 +205,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
197
205
|
data_summary_blocks = [
|
198
206
|
rc.DataTable(
|
199
207
|
df.describe(),
|
200
|
-
label=s_id,
|
208
|
+
label=s_id if self.target_cat_col else None,
|
201
209
|
index=True,
|
202
210
|
)
|
203
211
|
for s_id, df in self.full_data_dict.items()
|
@@ -215,17 +223,17 @@ class ForecastOperatorBaseModel(ABC):
|
|
215
223
|
rc.Block(
|
216
224
|
first_10_title,
|
217
225
|
# series_subtext,
|
218
|
-
rc.Select(blocks=first_5_rows_blocks),
|
226
|
+
rc.Select(blocks=first_5_rows_blocks) if self.target_cat_col else first_5_rows_blocks[0],
|
219
227
|
),
|
220
228
|
rc.Block(
|
221
229
|
last_10_title,
|
222
230
|
# series_subtext,
|
223
|
-
rc.Select(blocks=last_5_rows_blocks),
|
231
|
+
rc.Select(blocks=last_5_rows_blocks) if self.target_cat_col else last_5_rows_blocks[0],
|
224
232
|
),
|
225
233
|
rc.Block(
|
226
234
|
summary_title,
|
227
235
|
# series_subtext,
|
228
|
-
rc.Select(blocks=data_summary_blocks),
|
236
|
+
rc.Select(blocks=data_summary_blocks) if self.target_cat_col else data_summary_blocks[0],
|
229
237
|
),
|
230
238
|
rc.Separator(),
|
231
239
|
)
|
@@ -259,7 +267,11 @@ class ForecastOperatorBaseModel(ABC):
|
|
259
267
|
output_dir = self.spec.output_directory.url
|
260
268
|
file_path = f"{output_dir}/{BACKTEST_REPORT_NAME}"
|
261
269
|
if self.spec.model == AUTO_SELECT:
|
262
|
-
backtest_sections.append(
|
270
|
+
backtest_sections.append(
|
271
|
+
rc.Heading(
|
272
|
+
"Auto-Select Backtesting and Performance Metrics", level=2
|
273
|
+
)
|
274
|
+
)
|
263
275
|
if not os.path.exists(file_path):
|
264
276
|
failure_msg = rc.Text(
|
265
277
|
"auto-select could not be executed. Please check the "
|
@@ -268,15 +280,23 @@ class ForecastOperatorBaseModel(ABC):
|
|
268
280
|
backtest_sections.append(failure_msg)
|
269
281
|
else:
|
270
282
|
backtest_stats = pd.read_csv(file_path)
|
271
|
-
model_metric_map = backtest_stats.drop(
|
272
|
-
|
283
|
+
model_metric_map = backtest_stats.drop(
|
284
|
+
columns=["metric", "backtest"]
|
285
|
+
)
|
286
|
+
average_dict = {
|
287
|
+
k: round(v, 4)
|
288
|
+
for k, v in model_metric_map.mean().to_dict().items()
|
289
|
+
}
|
273
290
|
best_model = min(average_dict, key=average_dict.get)
|
274
291
|
summary_text = rc.Text(
|
275
292
|
f"Overall, the average {self.spec.metric} scores for the models are {average_dict}, with"
|
276
|
-
f" {best_model} being identified as the top-performing model during backtesting."
|
293
|
+
f" {best_model} being identified as the top-performing model during backtesting."
|
294
|
+
)
|
277
295
|
backtest_table = rc.DataTable(backtest_stats, index=True)
|
278
296
|
liner_plot = get_auto_select_plot(backtest_stats)
|
279
|
-
backtest_sections.extend(
|
297
|
+
backtest_sections.extend(
|
298
|
+
[backtest_table, summary_text, liner_plot]
|
299
|
+
)
|
280
300
|
|
281
301
|
forecast_plots = []
|
282
302
|
if len(self.forecast_output.list_series_ids()) > 0:
|
@@ -288,6 +308,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
288
308
|
horizon=self.spec.horizon,
|
289
309
|
test_data=test_data,
|
290
310
|
ci_interval_width=self.spec.confidence_interval_width,
|
311
|
+
target_category_column=self.target_cat_col
|
291
312
|
)
|
292
313
|
if (
|
293
314
|
series_name is not None
|
@@ -301,7 +322,14 @@ class ForecastOperatorBaseModel(ABC):
|
|
301
322
|
forecast_plots = [forecast_text, forecast_sec]
|
302
323
|
|
303
324
|
yaml_appendix_title = rc.Heading("Reference: YAML File", level=2)
|
304
|
-
|
325
|
+
config_dict = self.config.to_dict()
|
326
|
+
# pop the data incase it isn't json serializable
|
327
|
+
config_dict["spec"]["historical_data"].pop("data")
|
328
|
+
if config_dict["spec"].get("additional_data"):
|
329
|
+
config_dict["spec"]["additional_data"].pop("data")
|
330
|
+
if config_dict["spec"].get("test_data"):
|
331
|
+
config_dict["spec"]["test_data"].pop("data")
|
332
|
+
yaml_appendix = rc.Yaml(config_dict)
|
305
333
|
report_sections = (
|
306
334
|
[summary]
|
307
335
|
+ backtest_sections
|
@@ -463,6 +491,7 @@ class ForecastOperatorBaseModel(ABC):
|
|
463
491
|
f2.write(f1.read())
|
464
492
|
|
465
493
|
# forecast csv report
|
494
|
+
result_df = result_df if self.target_cat_col else result_df.drop(DataColumns.Series, axis=1)
|
466
495
|
write_data(
|
467
496
|
data=result_df,
|
468
497
|
filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
|
@@ -637,6 +666,13 @@ class ForecastOperatorBaseModel(ABC):
|
|
637
666
|
storage_options=storage_options,
|
638
667
|
)
|
639
668
|
|
669
|
+
def _validate_automlx_explanation_mode(self):
|
670
|
+
if self.spec.model != SupportedModels.AutoMLX and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
|
671
|
+
raise ValueError(
|
672
|
+
"AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
|
673
|
+
"Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
|
674
|
+
)
|
675
|
+
|
640
676
|
@runtime_dependency(
|
641
677
|
module="shap",
|
642
678
|
err_msg=(
|
@@ -665,6 +701,9 @@ class ForecastOperatorBaseModel(ABC):
|
|
665
701
|
)
|
666
702
|
ratio = SpeedAccuracyMode.ratio[self.spec.explanations_accuracy_mode]
|
667
703
|
|
704
|
+
# validate the automlx mode is use for automlx model
|
705
|
+
self._validate_automlx_explanation_mode()
|
706
|
+
|
668
707
|
for s_id, data_i in self.datasets.get_data_by_series(
|
669
708
|
include_horizon=False
|
670
709
|
).items():
|
@@ -699,6 +738,14 @@ class ForecastOperatorBaseModel(ABC):
|
|
699
738
|
logger.warn(
|
700
739
|
"No explanations generated. Ensure that additional data has been provided."
|
701
740
|
)
|
741
|
+
elif (
|
742
|
+
self.spec.model == SupportedModels.AutoMLX
|
743
|
+
and self.spec.explanations_accuracy_mode
|
744
|
+
== SpeedAccuracyMode.AUTOMLX
|
745
|
+
):
|
746
|
+
logger.warning(
|
747
|
+
"Global explanations not available for AutoMLX models with inherent explainability"
|
748
|
+
)
|
702
749
|
else:
|
703
750
|
self.global_explanation[s_id] = dict(
|
704
751
|
zip(
|
@@ -360,7 +360,7 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
|
|
360
360
|
pd.Series(
|
361
361
|
m.state_dict(),
|
362
362
|
index=m.state_dict().keys(),
|
363
|
-
name=s_id,
|
363
|
+
name=s_id if self.target_cat_col else self.original_target_column,
|
364
364
|
)
|
365
365
|
)
|
366
366
|
all_model_states = pd.concat(model_states, axis=1)
|
@@ -373,6 +373,13 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
|
|
373
373
|
# If the key is present, call the "explain_model" method
|
374
374
|
self.explain_model()
|
375
375
|
|
376
|
+
if not self.target_cat_col:
|
377
|
+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
|
378
|
+
{"Series 1": self.original_target_column},
|
379
|
+
axis=1,
|
380
|
+
)
|
381
|
+
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
|
382
|
+
|
376
383
|
# Create a markdown section for the global explainability
|
377
384
|
global_explanation_section = rc.Block(
|
378
385
|
rc.Heading("Global Explainability", level=2),
|
@@ -385,14 +392,14 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
|
|
385
392
|
blocks = [
|
386
393
|
rc.DataTable(
|
387
394
|
local_ex_df.drop("Series", axis=1),
|
388
|
-
label=s_id,
|
395
|
+
label=s_id if self.target_cat_col else None,
|
389
396
|
index=True,
|
390
397
|
)
|
391
398
|
for s_id, local_ex_df in self.local_explanation.items()
|
392
399
|
]
|
393
400
|
local_explanation_section = rc.Block(
|
394
401
|
rc.Heading("Local Explanation of Models", level=2),
|
395
|
-
rc.Select(blocks=blocks),
|
402
|
+
rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
|
396
403
|
)
|
397
404
|
|
398
405
|
# Append the global explanation text and section to the "all_sections" list
|
@@ -256,6 +256,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
|
|
256
256
|
self.outputs[s_id], include_legend=True
|
257
257
|
),
|
258
258
|
series_ids=series_ids,
|
259
|
+
target_category_column=self.target_cat_col
|
259
260
|
)
|
260
261
|
section_1 = rc.Block(
|
261
262
|
rc.Heading("Forecast Overview", level=2),
|
@@ -268,6 +269,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
|
|
268
269
|
sec2 = _select_plot_list(
|
269
270
|
lambda s_id: self.models[s_id].plot_components(self.outputs[s_id]),
|
270
271
|
series_ids=series_ids,
|
272
|
+
target_category_column=self.target_cat_col
|
271
273
|
)
|
272
274
|
section_2 = rc.Block(
|
273
275
|
rc.Heading("Forecast Broken Down by Trend Component", level=2), sec2
|
@@ -281,7 +283,9 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
|
|
281
283
|
sec3_figs[s_id].gca(), self.models[s_id], self.outputs[s_id]
|
282
284
|
)
|
283
285
|
sec3 = _select_plot_list(
|
284
|
-
lambda s_id: sec3_figs[s_id],
|
286
|
+
lambda s_id: sec3_figs[s_id],
|
287
|
+
series_ids=series_ids,
|
288
|
+
target_category_column=self.target_cat_col
|
285
289
|
)
|
286
290
|
section_3 = rc.Block(rc.Heading("Forecast Changepoints", level=2), sec3)
|
287
291
|
|
@@ -295,7 +299,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
|
|
295
299
|
pd.Series(
|
296
300
|
m.seasonalities,
|
297
301
|
index=pd.Index(m.seasonalities.keys(), dtype="object"),
|
298
|
-
name=s_id,
|
302
|
+
name=s_id if self.target_cat_col else self.original_target_column,
|
299
303
|
dtype="object",
|
300
304
|
)
|
301
305
|
)
|
@@ -316,15 +320,6 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
|
|
316
320
|
global_explanation_df / global_explanation_df.sum(axis=0) * 100
|
317
321
|
)
|
318
322
|
|
319
|
-
# Create a markdown section for the global explainability
|
320
|
-
global_explanation_section = rc.Block(
|
321
|
-
rc.Heading("Global Explanation of Models", level=2),
|
322
|
-
rc.Text(
|
323
|
-
"The following tables provide the feature attribution for the global explainability."
|
324
|
-
),
|
325
|
-
rc.DataTable(self.formatted_global_explanation, index=True),
|
326
|
-
)
|
327
|
-
|
328
323
|
aggregate_local_explanations = pd.DataFrame()
|
329
324
|
for s_id, local_ex_df in self.local_explanation.items():
|
330
325
|
local_ex_df_copy = local_ex_df.copy()
|
@@ -334,17 +329,33 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
|
|
334
329
|
)
|
335
330
|
self.formatted_local_explanation = aggregate_local_explanations
|
336
331
|
|
332
|
+
if not self.target_cat_col:
|
333
|
+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
|
334
|
+
{"Series 1": self.original_target_column},
|
335
|
+
axis=1,
|
336
|
+
)
|
337
|
+
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
|
338
|
+
|
339
|
+
# Create a markdown section for the global explainability
|
340
|
+
global_explanation_section = rc.Block(
|
341
|
+
rc.Heading("Global Explanation of Models", level=2),
|
342
|
+
rc.Text(
|
343
|
+
"The following tables provide the feature attribution for the global explainability."
|
344
|
+
),
|
345
|
+
rc.DataTable(self.formatted_global_explanation, index=True),
|
346
|
+
)
|
347
|
+
|
337
348
|
blocks = [
|
338
349
|
rc.DataTable(
|
339
350
|
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
|
340
|
-
label=s_id,
|
351
|
+
label=s_id if self.target_cat_col else None,
|
341
352
|
index=True,
|
342
353
|
)
|
343
354
|
for s_id, local_ex_df in self.local_explanation.items()
|
344
355
|
]
|
345
356
|
local_explanation_section = rc.Block(
|
346
357
|
rc.Heading("Local Explanation of Models", level=2),
|
347
|
-
rc.Select(blocks=blocks),
|
358
|
+
rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
|
348
359
|
)
|
349
360
|
|
350
361
|
# Append the global explanation text and section to the "all_sections" list
|
@@ -358,11 +369,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
|
|
358
369
|
logger.debug(f"Full Traceback: {traceback.format_exc()}")
|
359
370
|
|
360
371
|
model_description = rc.Text(
|
361
|
-
"Prophet is a procedure for forecasting time series data based on an additive "
|
362
|
-
"model where non-linear trends are fit with yearly, weekly, and daily seasonality, "
|
363
|
-
"plus holiday effects. It works best with time series that have strong seasonal "
|
364
|
-
"effects and several seasons of historical data. Prophet is robust to missing "
|
365
|
-
"data and shifts in the trend, and typically handles outliers well."
|
372
|
+
"""Prophet is a procedure for forecasting time series data based on an additive model where non-linear trends are fit with yearly, weekly, and daily seasonality, plus holiday effects. It works best with time series that have strong seasonal effects and several seasons of historical data. Prophet is robust to missing data and shifts in the trend, and typically handles outliers well."""
|
366
373
|
)
|
367
374
|
other_sections = all_sections
|
368
375
|
|
@@ -18,6 +18,35 @@ from ads.opctl.operator.lowcode.common.utils import find_output_dirname
|
|
18
18
|
|
19
19
|
from .const import SpeedAccuracyMode, SupportedMetrics, SupportedModels
|
20
20
|
|
21
|
+
@dataclass
|
22
|
+
class AutoScaling(DataClassSerializable):
|
23
|
+
"""Class representing simple autoscaling policy"""
|
24
|
+
minimum_instance: int = 1
|
25
|
+
maximum_instance: int = None
|
26
|
+
cool_down_in_seconds: int = 600
|
27
|
+
scale_in_threshold: int = 10
|
28
|
+
scale_out_threshold: int = 80
|
29
|
+
scaling_metric: str = "CPU_UTILIZATION"
|
30
|
+
|
31
|
+
@dataclass(repr=True)
|
32
|
+
class ModelDeploymentServer(DataClassSerializable):
|
33
|
+
"""Class representing model deployment server specification for whatif-analysis."""
|
34
|
+
display_name: str = None
|
35
|
+
initial_shape: str = None
|
36
|
+
description: str = None
|
37
|
+
log_group: str = None
|
38
|
+
log_id: str = None
|
39
|
+
auto_scaling: AutoScaling = field(default_factory=AutoScaling)
|
40
|
+
|
41
|
+
|
42
|
+
@dataclass(repr=True)
|
43
|
+
class WhatIfAnalysis(DataClassSerializable):
|
44
|
+
"""Class representing operator specification for whatif-analysis."""
|
45
|
+
model_display_name: str = None
|
46
|
+
compartment_id: str = None
|
47
|
+
project_id: str = None
|
48
|
+
model_deployment: ModelDeploymentServer = field(default_factory=ModelDeploymentServer)
|
49
|
+
|
21
50
|
|
22
51
|
@dataclass(repr=True)
|
23
52
|
class TestData(InputData):
|
@@ -90,12 +119,14 @@ class ForecastOperatorSpec(DataClassSerializable):
|
|
90
119
|
confidence_interval_width: float = None
|
91
120
|
metric: str = None
|
92
121
|
tuning: Tuning = field(default_factory=Tuning)
|
122
|
+
what_if_analysis: WhatIfAnalysis = field(default_factory=WhatIfAnalysis)
|
93
123
|
|
94
124
|
def __post_init__(self):
|
95
125
|
"""Adjusts the specification details."""
|
96
126
|
self.output_directory = self.output_directory or OutputDirectory(
|
97
127
|
url=find_output_dirname(self.output_directory)
|
98
128
|
)
|
129
|
+
self.generate_model_pickle = True if self.generate_model_pickle or self.what_if_analysis else False
|
99
130
|
self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower()
|
100
131
|
self.model = self.model or SupportedModels.Prophet
|
101
132
|
self.confidence_interval_width = self.confidence_interval_width or 0.80
|
@@ -37,6 +37,9 @@ spec:
|
|
37
37
|
nullable: true
|
38
38
|
required: false
|
39
39
|
type: dict
|
40
|
+
data:
|
41
|
+
nullable: true
|
42
|
+
required: false
|
40
43
|
format:
|
41
44
|
allowed:
|
42
45
|
- csv
|
@@ -48,6 +51,7 @@ spec:
|
|
48
51
|
- sql_query
|
49
52
|
- hdf
|
50
53
|
- tsv
|
54
|
+
- pandas
|
51
55
|
required: false
|
52
56
|
type: string
|
53
57
|
columns:
|
@@ -92,6 +96,9 @@ spec:
|
|
92
96
|
nullable: true
|
93
97
|
required: false
|
94
98
|
type: dict
|
99
|
+
data:
|
100
|
+
nullable: true
|
101
|
+
required: false
|
95
102
|
format:
|
96
103
|
allowed:
|
97
104
|
- csv
|
@@ -103,6 +110,7 @@ spec:
|
|
103
110
|
- sql_query
|
104
111
|
- hdf
|
105
112
|
- tsv
|
113
|
+
- pandas
|
106
114
|
required: false
|
107
115
|
type: string
|
108
116
|
columns:
|
@@ -146,6 +154,9 @@ spec:
|
|
146
154
|
nullable: true
|
147
155
|
required: false
|
148
156
|
type: dict
|
157
|
+
data:
|
158
|
+
nullable: true
|
159
|
+
required: false
|
149
160
|
format:
|
150
161
|
allowed:
|
151
162
|
- csv
|
@@ -157,6 +168,7 @@ spec:
|
|
157
168
|
- sql_query
|
158
169
|
- hdf
|
159
170
|
- tsv
|
171
|
+
- pandas
|
160
172
|
required: false
|
161
173
|
type: string
|
162
174
|
columns:
|
@@ -332,6 +344,7 @@ spec:
|
|
332
344
|
- HIGH_ACCURACY
|
333
345
|
- BALANCED
|
334
346
|
- FAST_APPROXIMATE
|
347
|
+
- AUTOMLX
|
335
348
|
|
336
349
|
generate_report:
|
337
350
|
type: boolean
|
@@ -340,6 +353,69 @@ spec:
|
|
340
353
|
meta:
|
341
354
|
description: "Report file generation can be enabled using this flag. Defaults to true."
|
342
355
|
|
356
|
+
what_if_analysis:
|
357
|
+
type: dict
|
358
|
+
required: false
|
359
|
+
schema:
|
360
|
+
model_deployment:
|
361
|
+
type: dict
|
362
|
+
required: false
|
363
|
+
meta: "If model_deployment id is not specified, a new model deployment is created; otherwise, the model is linked to the specified model deployment."
|
364
|
+
schema:
|
365
|
+
id:
|
366
|
+
type: string
|
367
|
+
required: false
|
368
|
+
display_name:
|
369
|
+
type: string
|
370
|
+
required: false
|
371
|
+
initial_shape:
|
372
|
+
type: string
|
373
|
+
required: false
|
374
|
+
description:
|
375
|
+
type: string
|
376
|
+
required: false
|
377
|
+
log_group:
|
378
|
+
type: string
|
379
|
+
required: true
|
380
|
+
log_id:
|
381
|
+
type: string
|
382
|
+
required: true
|
383
|
+
auto_scaling:
|
384
|
+
type: dict
|
385
|
+
required: false
|
386
|
+
schema:
|
387
|
+
minimum_instance:
|
388
|
+
type: integer
|
389
|
+
required: true
|
390
|
+
maximum_instance:
|
391
|
+
type: integer
|
392
|
+
required: true
|
393
|
+
scale_in_threshold:
|
394
|
+
type: integer
|
395
|
+
required: true
|
396
|
+
scale_out_threshold:
|
397
|
+
type: integer
|
398
|
+
required: true
|
399
|
+
scaling_metric:
|
400
|
+
type: string
|
401
|
+
required: true
|
402
|
+
cool_down_in_seconds:
|
403
|
+
type: integer
|
404
|
+
required: true
|
405
|
+
model_display_name:
|
406
|
+
type: string
|
407
|
+
required: true
|
408
|
+
project_id:
|
409
|
+
type: string
|
410
|
+
required: false
|
411
|
+
meta: "If not provided, The project OCID from config.PROJECT_OCID is used"
|
412
|
+
compartment_id:
|
413
|
+
type: string
|
414
|
+
required: false
|
415
|
+
meta: "If not provided, The compartment OCID from config.NB_SESSION_COMPARTMENT_OCID is used."
|
416
|
+
meta:
|
417
|
+
description: "When enabled, the models are saved to the model catalog. Defaults to false."
|
418
|
+
|
343
419
|
generate_metrics:
|
344
420
|
type: boolean
|
345
421
|
required: false
|
@@ -250,8 +250,8 @@ def evaluate_train_metrics(output):
|
|
250
250
|
return total_metrics
|
251
251
|
|
252
252
|
|
253
|
-
def _select_plot_list(fn, series_ids):
|
254
|
-
blocks = [rc.Widget(fn(s_id=s_id), label=s_id) for s_id in series_ids]
|
253
|
+
def _select_plot_list(fn, series_ids, target_category_column):
|
254
|
+
blocks = [rc.Widget(fn(s_id=s_id), label=s_id if target_category_column else None) for s_id in series_ids]
|
255
255
|
return rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
|
256
256
|
|
257
257
|
|
@@ -283,6 +283,7 @@ def get_forecast_plots(
|
|
283
283
|
horizon,
|
284
284
|
test_data=None,
|
285
285
|
ci_interval_width=0.95,
|
286
|
+
target_category_column=None
|
286
287
|
):
|
287
288
|
def plot_forecast_plotly(s_id):
|
288
289
|
fig = go.Figure()
|
@@ -379,7 +380,7 @@ def get_forecast_plots(
|
|
379
380
|
)
|
380
381
|
return fig
|
381
382
|
|
382
|
-
return _select_plot_list(plot_forecast_plotly, forecast_output.list_series_ids())
|
383
|
+
return _select_plot_list(plot_forecast_plotly, forecast_output.list_series_ids(), target_category_column)
|
383
384
|
|
384
385
|
|
385
386
|
def convert_target(target: str, target_col: str):
|