oracle-ads 2.10.0__py3-none-any.whl → 2.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. ads/aqua/__init__.py +12 -0
  2. ads/aqua/base.py +324 -0
  3. ads/aqua/cli.py +19 -0
  4. ads/aqua/config/deployment_config_defaults.json +9 -0
  5. ads/aqua/config/resource_limit_names.json +7 -0
  6. ads/aqua/constants.py +45 -0
  7. ads/aqua/data.py +40 -0
  8. ads/aqua/decorator.py +101 -0
  9. ads/aqua/deployment.py +643 -0
  10. ads/aqua/dummy_data/icon.txt +1 -0
  11. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  12. ads/aqua/dummy_data/oci_models.json +1 -0
  13. ads/aqua/dummy_data/readme.md +26 -0
  14. ads/aqua/evaluation.py +1751 -0
  15. ads/aqua/exception.py +82 -0
  16. ads/aqua/extension/__init__.py +40 -0
  17. ads/aqua/extension/base_handler.py +138 -0
  18. ads/aqua/extension/common_handler.py +21 -0
  19. ads/aqua/extension/deployment_handler.py +202 -0
  20. ads/aqua/extension/evaluation_handler.py +135 -0
  21. ads/aqua/extension/finetune_handler.py +66 -0
  22. ads/aqua/extension/model_handler.py +59 -0
  23. ads/aqua/extension/ui_handler.py +201 -0
  24. ads/aqua/extension/utils.py +23 -0
  25. ads/aqua/finetune.py +579 -0
  26. ads/aqua/job.py +29 -0
  27. ads/aqua/model.py +819 -0
  28. ads/aqua/training/__init__.py +4 -0
  29. ads/aqua/training/exceptions.py +459 -0
  30. ads/aqua/ui.py +453 -0
  31. ads/aqua/utils.py +715 -0
  32. ads/cli.py +37 -6
  33. ads/common/auth.py +7 -0
  34. ads/common/decorator/__init__.py +7 -3
  35. ads/common/decorator/require_nonempty_arg.py +65 -0
  36. ads/common/object_storage_details.py +166 -7
  37. ads/common/oci_client.py +18 -1
  38. ads/common/oci_logging.py +2 -2
  39. ads/common/oci_mixin.py +4 -5
  40. ads/common/serializer.py +34 -5
  41. ads/common/utils.py +75 -10
  42. ads/config.py +40 -1
  43. ads/dataset/correlation_plot.py +10 -12
  44. ads/jobs/ads_job.py +43 -25
  45. ads/jobs/builders/infrastructure/base.py +4 -2
  46. ads/jobs/builders/infrastructure/dsc_job.py +49 -39
  47. ads/jobs/builders/runtimes/base.py +71 -1
  48. ads/jobs/builders/runtimes/container_runtime.py +4 -4
  49. ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
  50. ads/jobs/templates/driver_pytorch.py +27 -10
  51. ads/model/artifact_downloader.py +84 -14
  52. ads/model/artifact_uploader.py +25 -23
  53. ads/model/datascience_model.py +388 -38
  54. ads/model/deployment/model_deployment.py +10 -2
  55. ads/model/generic_model.py +8 -0
  56. ads/model/model_file_description_schema.json +68 -0
  57. ads/model/model_metadata.py +1 -1
  58. ads/model/service/oci_datascience_model.py +34 -5
  59. ads/opctl/config/merger.py +2 -2
  60. ads/opctl/operator/__init__.py +3 -1
  61. ads/opctl/operator/cli.py +7 -1
  62. ads/opctl/operator/cmd.py +3 -3
  63. ads/opctl/operator/common/errors.py +2 -1
  64. ads/opctl/operator/common/operator_config.py +22 -3
  65. ads/opctl/operator/common/utils.py +16 -0
  66. ads/opctl/operator/lowcode/anomaly/MLoperator +15 -0
  67. ads/opctl/operator/lowcode/anomaly/README.md +209 -0
  68. ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
  69. ads/opctl/operator/lowcode/anomaly/__main__.py +104 -0
  70. ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
  71. ads/opctl/operator/lowcode/anomaly/const.py +88 -0
  72. ads/opctl/operator/lowcode/anomaly/environment.yaml +12 -0
  73. ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
  74. ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +147 -0
  75. ads/opctl/operator/lowcode/anomaly/model/automlx.py +89 -0
  76. ads/opctl/operator/lowcode/anomaly/model/autots.py +103 -0
  77. ads/opctl/operator/lowcode/anomaly/model/base_model.py +354 -0
  78. ads/opctl/operator/lowcode/anomaly/model/factory.py +67 -0
  79. ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
  80. ads/opctl/operator/lowcode/anomaly/operator_config.py +105 -0
  81. ads/opctl/operator/lowcode/anomaly/schema.yaml +359 -0
  82. ads/opctl/operator/lowcode/anomaly/utils.py +81 -0
  83. ads/opctl/operator/lowcode/common/__init__.py +5 -0
  84. ads/opctl/operator/lowcode/common/const.py +10 -0
  85. ads/opctl/operator/lowcode/common/data.py +96 -0
  86. ads/opctl/operator/lowcode/common/errors.py +41 -0
  87. ads/opctl/operator/lowcode/common/transformations.py +191 -0
  88. ads/opctl/operator/lowcode/common/utils.py +250 -0
  89. ads/opctl/operator/lowcode/forecast/README.md +3 -2
  90. ads/opctl/operator/lowcode/forecast/__main__.py +18 -2
  91. ads/opctl/operator/lowcode/forecast/cmd.py +8 -7
  92. ads/opctl/operator/lowcode/forecast/const.py +17 -1
  93. ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
  94. ads/opctl/operator/lowcode/forecast/model/arima.py +106 -117
  95. ads/opctl/operator/lowcode/forecast/model/automlx.py +204 -180
  96. ads/opctl/operator/lowcode/forecast/model/autots.py +144 -253
  97. ads/opctl/operator/lowcode/forecast/model/base_model.py +326 -259
  98. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +325 -176
  99. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +293 -237
  100. ads/opctl/operator/lowcode/forecast/model/prophet.py +191 -208
  101. ads/opctl/operator/lowcode/forecast/operator_config.py +24 -33
  102. ads/opctl/operator/lowcode/forecast/schema.yaml +116 -29
  103. ads/opctl/operator/lowcode/forecast/utils.py +186 -356
  104. ads/opctl/operator/lowcode/pii/model/guardrails.py +18 -15
  105. ads/opctl/operator/lowcode/pii/model/report.py +7 -7
  106. ads/opctl/operator/lowcode/pii/operator_config.py +1 -8
  107. ads/opctl/operator/lowcode/pii/utils.py +0 -82
  108. ads/opctl/operator/runtime/runtime.py +3 -2
  109. ads/telemetry/base.py +62 -0
  110. ads/telemetry/client.py +105 -0
  111. ads/telemetry/telemetry.py +6 -3
  112. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +44 -7
  113. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +116 -59
  114. ads/opctl/operator/lowcode/forecast/model/transformations.py +0 -125
  115. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
  116. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
  117. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
