oracle-ads 2.10.0__py3-none-any.whl → 2.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. ads/aqua/__init__.py +12 -0
  2. ads/aqua/base.py +324 -0
  3. ads/aqua/cli.py +19 -0
  4. ads/aqua/config/deployment_config_defaults.json +9 -0
  5. ads/aqua/config/resource_limit_names.json +7 -0
  6. ads/aqua/constants.py +45 -0
  7. ads/aqua/data.py +40 -0
  8. ads/aqua/decorator.py +101 -0
  9. ads/aqua/deployment.py +643 -0
  10. ads/aqua/dummy_data/icon.txt +1 -0
  11. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  12. ads/aqua/dummy_data/oci_models.json +1 -0
  13. ads/aqua/dummy_data/readme.md +26 -0
  14. ads/aqua/evaluation.py +1751 -0
  15. ads/aqua/exception.py +82 -0
  16. ads/aqua/extension/__init__.py +40 -0
  17. ads/aqua/extension/base_handler.py +138 -0
  18. ads/aqua/extension/common_handler.py +21 -0
  19. ads/aqua/extension/deployment_handler.py +202 -0
  20. ads/aqua/extension/evaluation_handler.py +135 -0
  21. ads/aqua/extension/finetune_handler.py +66 -0
  22. ads/aqua/extension/model_handler.py +59 -0
  23. ads/aqua/extension/ui_handler.py +201 -0
  24. ads/aqua/extension/utils.py +23 -0
  25. ads/aqua/finetune.py +579 -0
  26. ads/aqua/job.py +29 -0
  27. ads/aqua/model.py +819 -0
  28. ads/aqua/training/__init__.py +4 -0
  29. ads/aqua/training/exceptions.py +459 -0
  30. ads/aqua/ui.py +453 -0
  31. ads/aqua/utils.py +715 -0
  32. ads/cli.py +37 -6
  33. ads/common/auth.py +7 -0
  34. ads/common/decorator/__init__.py +7 -3
  35. ads/common/decorator/require_nonempty_arg.py +65 -0
  36. ads/common/object_storage_details.py +166 -7
  37. ads/common/oci_client.py +18 -1
  38. ads/common/oci_logging.py +2 -2
  39. ads/common/oci_mixin.py +4 -5
  40. ads/common/serializer.py +34 -5
  41. ads/common/utils.py +75 -10
  42. ads/config.py +40 -1
  43. ads/dataset/correlation_plot.py +10 -12
  44. ads/jobs/ads_job.py +43 -25
  45. ads/jobs/builders/infrastructure/base.py +4 -2
  46. ads/jobs/builders/infrastructure/dsc_job.py +49 -39
  47. ads/jobs/builders/runtimes/base.py +71 -1
  48. ads/jobs/builders/runtimes/container_runtime.py +4 -4
  49. ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
  50. ads/jobs/templates/driver_pytorch.py +27 -10
  51. ads/model/artifact_downloader.py +84 -14
  52. ads/model/artifact_uploader.py +25 -23
  53. ads/model/datascience_model.py +388 -38
  54. ads/model/deployment/model_deployment.py +10 -2
  55. ads/model/generic_model.py +8 -0
  56. ads/model/model_file_description_schema.json +68 -0
  57. ads/model/model_metadata.py +1 -1
  58. ads/model/service/oci_datascience_model.py +34 -5
  59. ads/opctl/config/merger.py +2 -2
  60. ads/opctl/operator/__init__.py +3 -1
  61. ads/opctl/operator/cli.py +7 -1
  62. ads/opctl/operator/cmd.py +3 -3
  63. ads/opctl/operator/common/errors.py +2 -1
  64. ads/opctl/operator/common/operator_config.py +22 -3
  65. ads/opctl/operator/common/utils.py +16 -0
  66. ads/opctl/operator/lowcode/anomaly/MLoperator +15 -0
  67. ads/opctl/operator/lowcode/anomaly/README.md +209 -0
  68. ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
  69. ads/opctl/operator/lowcode/anomaly/__main__.py +104 -0
  70. ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
  71. ads/opctl/operator/lowcode/anomaly/const.py +88 -0
  72. ads/opctl/operator/lowcode/anomaly/environment.yaml +12 -0
  73. ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
  74. ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +147 -0
  75. ads/opctl/operator/lowcode/anomaly/model/automlx.py +89 -0
  76. ads/opctl/operator/lowcode/anomaly/model/autots.py +103 -0
  77. ads/opctl/operator/lowcode/anomaly/model/base_model.py +354 -0
  78. ads/opctl/operator/lowcode/anomaly/model/factory.py +67 -0
  79. ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
  80. ads/opctl/operator/lowcode/anomaly/operator_config.py +105 -0
  81. ads/opctl/operator/lowcode/anomaly/schema.yaml +359 -0
  82. ads/opctl/operator/lowcode/anomaly/utils.py +81 -0
  83. ads/opctl/operator/lowcode/common/__init__.py +5 -0
  84. ads/opctl/operator/lowcode/common/const.py +10 -0
  85. ads/opctl/operator/lowcode/common/data.py +96 -0
  86. ads/opctl/operator/lowcode/common/errors.py +41 -0
  87. ads/opctl/operator/lowcode/common/transformations.py +191 -0
  88. ads/opctl/operator/lowcode/common/utils.py +250 -0
  89. ads/opctl/operator/lowcode/forecast/README.md +3 -2
  90. ads/opctl/operator/lowcode/forecast/__main__.py +18 -2
  91. ads/opctl/operator/lowcode/forecast/cmd.py +8 -7
  92. ads/opctl/operator/lowcode/forecast/const.py +17 -1
  93. ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
  94. ads/opctl/operator/lowcode/forecast/model/arima.py +106 -117
  95. ads/opctl/operator/lowcode/forecast/model/automlx.py +204 -180
  96. ads/opctl/operator/lowcode/forecast/model/autots.py +144 -253
  97. ads/opctl/operator/lowcode/forecast/model/base_model.py +326 -259
  98. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +325 -176
  99. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +293 -237
  100. ads/opctl/operator/lowcode/forecast/model/prophet.py +191 -208
  101. ads/opctl/operator/lowcode/forecast/operator_config.py +24 -33
  102. ads/opctl/operator/lowcode/forecast/schema.yaml +116 -29
  103. ads/opctl/operator/lowcode/forecast/utils.py +186 -356
  104. ads/opctl/operator/lowcode/pii/model/guardrails.py +18 -15
  105. ads/opctl/operator/lowcode/pii/model/report.py +7 -7
  106. ads/opctl/operator/lowcode/pii/operator_config.py +1 -8
  107. ads/opctl/operator/lowcode/pii/utils.py +0 -82
  108. ads/opctl/operator/runtime/runtime.py +3 -2
  109. ads/telemetry/base.py +62 -0
  110. ads/telemetry/client.py +105 -0
  111. ads/telemetry/telemetry.py +6 -3
  112. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +44 -7
  113. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +116 -59
  114. ads/opctl/operator/lowcode/forecast/model/transformations.py +0 -125
  115. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
  116. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
  117. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,191 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*--
