oracle-ads 2.10.0__py3-none-any.whl → 2.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +12 -0
- ads/aqua/base.py +324 -0
- ads/aqua/cli.py +19 -0
- ads/aqua/config/deployment_config_defaults.json +9 -0
- ads/aqua/config/resource_limit_names.json +7 -0
- ads/aqua/constants.py +45 -0
- ads/aqua/data.py +40 -0
- ads/aqua/decorator.py +101 -0
- ads/aqua/deployment.py +643 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation.py +1751 -0
- ads/aqua/exception.py +82 -0
- ads/aqua/extension/__init__.py +40 -0
- ads/aqua/extension/base_handler.py +138 -0
- ads/aqua/extension/common_handler.py +21 -0
- ads/aqua/extension/deployment_handler.py +202 -0
- ads/aqua/extension/evaluation_handler.py +135 -0
- ads/aqua/extension/finetune_handler.py +66 -0
- ads/aqua/extension/model_handler.py +59 -0
- ads/aqua/extension/ui_handler.py +201 -0
- ads/aqua/extension/utils.py +23 -0
- ads/aqua/finetune.py +579 -0
- ads/aqua/job.py +29 -0
- ads/aqua/model.py +819 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +459 -0
- ads/aqua/ui.py +453 -0
- ads/aqua/utils.py +715 -0
- ads/cli.py +37 -6
- ads/common/auth.py +7 -0
- ads/common/decorator/__init__.py +7 -3
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/object_storage_details.py +166 -7
- ads/common/oci_client.py +18 -1
- ads/common/oci_logging.py +2 -2
- ads/common/oci_mixin.py +4 -5
- ads/common/serializer.py +34 -5
- ads/common/utils.py +75 -10
- ads/config.py +40 -1
- ads/dataset/correlation_plot.py +10 -12
- ads/jobs/ads_job.py +43 -25
- ads/jobs/builders/infrastructure/base.py +4 -2
- ads/jobs/builders/infrastructure/dsc_job.py +49 -39
- ads/jobs/builders/runtimes/base.py +71 -1
- ads/jobs/builders/runtimes/container_runtime.py +4 -4
- ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
- ads/jobs/templates/driver_pytorch.py +27 -10
- ads/model/artifact_downloader.py +84 -14
- ads/model/artifact_uploader.py +25 -23
- ads/model/datascience_model.py +388 -38
- ads/model/deployment/model_deployment.py +10 -2
- ads/model/generic_model.py +8 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_metadata.py +1 -1
- ads/model/service/oci_datascience_model.py +34 -5
- ads/opctl/config/merger.py +2 -2
- ads/opctl/operator/__init__.py +3 -1
- ads/opctl/operator/cli.py +7 -1
- ads/opctl/operator/cmd.py +3 -3
- ads/opctl/operator/common/errors.py +2 -1
- ads/opctl/operator/common/operator_config.py +22 -3
- ads/opctl/operator/common/utils.py +16 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +15 -0
- ads/opctl/operator/lowcode/anomaly/README.md +209 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +104 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +88 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +12 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +147 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +89 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +103 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +354 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +67 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +105 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +359 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +81 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +96 -0
- ads/opctl/operator/lowcode/common/errors.py +41 -0
- ads/opctl/operator/lowcode/common/transformations.py +191 -0
- ads/opctl/operator/lowcode/common/utils.py +250 -0
- ads/opctl/operator/lowcode/forecast/README.md +3 -2
- ads/opctl/operator/lowcode/forecast/__main__.py +18 -2
- ads/opctl/operator/lowcode/forecast/cmd.py +8 -7
- ads/opctl/operator/lowcode/forecast/const.py +17 -1
- ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
- ads/opctl/operator/lowcode/forecast/model/arima.py +106 -117
- ads/opctl/operator/lowcode/forecast/model/automlx.py +204 -180
- ads/opctl/operator/lowcode/forecast/model/autots.py +144 -253
- ads/opctl/operator/lowcode/forecast/model/base_model.py +326 -259
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +325 -176
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +293 -237
- ads/opctl/operator/lowcode/forecast/model/prophet.py +191 -208
- ads/opctl/operator/lowcode/forecast/operator_config.py +24 -33
- ads/opctl/operator/lowcode/forecast/schema.yaml +116 -29
- ads/opctl/operator/lowcode/forecast/utils.py +186 -356
- ads/opctl/operator/lowcode/pii/model/guardrails.py +18 -15
- ads/opctl/operator/lowcode/pii/model/report.py +7 -7
- ads/opctl/operator/lowcode/pii/operator_config.py +1 -8
- ads/opctl/operator/lowcode/pii/utils.py +0 -82
- ads/opctl/operator/runtime/runtime.py +3 -2
- ads/telemetry/base.py +62 -0
- ads/telemetry/client.py +105 -0
- ads/telemetry/telemetry.py +6 -3
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +44 -7
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +116 -59
- ads/opctl/operator/lowcode/forecast/model/transformations.py +0 -125
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,359 @@
|
|
1
|
+
kind:
|
2
|
+
allowed:
|
3
|
+
- operator
|
4
|
+
required: true
|
5
|
+
type: string
|
6
|
+
default: operator
|
7
|
+
meta:
|
8
|
+
description: "Which service are you trying to use? Common kinds: `operator`, `job`"
|
9
|
+
|
10
|
+
version:
|
11
|
+
allowed:
|
12
|
+
- "v1"
|
13
|
+
required: true
|
14
|
+
type: string
|
15
|
+
default: v1
|
16
|
+
meta:
|
17
|
+
description: "Operators may change yaml file schemas from version to version, as well as implementation details. Double check the version to ensure compatibility."
|
18
|
+
|
19
|
+
type:
|
20
|
+
required: true
|
21
|
+
type: string
|
22
|
+
default: anomaly
|
23
|
+
meta:
|
24
|
+
description: "Type should always be `anomaly` when using a anomaly detection operator"
|
25
|
+
|
26
|
+
spec:
|
27
|
+
required: true
|
28
|
+
schema:
|
29
|
+
input_data:
|
30
|
+
required: true
|
31
|
+
type: dict
|
32
|
+
default: {"url": "data.csv"}
|
33
|
+
meta:
|
34
|
+
description: "The payload that the detector should evaluate."
|
35
|
+
schema:
|
36
|
+
connect_args:
|
37
|
+
nullable: true
|
38
|
+
required: false
|
39
|
+
type: dict
|
40
|
+
format:
|
41
|
+
allowed:
|
42
|
+
- csv
|
43
|
+
- json
|
44
|
+
- clipboard
|
45
|
+
- excel
|
46
|
+
- feather
|
47
|
+
- sql_table
|
48
|
+
- sql_query
|
49
|
+
- hdf
|
50
|
+
- tsv
|
51
|
+
required: false
|
52
|
+
type: string
|
53
|
+
columns:
|
54
|
+
required: false
|
55
|
+
type: list
|
56
|
+
schema:
|
57
|
+
type: string
|
58
|
+
filters:
|
59
|
+
required: false
|
60
|
+
type: list
|
61
|
+
schema:
|
62
|
+
type: string
|
63
|
+
options:
|
64
|
+
nullable: true
|
65
|
+
required: false
|
66
|
+
type: dict
|
67
|
+
sql:
|
68
|
+
required: false
|
69
|
+
type: string
|
70
|
+
table_name:
|
71
|
+
required: false
|
72
|
+
type: string
|
73
|
+
url:
|
74
|
+
required: false
|
75
|
+
type: string
|
76
|
+
meta:
|
77
|
+
description: "The url can be local, or remote. For example: `oci://<bucket>@<namespace>/data.csv`"
|
78
|
+
limit:
|
79
|
+
required: false
|
80
|
+
type: integer
|
81
|
+
|
82
|
+
validation_data:
|
83
|
+
required: false
|
84
|
+
type: dict
|
85
|
+
meta:
|
86
|
+
description: "Data that has already been labeled as anomalous or not."
|
87
|
+
schema:
|
88
|
+
connect_args:
|
89
|
+
nullable: true
|
90
|
+
required: false
|
91
|
+
type: dict
|
92
|
+
format:
|
93
|
+
allowed:
|
94
|
+
- csv
|
95
|
+
- json
|
96
|
+
- clipboard
|
97
|
+
- excel
|
98
|
+
- feather
|
99
|
+
- sql_table
|
100
|
+
- sql_query
|
101
|
+
- hdf
|
102
|
+
- tsv
|
103
|
+
required: false
|
104
|
+
type: string
|
105
|
+
columns:
|
106
|
+
required: false
|
107
|
+
type: list
|
108
|
+
schema:
|
109
|
+
type: string
|
110
|
+
filters:
|
111
|
+
required: false
|
112
|
+
type: list
|
113
|
+
schema:
|
114
|
+
type: string
|
115
|
+
options:
|
116
|
+
nullable: true
|
117
|
+
required: false
|
118
|
+
type: dict
|
119
|
+
sql:
|
120
|
+
required: false
|
121
|
+
type: string
|
122
|
+
table_name:
|
123
|
+
required: false
|
124
|
+
type: string
|
125
|
+
url:
|
126
|
+
required: false
|
127
|
+
type: string
|
128
|
+
meta:
|
129
|
+
description: "The url can be local, or remote. For example: `oci://<bucket>@<namespace>/data.csv`"
|
130
|
+
limit:
|
131
|
+
required: false
|
132
|
+
type: integer
|
133
|
+
|
134
|
+
datetime_column:
|
135
|
+
type: dict
|
136
|
+
required: true
|
137
|
+
schema:
|
138
|
+
name:
|
139
|
+
type: string
|
140
|
+
required: true
|
141
|
+
default: Date
|
142
|
+
format:
|
143
|
+
type: string
|
144
|
+
required: false
|
145
|
+
|
146
|
+
test_data:
|
147
|
+
required: false
|
148
|
+
meta:
|
149
|
+
description: "Optional, only if evaluation is needed."
|
150
|
+
schema:
|
151
|
+
connect_args:
|
152
|
+
nullable: true
|
153
|
+
required: false
|
154
|
+
type: dict
|
155
|
+
format:
|
156
|
+
allowed:
|
157
|
+
- csv
|
158
|
+
- json
|
159
|
+
- clipboard
|
160
|
+
- excel
|
161
|
+
- feather
|
162
|
+
- sql_table
|
163
|
+
- sql_query
|
164
|
+
- hdf
|
165
|
+
- tsv
|
166
|
+
required: false
|
167
|
+
type: string
|
168
|
+
columns:
|
169
|
+
required: false
|
170
|
+
type: list
|
171
|
+
schema:
|
172
|
+
type: string
|
173
|
+
filters:
|
174
|
+
required: false
|
175
|
+
type: list
|
176
|
+
schema:
|
177
|
+
type: string
|
178
|
+
options:
|
179
|
+
nullable: true
|
180
|
+
required: false
|
181
|
+
type: dict
|
182
|
+
sql:
|
183
|
+
required: false
|
184
|
+
type: string
|
185
|
+
table_name:
|
186
|
+
required: false
|
187
|
+
type: string
|
188
|
+
url:
|
189
|
+
required: false
|
190
|
+
type: string
|
191
|
+
meta:
|
192
|
+
description: "The url can be local, or remote. For example: `oci://<bucket>@<namespace>/data.csv`"
|
193
|
+
limit:
|
194
|
+
required: false
|
195
|
+
type: integer
|
196
|
+
type: dict
|
197
|
+
|
198
|
+
output_directory:
|
199
|
+
required: false
|
200
|
+
schema:
|
201
|
+
connect_args:
|
202
|
+
nullable: true
|
203
|
+
required: false
|
204
|
+
type: dict
|
205
|
+
format:
|
206
|
+
allowed:
|
207
|
+
- csv
|
208
|
+
- json
|
209
|
+
- clipboard
|
210
|
+
- excel
|
211
|
+
- feather
|
212
|
+
- sql_table
|
213
|
+
- sql_query
|
214
|
+
- hdf
|
215
|
+
- tsv
|
216
|
+
required: false
|
217
|
+
type: string
|
218
|
+
columns:
|
219
|
+
required: false
|
220
|
+
type: list
|
221
|
+
schema:
|
222
|
+
type: string
|
223
|
+
filters:
|
224
|
+
required: false
|
225
|
+
type: list
|
226
|
+
schema:
|
227
|
+
type: string
|
228
|
+
options:
|
229
|
+
nullable: true
|
230
|
+
required: false
|
231
|
+
type: dict
|
232
|
+
sql:
|
233
|
+
required: false
|
234
|
+
type: string
|
235
|
+
table_name:
|
236
|
+
required: false
|
237
|
+
type: string
|
238
|
+
url:
|
239
|
+
required: false
|
240
|
+
type: string
|
241
|
+
meta:
|
242
|
+
description: "The url can be local, or remote. For example: `oci://<bucket>@<namespace>/data.csv`"
|
243
|
+
limit:
|
244
|
+
required: false
|
245
|
+
type: integer
|
246
|
+
type: dict
|
247
|
+
|
248
|
+
report_filename:
|
249
|
+
required: false
|
250
|
+
type: string
|
251
|
+
default: report.html
|
252
|
+
meta:
|
253
|
+
description: "Placed into output_directory location. Defaults to report.html"
|
254
|
+
report_title:
|
255
|
+
required: false
|
256
|
+
type: string
|
257
|
+
report_theme:
|
258
|
+
required: false
|
259
|
+
type: string
|
260
|
+
default: light
|
261
|
+
allowed:
|
262
|
+
- light
|
263
|
+
- dark
|
264
|
+
|
265
|
+
metrics_filename:
|
266
|
+
required: false
|
267
|
+
type: string
|
268
|
+
default: metrics.csv
|
269
|
+
meta:
|
270
|
+
description: "Placed into output_directory location. Defaults to metrics.csv"
|
271
|
+
|
272
|
+
test_metrics_filename:
|
273
|
+
required: false
|
274
|
+
type: string
|
275
|
+
default: test_metrics.csv
|
276
|
+
meta:
|
277
|
+
description: "Placed into output_directory location. Defaults to test_metrics.csv"
|
278
|
+
|
279
|
+
inliers_filename:
|
280
|
+
required: false
|
281
|
+
type: string
|
282
|
+
default: inliers.csv
|
283
|
+
meta:
|
284
|
+
description: "Placed into output_directory location. Defaults to inliers.csv"
|
285
|
+
|
286
|
+
outliers_filename:
|
287
|
+
required: false
|
288
|
+
type: string
|
289
|
+
default: outliers.csv
|
290
|
+
meta:
|
291
|
+
description: "Placed into output_directory location. Defaults to outliers.csv"
|
292
|
+
|
293
|
+
target_column:
|
294
|
+
type: string
|
295
|
+
required: true
|
296
|
+
default: target
|
297
|
+
meta:
|
298
|
+
description: "Identifier column for the series in the dataset"
|
299
|
+
|
300
|
+
target_category_columns:
|
301
|
+
type: list
|
302
|
+
required: false
|
303
|
+
schema:
|
304
|
+
type: string
|
305
|
+
default: ["Series ID"]
|
306
|
+
meta:
|
307
|
+
description: "When provided, target_category_columns [list] indexes the data into multiple related datasets for anomaly detection"
|
308
|
+
|
309
|
+
preprocessing:
|
310
|
+
type: boolean
|
311
|
+
required: false
|
312
|
+
default: true
|
313
|
+
meta:
|
314
|
+
description: "preprocessing and feature engineering can be disabled using this flag, Defaults to true"
|
315
|
+
|
316
|
+
generate_report:
|
317
|
+
type: boolean
|
318
|
+
required: false
|
319
|
+
default: true
|
320
|
+
meta:
|
321
|
+
description: "Report file generation can be enabled using this flag. Defaults to true."
|
322
|
+
|
323
|
+
generate_metrics:
|
324
|
+
type: boolean
|
325
|
+
required: false
|
326
|
+
default: true
|
327
|
+
meta:
|
328
|
+
description: "Metrics files generation can be enabled using this flag. Defaults to true."
|
329
|
+
|
330
|
+
generate_inliers:
|
331
|
+
type: boolean
|
332
|
+
required: false
|
333
|
+
default: false
|
334
|
+
meta:
|
335
|
+
description: "Generates inliers.csv"
|
336
|
+
|
337
|
+
model:
|
338
|
+
type: string
|
339
|
+
required: false
|
340
|
+
default: automlx
|
341
|
+
allowed:
|
342
|
+
- automlx
|
343
|
+
- autots
|
344
|
+
- auto
|
345
|
+
meta:
|
346
|
+
description: "The model to be used for anomaly detection"
|
347
|
+
|
348
|
+
contamination:
|
349
|
+
required: false
|
350
|
+
default: 0.1
|
351
|
+
type: float
|
352
|
+
meta:
|
353
|
+
description: "Fraction of training dataset corresponding to anomalies (between 0.0 and 0.5)"
|
354
|
+
|
355
|
+
model_kwargs:
|
356
|
+
type: dict
|
357
|
+
required: false
|
358
|
+
|
359
|
+
type: dict
|
@@ -0,0 +1,81 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*--
|
3
|
+
|
4
|
+
# Copyright (c) 2023 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
import os
|
8
|
+
import pandas as pd
|
9
|
+
import fsspec
|
10
|
+
from .operator_config import AnomalyOperatorSpec
|
11
|
+
from .const import SupportedMetrics, SupportedModels
|
12
|
+
from ads.opctl import logger
|
13
|
+
|
14
|
+
|
15
|
+
def _build_metrics_df(y_true, y_pred, column_name):
|
16
|
+
from sklearn.metrics import (
|
17
|
+
recall_score,
|
18
|
+
precision_score,
|
19
|
+
accuracy_score,
|
20
|
+
f1_score,
|
21
|
+
confusion_matrix,
|
22
|
+
roc_auc_score,
|
23
|
+
precision_recall_curve,
|
24
|
+
auc,
|
25
|
+
matthews_corrcoef,
|
26
|
+
)
|
27
|
+
|
28
|
+
metrics = dict()
|
29
|
+
metrics[SupportedMetrics.RECALL] = recall_score(y_true, y_pred)
|
30
|
+
metrics[SupportedMetrics.PRECISION] = precision_score(y_true, y_pred)
|
31
|
+
metrics[SupportedMetrics.ACCURACY] = accuracy_score(y_true, y_pred)
|
32
|
+
metrics[SupportedMetrics.F1_SCORE] = f1_score(y_true, y_pred)
|
33
|
+
tn, *fn_fp_tp = confusion_matrix(y_true, y_pred).ravel()
|
34
|
+
fp, fn, tp = fn_fp_tp if fn_fp_tp else (0, 0, 0)
|
35
|
+
metrics[SupportedMetrics.FP] = fp
|
36
|
+
metrics[SupportedMetrics.FN] = fn
|
37
|
+
metrics[SupportedMetrics.TP] = tp
|
38
|
+
metrics[SupportedMetrics.TN] = tn
|
39
|
+
try:
|
40
|
+
# Throws exception if y_true has only one class
|
41
|
+
metrics[SupportedMetrics.ROC_AUC] = roc_auc_score(y_true, y_pred)
|
42
|
+
except Exception as e:
|
43
|
+
logger.warn(f"An exception occurred: {e}")
|
44
|
+
metrics[SupportedMetrics.ROC_AUC] = None
|
45
|
+
precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
|
46
|
+
metrics[SupportedMetrics.PRC_AUC] = auc(recall, precision)
|
47
|
+
metrics[SupportedMetrics.MCC] = matthews_corrcoef(y_true, y_pred)
|
48
|
+
return pd.DataFrame.from_dict(metrics, orient="index", columns=[column_name])
|
49
|
+
|
50
|
+
|
51
|
+
def get_frequency_of_datetime(data: pd.DataFrame, dataset_info: AnomalyOperatorSpec):
|
52
|
+
"""
|
53
|
+
Function finds the inferred freq from date time column
|
54
|
+
|
55
|
+
Parameters
|
56
|
+
------------
|
57
|
+
data: pd.DataFrame
|
58
|
+
primary dataset
|
59
|
+
dataset_info: AnomalyOperatorSpec
|
60
|
+
|
61
|
+
Returns
|
62
|
+
--------
|
63
|
+
None
|
64
|
+
|
65
|
+
"""
|
66
|
+
date_column = dataset_info.datetime_column.name
|
67
|
+
datetimes = pd.to_datetime(
|
68
|
+
data[date_column].drop_duplicates(), format=dataset_info.datetime_column.format
|
69
|
+
)
|
70
|
+
freq = pd.DatetimeIndex(datetimes).inferred_freq
|
71
|
+
return freq
|
72
|
+
|
73
|
+
|
74
|
+
def default_signer(**kwargs):
|
75
|
+
os.environ["EXTRA_USER_AGENT_INFO"] = "Anomaly-Detection-Operator"
|
76
|
+
from ads.common.auth import default_signer
|
77
|
+
|
78
|
+
return default_signer(**kwargs)
|
79
|
+
|
80
|
+
def select_auto_model(datasets, operator_config):
|
81
|
+
return SupportedModels.AutoMLX
|
@@ -0,0 +1,10 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*--
|
3
|
+
|
4
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
|
8
|
+
class DataColumns:
|
9
|
+
Series = "Series"
|
10
|
+
Date = "Date"
|
@@ -0,0 +1,96 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*--
|
3
|
+
|
4
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
import time
|
8
|
+
from .transformations import Transformations
|
9
|
+
from ads.opctl import logger
|
10
|
+
from ads.opctl.operator.lowcode.common.const import DataColumns
|
11
|
+
from ads.opctl.operator.lowcode.common.utils import load_data
|
12
|
+
from ads.opctl.operator.lowcode.common.errors import (
|
13
|
+
InputDataError,
|
14
|
+
InvalidParameterError,
|
15
|
+
PermissionsError,
|
16
|
+
DataMismatchError,
|
17
|
+
)
|
18
|
+
from abc import ABC
|
19
|
+
|
20
|
+
|
21
|
+
class AbstractData(ABC):
|
22
|
+
def __init__(self, spec: dict, name="input_data"):
|
23
|
+
self.Transformations = Transformations
|
24
|
+
self.data = None
|
25
|
+
self._data_dict = dict()
|
26
|
+
self.name = name
|
27
|
+
self.load_transform_ingest_data(spec)
|
28
|
+
|
29
|
+
def get_dict_by_series(self):
|
30
|
+
if not self._data_dict:
|
31
|
+
for s_id in self.list_series_ids():
|
32
|
+
try:
|
33
|
+
self._data_dict[s_id] = self.data.xs(
|
34
|
+
s_id, level=DataColumns.Series
|
35
|
+
).reset_index()
|
36
|
+
except KeyError as ke:
|
37
|
+
logger.debug(
|
38
|
+
f"Unable to extract series: {s_id} from data: {self.data}. This may occur due to significant missing data. Error message: {ke.args}"
|
39
|
+
)
|
40
|
+
pass
|
41
|
+
return self._data_dict
|
42
|
+
|
43
|
+
def get_data_for_series(self, series_id):
|
44
|
+
data_dict = self.get_dict_by_series()
|
45
|
+
try:
|
46
|
+
return data_dict[series_id]
|
47
|
+
except:
|
48
|
+
raise InvalidParameterError(
|
49
|
+
f"Unable to retrieve series {series_id} from {self.name}. Available series ids are: {self.list_series_ids()}"
|
50
|
+
)
|
51
|
+
|
52
|
+
def _load_data(self, data_spec, **kwargs):
|
53
|
+
loading_start_time = time.time()
|
54
|
+
try:
|
55
|
+
raw_data = load_data(data_spec)
|
56
|
+
except InvalidParameterError as e:
|
57
|
+
e.args = e.args + (f"Invalid Parameter: {self.name}",)
|
58
|
+
raise e
|
59
|
+
loading_end_time = time.time()
|
60
|
+
logger.info(
|
61
|
+
f"{self.name} loaded in {loading_end_time - loading_start_time} seconds",
|
62
|
+
)
|
63
|
+
return raw_data
|
64
|
+
|
65
|
+
def _transform_data(self, spec, raw_data, **kwargs):
|
66
|
+
transformation_start_time = time.time()
|
67
|
+
self._data_transformer = self.Transformations(spec, name=self.name)
|
68
|
+
data = self._data_transformer.run(raw_data)
|
69
|
+
transformation_end_time = time.time()
|
70
|
+
logger.info(
|
71
|
+
f"{self.name} transformations completed in {transformation_end_time - transformation_start_time} seconds"
|
72
|
+
)
|
73
|
+
return data
|
74
|
+
|
75
|
+
def load_transform_ingest_data(self, spec):
|
76
|
+
raw_data = self._load_data(getattr(spec, self.name))
|
77
|
+
self.data = self._transform_data(spec, raw_data)
|
78
|
+
self._ingest_data(spec)
|
79
|
+
|
80
|
+
def _ingest_data(self, spec):
|
81
|
+
pass
|
82
|
+
|
83
|
+
def get_data_long(self):
|
84
|
+
return self.data.reset_index(drop=False)
|
85
|
+
|
86
|
+
def get_min_time(self):
|
87
|
+
return self.data.index.get_level_values(0).min()
|
88
|
+
|
89
|
+
def get_max_time(self):
|
90
|
+
return self.data.index.get_level_values(0).max()
|
91
|
+
|
92
|
+
def list_series_ids(self):
|
93
|
+
return self.data.index.get_level_values(1).unique().tolist()
|
94
|
+
|
95
|
+
def get_num_rows(self):
|
96
|
+
return self.data.shape[0]
|
@@ -0,0 +1,41 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*--
|
3
|
+
|
4
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
from ads.opctl.operator import __operators__
|
8
|
+
from ads.opctl.operator.common.errors import InvalidParameterError
|
9
|
+
|
10
|
+
|
11
|
+
class DataMismatchError(Exception):
|
12
|
+
"""Exception raised when there is an issue with the schema."""
|
13
|
+
|
14
|
+
def __init__(self, error: str):
|
15
|
+
super().__init__(
|
16
|
+
"Invalid operator specification. Check the YAML structure and ensure it "
|
17
|
+
"complies with the required schema for the operator. \n"
|
18
|
+
f"{error}"
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
class InputDataError(Exception):
|
23
|
+
"""Exception raised when there is an issue with the input data."""
|
24
|
+
|
25
|
+
def __init__(self, error: str):
|
26
|
+
super().__init__(
|
27
|
+
"Invalid operator specification. Check the YAML structure and ensure it "
|
28
|
+
"complies with the required schema for the operator. \n"
|
29
|
+
f"{error}"
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
class PermissionsError(Exception):
|
34
|
+
"""Exception raised when there is an issue with the input data."""
|
35
|
+
|
36
|
+
def __init__(self, error: str):
|
37
|
+
super().__init__(
|
38
|
+
"Invalid operator specification. Check the YAML structure and ensure it "
|
39
|
+
"complies with the required schema for the operator. \n"
|
40
|
+
f"{error}"
|
41
|
+
)
|