@@ -4,14 +4,121 @@
4
4
  # Copyright (c) 2023 Oracle and/or its affiliates.
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
 
7
+ import time
7
8
  import pandas as pd
9
+ from pandas.api.types import is_datetime64_any_dtype, is_string_dtype, is_numeric_dtype
10
+
8
11
  from ..operator_config import ForecastOperatorConfig
9
- from .. import utils
10
- from .transformations import Transformations
11
12
  from ads.opctl import logger
12
- import pandas as pd
13
13
  from ..const import ForecastOutputColumns, PROPHET_INTERNAL_DATE_COL
14
- from pandas.api.types import is_datetime64_any_dtype, is_string_dtype, is_numeric_dtype
14
+ from ads.common.object_storage_details import ObjectStorageDetails
15
+ from ads.opctl.operator.lowcode.common.utils import (
16
+ get_frequency_in_seconds,
17
+ get_frequency_of_datetime,
18
+ )
19
+ from ads.opctl.operator.lowcode.common.data import AbstractData
20
+ from ads.opctl.operator.lowcode.forecast.utils import (
21
+ default_signer,
22
+ )
23
+ from ads.opctl.operator.lowcode.common.errors import (
24
+ InputDataError,
25
+ InvalidParameterError,
26
+ PermissionsError,
27
+ DataMismatchError,
28
+ )
29
+ from ..const import SupportedModels
30
+ from abc import ABC, abstractmethod
31
+
32
+
33
+ class HistoricalData(AbstractData):
34
+ def __init__(self, spec: dict):
35
+ super().__init__(spec=spec, name="historical_data")
36
+
37
+ def _ingest_data(self, spec):
38
+ try:
39
+ self.freq = get_frequency_of_datetime(self.data.index.get_level_values(0))
40
+ except TypeError as e:
41
+ logger.warn(
42
+ f"Error determining frequency: {e.args}. Setting Frequency to None"
43
+ )
44
+ logger.debug(f"Full traceback: {e}")
45
+ self.freq = None
46
+ self._verify_dt_col(spec)
47
+ super()._ingest_data(spec)
48
+
49
+ def _verify_dt_col(self, spec):
50
+ # Check frequency is compatible with model type
51
+ self.freq_in_secs = get_frequency_in_seconds(
52
+ self.data.index.get_level_values(0)
53
+ )
54
+ if spec.model == SupportedModels.AutoMLX:
55
+ if abs(self.freq_in_secs) < 3600:
56
+ message = (
57
+ "{} requires data with a frequency of at least one hour. Please try using a different model,"
58
+ " or select the 'auto' option.".format(SupportedModels.AutoMLX)
59
+ )
60
+ raise InvalidParameterError(message)
61
+
62
+
63
+ class AdditionalData(AbstractData):
64
+ def __init__(self, spec, historical_data):
65
+ if spec.additional_data is not None:
66
+ super().__init__(spec=spec, name="additional_data")
67
+ add_dates = self.data.index.get_level_values(0).unique().tolist()
68
+ add_dates.sort()
69
+ if historical_data.get_max_time() > add_dates[-spec.horizon]:
70
+ raise DataMismatchError(
71
+ f"The Historical Data ends on {historical_data.get_max_time()}. The additional data horizon starts on {add_dates[-spec.horizon]}. The horizon should have exactly {spec.horizon} dates after the Hisotrical at a frequency of {historical_data.freq}"
72
+ )
73
+ elif historical_data.get_max_time() != add_dates[-(spec.horizon + 1)]:
74
+ raise DataMismatchError(
75
+ f"The Additional Data must be present for all historical data and the entire horizon. The Historical Data ends on {historical_data.get_max_time()}. The additonal data horizon starts after {add_dates[-(spec.horizon+1)]}. These should be the same date."
76
+ )
77
+ else:
78
+ self.name = "additional_data"
79
+ self.data = None
80
+ self._data_dict = dict()
81
+ self.create_horizon(spec, historical_data)
82
+
83
+ def create_horizon(self, spec, historical_data):
84
+ logger.debug(f"No additional data provided. Constructing horizon.")
85
+ future_dates = pd.Series(
86
+ pd.date_range(
87
+ start=historical_data.get_max_time(),
88
+ periods=spec.horizon + 1,
89
+ freq=historical_data.freq,
90
+ ),
91
+ name=spec.datetime_column.name,
92
+ )
93
+ add_dfs = []
94
+ for s_id in historical_data.list_series_ids():
95
+ df_i = historical_data.get_data_for_series(s_id)[spec.datetime_column.name]
96
+ df_i = pd.DataFrame(pd.concat([df_i, future_dates[1:]]))
97
+ df_i[ForecastOutputColumns.SERIES] = s_id
98
+ df_i = df_i.set_index(
99
+ [spec.datetime_column.name, ForecastOutputColumns.SERIES]
100
+ )
101
+ add_dfs.append(df_i)
102
+ data = pd.concat(add_dfs, axis=1)
103
+ self.data = data.sort_values(
104
+ [spec.datetime_column.name, ForecastOutputColumns.SERIES], ascending=True
105
+ )
106
+ self.additional_regressors = []
107
+
108
+ def _ingest_data(self, spec):
109
+ self.additional_regressors = list(self.data.columns)
110
+ if not self.additional_regressors:
111
+ logger.warn(
112
+ f"No additional variables found in the additional_data. Only columns found: {self.data.columns}. Skipping for now."
113
+ )
114
+ # Check that datetime column matches historical datetime column
115
+
116
+
117
+ class TestData(AbstractData):
118
+ def __init__(self, spec):
119
+ super().__init__(spec=spec, name="test_data")
120
+ self.dt_column_name = spec.datetime_column.name
121
+ self.target_name = spec.target_column
15
122
 