3
+
4
+ # Copyright (c) 2023 Oracle and/or its affiliates.
5
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
+
7
+ from ads.opctl import logger
8
+ from ads.opctl.operator.lowcode.common.errors import (
9
+ InvalidParameterError,
10
+ DataMismatchError,
11
+ )
12
+ from ads.opctl.operator.lowcode.common.const import DataColumns
13
+ from ads.opctl.operator.lowcode.common.utils import merge_category_columns
14
+ import pandas as pd
15
+ from abc import ABC
16
+
17
+
18
+ class Transformations(ABC):
19
+ """A class which implements transformation for forecast operator"""
20
+
21
+ def __init__(self, dataset_info, name="historical_data"):
22
+ """
23
+ Initializes the transformation.
24
+
25
+ Parameters
26
+ ----------
27
+ data: The Pandas DataFrame.
28
+ dataset_info : ForecastOperatorConfig
29
+ """
30
+ self.name = name
31
+ self.has_artificial_series = False
32
+ self.dataset_info = dataset_info
33
+ self.target_category_columns = dataset_info.target_category_columns
34
+ self.target_column_name = dataset_info.target_column
35
+ self.dt_column_name = dataset_info.datetime_column.name
36
+ self.dt_column_format = dataset_info.datetime_column.format
37
+ self.preprocessing = dataset_info.preprocessing
38
+
39
+ def run(self, data):
40
+ """
41
+ The function runs all the transformation in a particular order.
42
+
43
+ Returns
44
+ -------
45
+ A new Pandas DataFrame with treated / transformed target values. Specifically:
46
+ - Data will be in a multiIndex with Datetime always first (level 0)
47
+ - whether 0, 1 or 2+, all target_category_columns will be merged into a single index column: Series
48
+ - All datetime columns will be formatted as such
49
+ - all data will be imputed (unless preprocessing disabled)
50
+ - all trailing whitespace will be removed
51
+ - the data will be sorted by Datetime then Series
52
+
53
+ """
54
+ clean_df = self._remove_trailing_whitespace(data)
55
+ if self.name == "historical_data":
56
+ self._check_historical_dataset(clean_df)
57
+ clean_df = self._set_series_id_column(clean_df)
58
+ clean_df = self._format_datetime_col(clean_df)
59
+ clean_df = self._set_multi_index(clean_df)
60
+
61
+ if self.name == "historical_data":
62
+ try:
63
+ clean_df = self._missing_value_imputation_hist(clean_df)
64
+ except Exception as e:
65
+ logger.debug(f"Missing value imputation failed with {e.args}")
66
+ if self.preprocessing:
67
+ try:
68
+ clean_df = self._outlier_treatment(clean_df)
69
+ except Exception as e:
70
+ logger.debug(f"Outlier Treatment failed with {e.args}")
71
+ else:
72
+ logger.debug("Skipping outlier treatment as preprocessing is disabled")
73
+ elif self.name == "additional_data":
74
+ clean_df = self._missing_value_imputation_add(clean_df)
75
+ return clean_df
76
+
77
+ def _remove_trailing_whitespace(self, df):
78
+ return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
79
+
80
+ def _set_series_id_column(self, df):
81
+ if not self.target_category_columns:
82
+ df[DataColumns.Series] = "Series 1"
83
+ self.has_artificial_series = True
84
+ else:
85
+ df[DataColumns.Series] = merge_category_columns(
86
+ df, self.target_category_columns
87
+ )
88
+ df = df.drop(self.target_category_columns, axis=1)
89
+ return df
90
+
91
+ def _format_datetime_col(self, df):
92
+ try:
93
+ df[self.dt_column_name] = pd.to_datetime(
94
+ df[self.dt_column_name], format=self.dt_column_format
95
+ )
96
+ except:
97
+ raise InvalidParameterError(
98
+ f"Unable to determine the datetime type for column: {self.dt_column_name} in dataset: {self.name}. Please specify the format explicitly. (For example adding 'format: %d/%m/%Y' underneath 'name: {self.dt_column_name}' in the datetime_column section of the yaml file if you haven't already. For reference, here is the first datetime given: {df[self.dt_column_name].values[0]}"
99
+ )
100
+ return df
101
+
102
+ def _set_multi_index(self, df):
103
+ """
104
+ Function sorts by date
105
+
106
+ Parameters
107
+ ----------
108
+ df : The Pandas DataFrame.
109
+
110
+ Returns
111
+ -------
112
+ A new Pandas DataFrame with sorted dates for each series
113
+ """
114
+ df = df.set_index([self.dt_column_name, DataColumns.Series])
115
+ return df.sort_values([self.dt_column_name, DataColumns.Series], ascending=True)
116
+
117
+ def _missing_value_imputation_hist(self, df):
118
+ """
119
+ Function fills missing values in the pandas dataframe using liner interpolation
120
+
121
+ Parameters
122
+ ----------
123
+ df : The Pandas DataFrame.
124
+
125
+ Returns
126
+ -------
127
+ A new Pandas DataFrame without missing values.
128
+ """
129
+ # missing value imputation using linear interpolation
130
+ df[self.target_column_name] = (
131
+ df[self.target_column_name]
132
+ .groupby(DataColumns.Series)
133
+ .transform(lambda x: x.interpolate(limit_direction="both"))
134
+ )
135
+ return df
136
+
137
+ def _missing_value_imputation_add(self, df):
138
+ """
139
+ Function fills missing values in the pandas dataframe using liner interpolation
140
+
141
+ Parameters
142
+ ----------
143
+ df : The Pandas DataFrame.
144
+
145
+ Returns
146
+ -------
147
+ A new Pandas DataFrame without missing values.
148
+ """
149
+ # find columns that all all NA and replace with 0
150
+ for col in df.columns:
151
+ # find next int not in list
152
+ i = 0
153
+ vals = df[col].unique()
154
+ while i in vals:
155
+ i = i + 1
156
+ df[col] = df[col].fillna(0)
157
+ return df
158
+
159
+ def _outlier_treatment(self, df):
160
+ """
161
+ Function finds outliears using z_score and treats with mean value.
162
+
163
+ Parameters
164
+ ----------
165
+ df : The Pandas DataFrame.
166
+
167
+ Returns
168
+ -------
169
+ A new Pandas DataFrame with treated outliears.
170
+ """
171
+ df["z_score"] = (
172
+ df[self.target_column_name]
173
+ .groupby(DataColumns.Series)
174
+ .transform(lambda x: (x - x.mean()) / x.std())
175
+ )
176
+ outliers_mask = df["z_score"].abs() > 3
177
+ df.loc[outliers_mask, self.target_column_name] = (
178
+ df[self.target_column_name]
179
+ .groupby(DataColumns.Series)
180
+ .transform(lambda x: x.mean())
181
+ )
182
+ return df.drop("z_score", axis=1)
183
+
184
+ def _check_historical_dataset(self, df):
185
+ expected_names = [self.target_column_name, self.dt_column_name] + (
186
+ self.target_category_columns if self.target_category_columns else []
187
+ )
188
+ if set(df.columns) != set(expected_names):
189
+ raise DataMismatchError(
190
+ f"Expected {self.name} to have columns: {expected_names}, but instead found column names: {df.columns}. Is the {self.name} path correct?"
191
+ )
@@ -0,0 +1,250 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*--
3
+
4
+ # Copyright (c) 2024 Oracle and/or its affiliates.
5
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
+
7
+ import argparse
8
+ import logging
9
+ import os
10
+ import sys
11
+ import time
12
+ from string import Template
13
+ from typing import Any, Dict, List, Tuple
14
+ import pandas as pd
15
+ from ads.opctl import logger
16
+ import oracledb
17
+
18
+ import fsspec
19
+ import yaml
20
+ from typing import Union
21
+
22
+ from ads.opctl import logger
23
+ from ads.opctl.operator.lowcode.common.errors import (
24
+ InputDataError,
25
+ InvalidParameterError,
26
+ PermissionsError,
27
+ DataMismatchError,
28
+ )
29
+ from ads.opctl.operator.common.operator_config import OutputDirectory
30
+ from ads.common.object_storage_details import ObjectStorageDetails
31
+
32
+
33
+ def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
34
+ if fsspec.utils.get_protocol(filename) == "file":
35
+ return pd_fn(filename, **kwargs)
36
+ elif fsspec.utils.get_protocol(filename) in ["http", "https"]:
37
+ return pd_fn(filename, **kwargs)
38
+
39
+ storage_options = storage_options or (
40
+ default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
41
+ )
42
+
43
+ return pd_fn(filename, storage_options=storage_options, **kwargs)
44
+
45
+
46
+ def load_data(data_spec, storage_options=None, **kwargs):
47
+ if data_spec is None:
48
+ raise InvalidParameterError(f"No details provided for this data source.")
49
+ filename = data_spec.url
50
+ format = data_spec.format
51
+ columns = data_spec.columns
52
+ connect_args = data_spec.connect_args
53
+ sql = data_spec.sql
54
+ table_name = data_spec.table_name
55
+ limit = data_spec.limit
56
+
57
+ storage_options = storage_options or (
58
+ default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
59
+ )
60
+
61
+ if filename is not None:
62
+ if not format:
63
+ _, format = os.path.splitext(filename)
64
+ format = format[1:]
65
+ if format in ["json", "clipboard", "excel", "csv", "feather", "hdf"]:
66
+ read_fn = getattr(pd, f"read_{format}")
67
+ data = call_pandas_fsspec(
68
+ read_fn, filename, storage_options=storage_options
69
+ )
70
+ elif format in ["tsv"]:
71
+ data = call_pandas_fsspec(
72
+ pd.read_csv, filename, storage_options=storage_options, sep="\t"
73
+ )
74
+ else:
75
+ raise InvalidParameterError(
76
+ f"The format {format} is not currently supported for reading data. Please reformat the data source: {filename} ."
77
+ )
78
+ elif connect_args is not None:
79
+ con = oracledb.connect(**connect_args)
80
+ if table_name is not None:
81
+ data = pd.read_sql_table(table_name, con)
82
+ elif sql is not None:
83
+ data = pd.read_sql(sql, con)
84
+ else:
85
+ raise InvalidParameterError(
86
+ f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
87
+ )
88
+ else:
89
+ raise InvalidParameterError(
90
+ f"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively."
91
+ )
92
+ if columns:
93
+ # keep only these columns, done after load because only CSV supports stream filtering
94
+ data = data[columns]
95
+ if limit:
96
+ data = data[:limit]
97
+ return data
98
+
99
+
100
+ def write_data(data, filename, format, storage_options, index=False, **kwargs):
101
+ if not format:
102
+ _, format = os.path.splitext(filename)
103
+ format = format[1:]
104
+ if format in ["json", "clipboard", "excel", "csv", "feather", "hdf"]:
105
+ write_fn = getattr(data, f"to_{format}")
106
+ return call_pandas_fsspec(
107
+ write_fn, filename, index=index, storage_options=storage_options, **kwargs
108
+ )
109
+ raise OperatorYamlContentError(
110
+ f"The format {format} is not currently supported for writing data. Please change the format parameter for the data output: {filename} ."
111
+ )
112
+
113
+
114
+ def merge_category_columns(data, target_category_columns):
115
+ result = data.apply(
116
+ lambda x: "__".join([str(x[col]) for col in target_category_columns]), axis=1
117
+ )
118
+ return result if not result.empty else pd.Series([], dtype=str)
119
+
120
+
121
+ def merged_category_column_name(target_category_columns: Union[List, None]):
122
+ if not target_category_columns or len(target_category_columns) == 0:
123
+ return None
124
+ return "__".join([str(x) for x in target_category_columns])
125
+
126
+
127
+ def datetime_to_seconds(s: pd.Series):
128
+ """
129
+ Method converts a datetime column into an integer number of seconds.
130
+ This method has many uses, most notably for enabling libraries like shap
131
+ to read datetime columns
132
+ ------------
133
+ s: pd.Series
134
+ A Series of type datetime
135
+ Returns
136
+ pd.Series of type int
137
+ """
138
+ return s.apply(lambda x: x.timestamp())
139
+
140
+
141
+ def seconds_to_datetime(s: pd.Series, dt_format=None):
142
+ """
143
+ Inverse of `datetime_to_second`
144
+ ------------
145
+ s: pd.Series
146
+ A Series of type int
147
+ Returns
148
+ pd.Series of type datetime
149
+ """
150
+ s = pd.to_datetime(s, unit="s")
151
+ if dt_format is not None:
152
+ return pd.to_datetime(s, format=dt_format)
153
+ return s
154
+
155
+
156
+ def default_signer(**kwargs):
157
+ os.environ["EXTRA_USER_AGENT_INFO"] = "Operator"
158
+ from ads.common.auth import default_signer
159
+
160
+ return default_signer(**kwargs)
161
+
162
+
163
+ def get_frequency_in_seconds(s: pd.Series, sample_size=100, ignore_duplicates=True):
164
+ """
165
+ Returns frequency of data in seconds
166
+
167
+ Parameters
168
+ ------------
169
+ dt_col: pd.Series Datetime column
170
+ ignore_duplicates: bool if True, duplicates will be dropped before computing frequency
171
+
172
+ Returns
173
+ --------
174
+ int Minimum difference in seconds
175
+ """
176
+ s1 = pd.Series(s).drop_duplicates() if ignore_duplicates else s
177
+ return s1.tail(20).diff().min().total_seconds()
178
+
179
+
180
+ def get_frequency_of_datetime(dt_col: pd.Series, ignore_duplicates=True):
181
+ """
182
+ Returns string frequency of data
183
+
184
+ Parameters
185
+ ------------
186
+ dt_col: pd.Series Datetime column
187
+ ignore_duplicates: bool if True, duplicates will be dropped before computing frequency
188
+
189
+ Returns
190
+ --------
191
+ str Pandas Datetime Frequency
192
+ """
193
+ s = pd.Series(dt_col).drop_duplicates() if ignore_duplicates else dt_col
194
+ return pd.infer_freq(s)
195
+
196
+
197
+ def human_time_friendly(seconds):
198
+ TIME_DURATION_UNITS = (
199
+ ("week", 60 * 60 * 24 * 7),
200
+ ("day", 60 * 60 * 24),
201
+ ("hour", 60 * 60),
202
+ ("min", 60),
203
+ )
204
+ if seconds == 0:
205
+ return "inf"
206
+ accumulator = []
207
+ for unit, div in TIME_DURATION_UNITS:
208
+ amount, seconds = divmod(float(seconds), div)
209
+ if amount > 0:
210
+ accumulator.append(
211
+ "{} {}{}".format(int(amount), unit, "" if amount == 1 else "s")
212
+ )
213
+ accumulator.append("{} secs".format(round(seconds, 2)))
214
+ return ", ".join(accumulator)
215
+
216
+
217
+ def find_output_dirname(output_dir: OutputDirectory):
218
+ if output_dir:
219
+ return output_dir.url
220
+ output_dir = "results"
221
+
222
+ # If the directory exists, find the next unique directory name by appending an incrementing suffix
223
+ counter = 1
224
+ unique_output_dir = f"{output_dir}"
225
+ while os.path.exists(unique_output_dir):
226
+ unique_output_dir = f"{output_dir}_{counter}"
227
+ counter += 1
228
+ logger.warn(
229
+ "Since the output directory was not specified, the output will be saved to {} directory.".format(
230
+ unique_output_dir
231
+ )
232
+ )
233
+ return unique_output_dir
234
+
235
+
236
+ def set_log_level(pkg_name: str, level: int):
237
+ pkg_logger = logging.getLogger(pkg_name)
238
+ pkg_logger.addHandler(logging.NullHandler())
239
+ pkg_logger.propagate = False
240
+ pkg_logger.setLevel(level)
241
+
242
+
243
+ # Disable
244
+ def disable_print():
245
+ sys.stdout = open(os.devnull, "w")
246
+
247
+
248
+ # Restore
249
+ def enable_print():
250
+ sys.stdout = sys.__stdout__
@@ -38,8 +38,9 @@ To run forecasting locally, create and activate a new conda environment (`ads-fo
38
38
  - datapane
39
39
  - cerberus
40
40
  - sktime
41
- - optuna==2.9.0
42
- - oracle-automlx==23.2.3
41
+ - optuna==3.1.0
42
+ - oracle-automlx==23.4.1
43
+ - oracle-automlx[forecasting]==23.4.1
43
44
  - oracle-ads>=2.9.0
44
45
  ```
45
46
 
@@ -1,7 +1,7 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8 -*--
3
3
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
4
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
 
7
7
  import json
@@ -24,7 +24,23 @@ def operate(operator_config: ForecastOperatorConfig) -> None:
24
24
  from .model.factory import ForecastOperatorModelFactory
25
25
 
26
26
  datasets = ForecastDatasets(operator_config)
27
- ForecastOperatorModelFactory.get_model(operator_config, datasets).generate_report()
27
+ try:
28
+ ForecastOperatorModelFactory.get_model(
29
+ operator_config, datasets
30
+ ).generate_report()
31
+ except Exception as e:
32
+ if operator_config.spec.model == "auto":
33
+ logger.debug(
34
+ f"Failed to forecast with error {e.args}. Trying again with model `prophet`."
35
+ )
36
+ operator_config.spec.model = "prophet"
37
+ operator_config.spec.model_kwargs = dict()
38
+ datasets = ForecastDatasets(operator_config)
39
+ ForecastOperatorModelFactory.get_model(
40
+ operator_config, datasets
41
+ ).generate_report()
42
+ else:
43
+ raise
28
44
 
29
45
 
30
46
  def verify(spec: Dict, **kwargs: Dict) -> bool:
@@ -32,13 +32,14 @@ def init(**kwargs: Dict) -> str:
32
32
  str
33
33
  The YAML specification generated based on the schema.
34
34
  """
35
- logger.info("==== Forecasting related options ====")
36
-
37
- model_type = click.prompt(
38
- "Provide a model type:",
39
- type=click.Choice(SupportedModels.values()),
40
- default=SupportedModels.Auto,
41
- )
35
+ # logger.info("==== Forecasting related options ====")
36
+
37
+ # model_type = click.prompt(
38
+ # "Provide a model type:",
39
+ # type=click.Choice(SupportedModels.values()),
40
+ # default=SupportedModels.Auto,
41
+ # )
42
+ model_type = "auto"
42
43
 
43
44
  return YamlGenerator(
44
45
  schema=_load_yaml_from_uri(__file__.replace("cmd.py", "schema.yaml"))
@@ -5,6 +5,7 @@
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
 
7
7
  from ads.common.extended_enum import ExtendedEnumMeta
8
+ from ads.opctl.operator.lowcode.common.const import DataColumns
8
9
 
9
10
 
10
11
  class SupportedModels(str, metaclass=ExtendedEnumMeta):
@@ -18,6 +19,20 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
18
19
  Auto = "auto"
19
20
 
20
21
 
22
+ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
23
+ """
24
+ Enum representing different modes based on time taken and accuracy for explainability.
25
+ """
26
+
27
+ HIGH_ACCURACY = "HIGH_ACCURACY"
28
+ BALANCED = "BALANCED"
29
+ FAST_APPROXIMATE = "FAST_APPROXIMATE"
30
+ ratio = dict()
31
+ ratio[HIGH_ACCURACY] = 1 # 100 % data used for generating explanations
32
+ ratio[BALANCED] = 0.5 # 50 % data used for generating explanations
33
+ ratio[FAST_APPROXIMATE] = 0 # constant
34
+
35
+
21
36
  class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
22
37
  """Supported forecast metrics."""
23
38
 
@@ -49,7 +64,7 @@ class ForecastOutputColumns(str, metaclass=ExtendedEnumMeta):
49
64
  """The column names for the forecast.csv output file"""
50
65
 
51
66
  DATE = "Date"
52
- SERIES = "Series"
67
+ SERIES = DataColumns.Series
53
68
  INPUT_VALUE = "input_value"
54
69
  FITTED_VALUE = "fitted_value"
55
70
  FORECAST_VALUE = "forecast_value"
@@ -70,3 +85,4 @@ MAX_COLUMNS_AUTOMLX = 15
70
85
  DEFAULT_TRIALS = 10
71
86
  SUMMARY_METRICS_HORIZON_LIMIT = 10
72
87
  PROPHET_INTERNAL_DATE_COL = "ds"
88
+ RENDER_LIMIT = 5000
@@ -15,5 +15,6 @@ dependencies:
15
15
  - sktime
16
16
  - shap
17
17
  - autots[additional]
18
- - optuna==2.9.0
19
- - oracle-automlx==23.2.3
18
+ - optuna==3.1.0
19
+ - oracle-automlx==23.4.1
20
+ - oracle-automlx[forecasting]==23.4.1