16
123
 
17
124
  class ForecastDatasets:
@@ -23,223 +130,259 @@ class ForecastDatasets:
23
130
  config: ForecastOperatorConfig
24
131
  The forecast operator configuration.
25
132
  """
26
- self.original_user_data = None
27
- self.original_total_data = None
28
- self.original_additional_data = None
29
- self.full_data_dict = None
30
- self.target_columns = None
31
- self.categories = None
32
- self.datetime_col = PROPHET_INTERNAL_DATE_COL
33
- self.datetime_format = config.spec.datetime_column.format
133
+ self.historical_data: HistoricalData = None
134
+ self.additional_data: AdditionalData = None
135
+
136
+ self._horizon = config.spec.horizon
137
+ self._datetime_column_name = config.spec.datetime_column.name
34
138
  self._load_data(config.spec)
35
139
 
36
140
  def _load_data(self, spec):
37
141
  """Loads forecasting input data."""
142
+ self.historical_data = HistoricalData(spec)
143
+ self.additional_data = AdditionalData(spec, self.historical_data)
38
144
 
39
- raw_data = utils._load_data(
40
- filename=spec.historical_data.url,
41
- format=spec.historical_data.format,
42
- columns=spec.historical_data.columns,
145
+ if spec.generate_explanations:
146
+ if spec.additional_data is None:
147
+ logger.warn(
148
+ f"Unable to generate explanations as there is no additional data passed in. Either set generate_explanations to False, or pass in additional data."
149
+ )
150
+ spec.generate_explanations = False
151
+
152
+ def get_all_data_long(self, include_horizon=True):
153
+ how = "outer" if include_horizon else "left"
154
+ return pd.merge(
155
+ self.historical_data.data,
156
+ self.additional_data.data,
157
+ how=how,
158
+ on=[self._datetime_column_name, ForecastOutputColumns.SERIES],
159
+ ).reset_index()
160
+
161
+ def get_data_multi_indexed(self):
162
+ return pd.concat(
163
+ [
164
+ self.historical_data.data,
165
+ self.additional_data.data,
166
+ ],
167
+ axis=1,
43
168
  )
44
- self.original_user_data = raw_data.copy()
45
- data_transformer = Transformations(raw_data, spec)
46
- data = data_transformer.run()
47
- try:
48
- spec.freq = utils.get_frequency_of_datetime(data, spec)
49
- except TypeError as e:
50
- logger.warn(
51
- f"Error determining frequency: {e.args}. Setting Frequency to None"
52
- )
53
- logger.debug(f"Full traceback: {e}")
54
- spec.freq = None
55
169
 
56
- self.original_total_data = data
57
- additional_data = None
170
+ def get_data_by_series(self, include_horizon=True):
171
+ total_dict = dict()
172
+ hist_data = self.historical_data.get_dict_by_series()
173
+ add_data = self.additional_data.get_dict_by_series()
174
+ how = "outer" if include_horizon else "left"
175
+ for s_id in self.list_series_ids():
176
+ # Note: ensure no duplicate column names
177
+ total_dict[s_id] = pd.merge(
178
+ hist_data[s_id],
179
+ add_data[s_id],
180
+ how=how,
181
+ on=[self._datetime_column_name],
182
+ )
183
+ return total_dict
58
184
 
185
+ def get_data_at_series(self, s_id, include_horizon=True):
186
+ all_data = self.get_data_by_series(include_horizon=include_horizon)
59
187
  try:
60
- data[spec.datetime_column.name] = pd.to_datetime(
61
- data[spec.datetime_column.name], format=self.datetime_format
62
- )
188
+ return all_data[s_id]
63
189
  except:
64
- raise ValueError(
65
- f"Unable to determine the datetime type for column: {spec.datetime_column.name}. Please specify the format explicitly."
190
+ raise InvalidParameterError(
191
+ f"Unable to retrieve series id: {s_id} from data. Available series ids are: {self.list_series_ids()}"
66
192
  )
67
193
 
68
- if spec.additional_data is not None:
69
- additional_data = utils._load_data(
70
- filename=spec.additional_data.url,
71
- format=spec.additional_data.format,
72
- columns=spec.additional_data.columns,
73
- )
74
- additional_data = data_transformer._sort_by_datetime_col(additional_data)
75
- try:
76
- additional_data[spec.datetime_column.name] = pd.to_datetime(
77
- additional_data[spec.datetime_column.name],
78
- format=self.datetime_format,
79
- )
80
- except:
81
- raise ValueError(
82
- f"Unable to determine the datetime type for column: {spec.datetime_column.name}. Please specify the format explicitly."
83
- )
194
+ def get_horizon_at_series(self, s_id):
195
+ return self.get_data_at_series(s_id)[-self._horizon :]
84
196
 
85
- self.original_additional_data = additional_data.copy()
86
- self.original_total_data = pd.concat([data, additional_data], axis=1)
87
- else:
88
- # Need to add the horizon to the data for compatibility
89
- additional_data_small = data[
90
- [spec.datetime_column.name] + spec.target_category_columns
91
- ].set_index(spec.datetime_column.name)
92
- if is_datetime64_any_dtype(additional_data_small.index):
93
- horizon_index = pd.date_range(
94
- start=additional_data_small.index.values[-1],
95
- freq=spec.freq,
96
- periods=spec.horizon + 1,
97
- )[1:]
98
- elif is_numeric_dtype(additional_data_small.index):
99
- # If datetime column is just ints
100
- assert (
101
- len(additional_data_small.index.values) > 1
102
- ), "Dataset is too small to infer frequency. Please pass in the horizon explicitly through the additional data."
103
- start = additional_data_small.index.values[-1]
104
- step = (
105
- additional_data_small.index.values[-1]
106
- - additional_data_small.index.values[-2]
107
- )
108
- horizon_index = pd.RangeIndex(
109
- start, start + step * (spec.horizon + 1), step=step
110
- )[1:]
111
- else:
112
- raise ValueError(
113
- f"Unable to determine the datetime type for column: {spec.datetime_column.name}. Please specify the format explicitly."
114
- )
197
+ def has_artificial_series(self):
198
+ return self.historical_data._data_transformer.has_artificial_series
115
199
 
116
- additional_data = pd.DataFrame()
117
-
118
- for cat_col in spec.target_category_columns:
119
- for cat in additional_data_small[cat_col].unique():
120
- add_data_i = additional_data_small[
121
- additional_data_small[cat_col] == cat
122
- ]
123
- horizon_df_i = pd.DataFrame([], index=horizon_index)
124
- horizon_df_i[cat_col] = cat
125
- additional_data = pd.concat(
126
- [additional_data, add_data_i, horizon_df_i]
127
- )
128
- additional_data = additional_data.reset_index().rename(
129
- {"index": spec.datetime_column.name}, axis=1
130
- )
200
+ def get_earliest_timestamp(self):
201
+ return self.historical_data.get_min_time()
131
202
 
132
- self.original_total_data = pd.concat([data, additional_data], axis=1)
133
-
134
- (
135
- self.full_data_dict,
136
- self.target_columns,
137
- self.categories,
138
- ) = utils._build_indexed_datasets(
139
- data=data,
140
- target_column=spec.target_column,
141
- datetime_column=spec.datetime_column.name,
142
- horizon=spec.horizon,
143
- target_category_columns=spec.target_category_columns,
144
- additional_data=additional_data,
145
- )
146
- if spec.generate_explanations:
147
- if spec.additional_data is None:
148
- logger.warn(
149
- f"Unable to generate explanations as there is no additional data passed in. Either set generate_explanations to False, or pass in additional data."
150
- )
151
- spec.generate_explanations = False
203
+ def get_latest_timestamp(self):
204
+ return self.historical_data.get_max_time()
205
+
206
+ def get_additional_data_column_names(self):
207
+ return self.additional_data.additional_regressors
208
+
209
+ def get_datetime_frequency(self):
210
+ return self.historical_data.freq
211
+
212
+ def get_datetime_frequency_in_seconds(self):
213
+ return self.historical_data.freq_in_secs
214
+
215
+ def get_num_rows(self):
216
+ return self.historical_data.get_num_rows()
217
+
218
+ def list_series_ids(self, sorted=True):
219
+ series_ids = self.historical_data.list_series_ids()
220
+ if sorted:
221
+ try:
222
+ series_ids.sort()
223
+ except:
224
+ pass
225
+ return series_ids
152
226
 
153
227
  def format_wide(self):
154
228
  data_merged = pd.concat(
155
229
  [
156
- v[v[k].notna()].set_index(self.datetime_col)
157
- for k, v in self.full_data_dict.items()
230
+ v[v[k].notna()].set_index(self._datetime_column_name)
231
+ for k, v in self.get_data_by_series().items()
158
232
  ],
159
233
  axis=1,
160
234
  ).reset_index()
161
235
  return data_merged
162
236
 
163
237
  def get_longest_datetime_column(self):
164
- return pd.to_datetime(
165
- self.format_wide()[self.datetime_col], format=self.datetime_format
166
- )
238
+ return self.format_wide()[self._datetime_column_name]
167
239
 
168
240
 
169
241
  class ForecastOutput:
170
- def __init__(self, confidence_interval_width: float):
242
+ def __init__(
243
+ self,
244
+ confidence_interval_width: float,
245
+ horizon: int,
246
+ target_column: str,
247
+ dt_column: str,
248
+ ):
171
249
  """Forecast Output contains all of the details required to generate the forecast.csv output file.
172
250
 
173
- Methods
174
- ----------
175
-
251
+ init
252
+ -------
253
+ confidence_interval_width: float value from OperatorSpec
254
+ horizon: int length of horizon
255
+ target_column: str the name of the original target column
256
+ dt_column: the name of the original datetime column
176
257
  """
177
- self.category_map = dict()
178
- self.category_to_target = dict()
179
- self.confidence_interval_width = confidence_interval_width
180
- self.upper_bound_name = None
181
- self.lower_bound_name = None
258
+ self.series_id_map = dict()
259
+ self._set_ci_column_names(confidence_interval_width)
260
+ self.horizon = horizon
261
+ self.target_column_name = target_column
262
+ self.dt_column_name = dt_column
182
263
 
183
- def add_category(
264
+ def add_series_id(
184
265
  self,
185
- category: str,
186
- target_category_column: str,
266
+ series_id: str,
187
267
  forecast: pd.DataFrame,
188
268
  overwrite: bool = False,
189
269
  ):
190
- if not overwrite and category in self.category_map.keys():
270
+ if not overwrite and series_id in self.series_id_map.keys():
191
271
  raise ValueError(
192
- f"Attempting to update ForecastOutput for category {category} when this already exists. Set overwrite to True."
272
+ f"Attempting to update ForecastOutput for series_id {series_id} when this already exists. Set overwrite to True."
193
273
  )
194
274
  forecast = self._check_forecast_format(forecast)
195
- forecast = self._set_ci_column_names(forecast)
196
- self.category_map[category] = forecast
197
- self.category_to_target[category] = target_category_column
275
+ self.series_id_map[series_id] = forecast
198
276
 
199
- def get_category(self, category): # change to by_category ?
200
- return self.category_map[category]
277
+ def init_series_output(self, series_id, data_at_series):
278
+ output_i = pd.DataFrame()
201
279
 
202
- def get_target_category(self, target_category_column):
203
- target_category_columns = self.list_target_category_columns()
204
- category = self.list_categories()[
205
- list(self.category_to_target.values()).index(target_category_column)
206
- ]
207
- return self.category_map[category]
280
+ output_i["Date"] = data_at_series[self.dt_column_name]
281
+ output_i["Series"] = series_id
282
+ output_i["input_value"] = data_at_series[self.target_column_name]
208
283
 
209
- def list_categories(self):
210
- return list(self.category_map.keys())
284
+ output_i["fitted_value"] = float("nan")
285
+ output_i["forecast_value"] = float("nan")
286
+ output_i[self.lower_bound_name] = float("nan")
287
+ output_i[self.upper_bound_name] = float("nan")
288
+ self.series_id_map[series_id] = output_i
211
289
 
212
- def list_target_category_columns(self):
213
- return list(self.category_to_target.values())
290
+ def populate_series_output(
291
+ self, series_id, fit_val, forecast_val, upper_bound, lower_bound
292
+ ):
293
+ """
294
+ This method should be run after init_series_output has been run on this series_id
214
295
 
215
- def format_long(self):
216
- return pd.concat(list(self.category_map.values()))
296
+ Parameters:
297
+ -----------
298
+ series_id: [str, int] the series being forecasted
299
+ fit_val: numpy.array of length input_value - horizon
300
+ forecast_val: numpy.array of length horizon containing the forecasted values
301
+ upper_bound: numpy.array of length horizon containing the upper_bound values
302
+ lower_bound: numpy.array of length horizon containing the lower_bound values
217
303
 
218
- def _set_ci_column_names(self, forecast_i):
219
- yhat_lower_percentage = (100 - self.confidence_interval_width * 100) // 2
220
- self.upper_bound_name = "p" + str(int(100 - yhat_lower_percentage))
221
- self.lower_bound_name = "p" + str(int(yhat_lower_percentage))
222
- return forecast_i.rename(
223
- {
224
- ForecastOutputColumns.UPPER_BOUND: self.upper_bound_name,
225
- ForecastOutputColumns.LOWER_BOUND: self.lower_bound_name,
226
- },
227
- axis=1,
228
- )
304
+ Returns:
305
+ --------
306
+ None
307
+ """
308
+ try:
309
+ output_i = self.series_id_map[series_id]
310
+ except KeyError:
311
+ raise ValueError(
312
+ f"Attempting to update output for series: {series_id}, however no series output has been initialized."
313
+ )
229
314
 
230
- def format_wide(self):
231
- dataset_time_indexed = {
232
- k: v.set_index(ForecastOutputColumns.DATE)
233
- for k, v in self.category_map.items()
234
- }
235
- datasets_category_appended = [
236
- v.rename(lambda x: str(x) + f"_{k}", axis=1)
237
- for k, v in dataset_time_indexed.items()
238
- ]
239
- return pd.concat(datasets_category_appended, axis=1)
315
+ if (output_i.shape[0] - self.horizon) == len(fit_val):
316
+ output_i["fitted_value"].iloc[
317
+ : -self.horizon
318
+ ] = fit_val # Note: may need to do len(output_i) - (len(fit_val) + horizon) : -horizon
319
+ elif (output_i.shape[0] - self.horizon) > len(fit_val):
320
+ logger.debug(
321
+ f"Fitted Values were only generated on a subset ({len(fit_val)}/{(output_i.shape[0] - self.horizon)}) of the data for Series: {series_id}."
322
+ )
323
+ start_idx = output_i.shape[0] - self.horizon - len(fit_val)
324
+ output_i["fitted_value"].iloc[start_idx : -self.horizon] = fit_val
325
+ else:
326
+ output_i["fitted_value"].iloc[start_idx : -self.horizon] = fit_val[
327
+ -(output_i.shape[0] - self.horizon) :
328
+ ]
240
329
 
241
- def get_longest_datetime_column(self):
242
- return self.format_wide().index
330
+ if len(forecast_val) != self.horizon:
331
+ raise ValueError(
332
+ f"Attempting to set forecast along horizon ({self.horizon}) for series: {series_id}, however forecast is only length {len(forecast_val)}"
333
+ )
334
+ output_i["forecast_value"].iloc[-self.horizon :] = forecast_val
335
+
336
+ if len(upper_bound) != self.horizon:
337
+ raise ValueError(
338
+ f"Attempting to set upper_bound along horizon ({self.horizon}) for series: {series_id}, however upper_bound is only length {len(upper_bound)}"
339
+ )
340
+ output_i[self.upper_bound_name].iloc[-self.horizon :] = upper_bound
341
+
342
+ if len(lower_bound) != self.horizon:
343
+ raise ValueError(
344
+ f"Attempting to set lower_bound along horizon ({self.horizon}) for series: {series_id}, however lower_bound is only length {len(lower_bound)}"
345
+ )
346
+ output_i[self.lower_bound_name].iloc[-self.horizon :] = lower_bound
347
+
348
+ self.series_id_map[series_id] = output_i
349
+ self.verify_series_output(series_id)
350
+
351
+ def verify_series_output(self, series_id):
352
+ forecast = self.series_id_map[series_id]
353
+ self._check_forecast_format(forecast)
354
+
355
+ def get_horizon_by_series(self, series_id):
356
+ return self.series_id_map[series_id][-self.horizon :]
357
+
358
+ def get_horizon_long(self):
359
+ df = pd.DataFrame()
360
+ for s_id in self.list_series_ids():
361
+ df = pd.concat([df, self.get_horizon_by_series(s_id)])
362
+ return df.reset_index(drop=True)
363
+
364
+ def get_forecast(self, series_id):
365
+ try:
366
+ return self.series_id_map[series_id]
367
+ except KeyError as ke:
368
+ logger.debug(
369
+ f"No Forecast found for series_id: {series_id}. Returning empty DataFrame."
370
+ )
371
+ return pd.DataFrame()
372
+
373
+ def list_series_ids(self, sorted=True):
374
+ series_ids = list(self.series_id_map.keys())
375
+ if sorted:
376
+ try:
377
+ series_ids.sort()
378
+ except:
379
+ pass
380
+ return series_ids
381
+
382
+ def _set_ci_column_names(self, confidence_interval_width):
383
+ yhat_lower_percentage = (100 - confidence_interval_width * 100) // 2
384
+ self.upper_bound_name = "p" + str(int(100 - yhat_lower_percentage))
385
+ self.lower_bound_name = "p" + str(int(yhat_lower_percentage))
243
386
 
244
387
  def _check_forecast_format(self, forecast):
245
388
  assert isinstance(forecast, pd.DataFrame)
@@ -251,8 +394,8 @@ class ForecastOutput:
251
394
  assert ForecastOutputColumns.INPUT_VALUE in forecast.columns
252
395
  assert ForecastOutputColumns.FITTED_VALUE in forecast.columns
253
396
  assert ForecastOutputColumns.FORECAST_VALUE in forecast.columns
254
- assert ForecastOutputColumns.UPPER_BOUND in forecast.columns
255
- assert ForecastOutputColumns.LOWER_BOUND in forecast.columns
397
+ assert self.upper_bound_name in forecast.columns
398
+ assert self.lower_bound_name in forecast.columns
256
399
  assert not forecast.empty
257
400
  # forecast.columns = pd.Index([
258
401
  # ForecastOutputColumns.DATE,
@@ -264,3 +407,9 @@ class ForecastOutput:
264
407
  # ForecastOutputColumns.LOWER_BOUND,
265
408
  # ])
266
409
  return forecast
410
+
411
+ def get_forecast_long(self):
412
+ output = pd.DataFrame()
413
+ for df in self.series_id_map.values():
414
+ output = pd.concat([output, df])
415
+ return output.reset_index(drop=True)