oracle-ads 2.10.1__py3-none-any.whl → 2.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +12 -0
- ads/aqua/base.py +324 -0
- ads/aqua/cli.py +19 -0
- ads/aqua/config/deployment_config_defaults.json +9 -0
- ads/aqua/config/resource_limit_names.json +7 -0
- ads/aqua/constants.py +45 -0
- ads/aqua/data.py +40 -0
- ads/aqua/decorator.py +101 -0
- ads/aqua/deployment.py +643 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation.py +1751 -0
- ads/aqua/exception.py +82 -0
- ads/aqua/extension/__init__.py +40 -0
- ads/aqua/extension/base_handler.py +138 -0
- ads/aqua/extension/common_handler.py +21 -0
- ads/aqua/extension/deployment_handler.py +202 -0
- ads/aqua/extension/evaluation_handler.py +135 -0
- ads/aqua/extension/finetune_handler.py +66 -0
- ads/aqua/extension/model_handler.py +59 -0
- ads/aqua/extension/ui_handler.py +201 -0
- ads/aqua/extension/utils.py +23 -0
- ads/aqua/finetune.py +579 -0
- ads/aqua/job.py +29 -0
- ads/aqua/model.py +819 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +459 -0
- ads/aqua/ui.py +453 -0
- ads/aqua/utils.py +715 -0
- ads/cli.py +37 -6
- ads/common/decorator/__init__.py +7 -3
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/object_storage_details.py +166 -7
- ads/common/oci_client.py +18 -1
- ads/common/oci_logging.py +2 -2
- ads/common/oci_mixin.py +4 -5
- ads/common/serializer.py +34 -5
- ads/common/utils.py +75 -10
- ads/config.py +40 -1
- ads/jobs/ads_job.py +43 -25
- ads/jobs/builders/infrastructure/base.py +4 -2
- ads/jobs/builders/infrastructure/dsc_job.py +49 -39
- ads/jobs/builders/runtimes/base.py +71 -1
- ads/jobs/builders/runtimes/container_runtime.py +4 -4
- ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
- ads/jobs/templates/driver_pytorch.py +27 -10
- ads/model/artifact_downloader.py +84 -14
- ads/model/artifact_uploader.py +25 -23
- ads/model/datascience_model.py +388 -38
- ads/model/deployment/model_deployment.py +10 -2
- ads/model/generic_model.py +8 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_metadata.py +1 -1
- ads/model/service/oci_datascience_model.py +34 -5
- ads/opctl/operator/lowcode/anomaly/README.md +2 -1
- ads/opctl/operator/lowcode/anomaly/__main__.py +10 -4
- ads/opctl/operator/lowcode/anomaly/environment.yaml +2 -1
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +12 -6
- ads/opctl/operator/lowcode/forecast/README.md +3 -2
- ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
- ads/opctl/operator/lowcode/forecast/model/automlx.py +12 -23
- ads/telemetry/base.py +62 -0
- ads/telemetry/client.py +105 -0
- ads/telemetry/telemetry.py +6 -3
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/METADATA +37 -7
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/RECORD +71 -36
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/WHEEL +0 -0
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/entry_points.txt +0 -0
ads/aqua/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
|
7
|
+
import logging
|
8
|
+
import sys
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
handler = logging.StreamHandler(sys.stdout)
|
12
|
+
logger.setLevel(logging.INFO)
|
ads/aqua/base.py
ADDED
@@ -0,0 +1,324 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
import os
|
7
|
+
from typing import Dict, Union
|
8
|
+
|
9
|
+
import oci
|
10
|
+
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
|
11
|
+
|
12
|
+
from ads import set_auth
|
13
|
+
from ads.aqua import logger
|
14
|
+
from ads.aqua.data import Tags
|
15
|
+
from ads.aqua.exception import AquaRuntimeError, AquaValueError
|
16
|
+
from ads.aqua.utils import (
|
17
|
+
UNKNOWN,
|
18
|
+
_is_valid_mvs,
|
19
|
+
get_artifact_path,
|
20
|
+
get_base_model_from_tags,
|
21
|
+
is_valid_ocid,
|
22
|
+
load_config,
|
23
|
+
logger,
|
24
|
+
)
|
25
|
+
from ads.common import oci_client as oc
|
26
|
+
from ads.common.auth import default_signer
|
27
|
+
from ads.common.utils import extract_region
|
28
|
+
from ads.config import (
|
29
|
+
AQUA_TELEMETRY_BUCKET,
|
30
|
+
AQUA_TELEMETRY_BUCKET_NS,
|
31
|
+
OCI_ODSC_SERVICE_ENDPOINT,
|
32
|
+
OCI_RESOURCE_PRINCIPAL_VERSION,
|
33
|
+
)
|
34
|
+
from ads.model.datascience_model import DataScienceModel
|
35
|
+
from ads.model.deployment.model_deployment import ModelDeployment
|
36
|
+
from ads.model.model_metadata import (
|
37
|
+
ModelCustomMetadata,
|
38
|
+
ModelProvenanceMetadata,
|
39
|
+
ModelTaxonomyMetadata,
|
40
|
+
)
|
41
|
+
from ads.model.model_version_set import ModelVersionSet
|
42
|
+
from ads.telemetry import telemetry
|
43
|
+
from ads.telemetry.client import TelemetryClient
|
44
|
+
|
45
|
+
|
46
|
+
class AquaApp:
|
47
|
+
"""Base Aqua App to contain common components."""
|
48
|
+
|
49
|
+
@telemetry(name="aqua")
|
50
|
+
def __init__(self) -> None:
|
51
|
+
if OCI_RESOURCE_PRINCIPAL_VERSION:
|
52
|
+
set_auth("resource_principal")
|
53
|
+
self._auth = default_signer({"service_endpoint": OCI_ODSC_SERVICE_ENDPOINT})
|
54
|
+
self.ds_client = oc.OCIClientFactory(**self._auth).data_science
|
55
|
+
self.logging_client = oc.OCIClientFactory(**default_signer()).logging_management
|
56
|
+
self.identity_client = oc.OCIClientFactory(**default_signer()).identity
|
57
|
+
self.region = extract_region(self._auth)
|
58
|
+
self._telemetry = None
|
59
|
+
|
60
|
+
def list_resource(
|
61
|
+
self,
|
62
|
+
list_func_ref,
|
63
|
+
**kwargs,
|
64
|
+
) -> list:
|
65
|
+
"""Generic method to list OCI Data Science resources.
|
66
|
+
|
67
|
+
Parameters
|
68
|
+
----------
|
69
|
+
list_func_ref : function
|
70
|
+
A reference to the list operation which will be called.
|
71
|
+
**kwargs :
|
72
|
+
Additional keyword arguments to filter the resource.
|
73
|
+
The kwargs are passed into OCI API.
|
74
|
+
|
75
|
+
Returns
|
76
|
+
-------
|
77
|
+
list
|
78
|
+
A list of OCI Data Science resources.
|
79
|
+
"""
|
80
|
+
return oci.pagination.list_call_get_all_results(
|
81
|
+
list_func_ref,
|
82
|
+
**kwargs,
|
83
|
+
).data
|
84
|
+
|
85
|
+
def update_model(self, model_id: str, update_model_details: UpdateModelDetails):
|
86
|
+
"""Updates model details.
|
87
|
+
|
88
|
+
Parameters
|
89
|
+
----------
|
90
|
+
model_id : str
|
91
|
+
The id of target model.
|
92
|
+
update_model_details: UpdateModelDetails
|
93
|
+
The model details to be updated.
|
94
|
+
"""
|
95
|
+
self.ds_client.update_model(
|
96
|
+
model_id=model_id, update_model_details=update_model_details
|
97
|
+
)
|
98
|
+
|
99
|
+
def update_model_provenance(
|
100
|
+
self,
|
101
|
+
model_id: str,
|
102
|
+
update_model_provenance_details: UpdateModelProvenanceDetails,
|
103
|
+
):
|
104
|
+
"""Updates model provenance details.
|
105
|
+
|
106
|
+
Parameters
|
107
|
+
----------
|
108
|
+
model_id : str
|
109
|
+
The id of target model.
|
110
|
+
update_model_provenance_details: UpdateModelProvenanceDetails
|
111
|
+
The model provenance details to be updated.
|
112
|
+
"""
|
113
|
+
self.ds_client.update_model_provenance(
|
114
|
+
model_id=model_id,
|
115
|
+
update_model_provenance_details=update_model_provenance_details,
|
116
|
+
)
|
117
|
+
|
118
|
+
# TODO: refactor model evaluation implementation to use it.
|
119
|
+
@staticmethod
|
120
|
+
def get_source(source_id: str) -> Union[ModelDeployment, DataScienceModel]:
|
121
|
+
if is_valid_ocid(source_id):
|
122
|
+
if "datasciencemodeldeployment" in source_id:
|
123
|
+
return ModelDeployment.from_id(source_id)
|
124
|
+
elif "datasciencemodel" in source_id:
|
125
|
+
return DataScienceModel.from_id(source_id)
|
126
|
+
|
127
|
+
raise AquaValueError(
|
128
|
+
f"Invalid source {source_id}. "
|
129
|
+
"Specify either a model or model deployment id."
|
130
|
+
)
|
131
|
+
|
132
|
+
# TODO: refactor model evaluation implementation to use it.
|
133
|
+
@staticmethod
|
134
|
+
def create_model_version_set(
|
135
|
+
model_version_set_id: str = None,
|
136
|
+
model_version_set_name: str = None,
|
137
|
+
description: str = None,
|
138
|
+
compartment_id: str = None,
|
139
|
+
project_id: str = None,
|
140
|
+
**kwargs,
|
141
|
+
) -> tuple:
|
142
|
+
"""Creates ModelVersionSet from given ID or Name.
|
143
|
+
|
144
|
+
Parameters
|
145
|
+
----------
|
146
|
+
model_version_set_id: (str, optional):
|
147
|
+
ModelVersionSet OCID.
|
148
|
+
model_version_set_name: (str, optional):
|
149
|
+
ModelVersionSet Name.
|
150
|
+
description: (str, optional):
|
151
|
+
TBD
|
152
|
+
compartment_id: (str, optional):
|
153
|
+
Compartment OCID.
|
154
|
+
project_id: (str, optional):
|
155
|
+
Project OCID.
|
156
|
+
tag: (str, optional)
|
157
|
+
calling tag, can be Tags.AQUA_FINE_TUNING or Tags.AQUA_EVALUATION
|
158
|
+
|
159
|
+
Returns
|
160
|
+
-------
|
161
|
+
tuple: (model_version_set_id, model_version_set_name)
|
162
|
+
"""
|
163
|
+
# TODO: tag should be selected based on which operation (eval/FT) invoke this method
|
164
|
+
# currently only used by fine-tuning flow.
|
165
|
+
tag = Tags.AQUA_FINE_TUNING.value
|
166
|
+
|
167
|
+
if not model_version_set_id:
|
168
|
+
tag = Tags.AQUA_FINE_TUNING.value # TODO: Fix this
|
169
|
+
try:
|
170
|
+
model_version_set = ModelVersionSet.from_name(
|
171
|
+
name=model_version_set_name,
|
172
|
+
compartment_id=compartment_id,
|
173
|
+
)
|
174
|
+
|
175
|
+
if not _is_valid_mvs(model_version_set, tag):
|
176
|
+
raise AquaValueError(
|
177
|
+
f"Invalid model version set name. Please provide a model version set with `{tag}` in tags."
|
178
|
+
)
|
179
|
+
|
180
|
+
except:
|
181
|
+
logger.debug(
|
182
|
+
f"Model version set {model_version_set_name} doesn't exist. "
|
183
|
+
"Creating new model version set."
|
184
|
+
)
|
185
|
+
mvs_freeform_tags = {
|
186
|
+
tag: tag,
|
187
|
+
}
|
188
|
+
model_version_set = (
|
189
|
+
ModelVersionSet()
|
190
|
+
.with_compartment_id(compartment_id)
|
191
|
+
.with_project_id(project_id)
|
192
|
+
.with_name(model_version_set_name)
|
193
|
+
.with_description(description)
|
194
|
+
.with_freeform_tags(**mvs_freeform_tags)
|
195
|
+
# TODO: decide what parameters will be needed
|
196
|
+
# when refactor eval to use this method, we need to pass tag here.
|
197
|
+
.create(**kwargs)
|
198
|
+
)
|
199
|
+
logger.debug(
|
200
|
+
f"Successfully created model version set {model_version_set_name} with id {model_version_set.id}."
|
201
|
+
)
|
202
|
+
return (model_version_set.id, model_version_set_name)
|
203
|
+
else:
|
204
|
+
model_version_set = ModelVersionSet.from_id(model_version_set_id)
|
205
|
+
# TODO: tag should be selected based on which operation (eval/FT) invoke this method
|
206
|
+
if not _is_valid_mvs(model_version_set, tag):
|
207
|
+
raise AquaValueError(
|
208
|
+
f"Invalid model version set id. Please provide a model version set with `{tag}` in tags."
|
209
|
+
)
|
210
|
+
return (model_version_set_id, model_version_set.name)
|
211
|
+
|
212
|
+
# TODO: refactor model evaluation implementation to use it.
|
213
|
+
@staticmethod
|
214
|
+
def create_model_catalog(
|
215
|
+
display_name: str,
|
216
|
+
description: str,
|
217
|
+
model_version_set_id: str,
|
218
|
+
model_custom_metadata: Union[ModelCustomMetadata, Dict],
|
219
|
+
model_taxonomy_metadata: Union[ModelTaxonomyMetadata, Dict],
|
220
|
+
compartment_id: str,
|
221
|
+
project_id: str,
|
222
|
+
**kwargs,
|
223
|
+
) -> DataScienceModel:
|
224
|
+
model = (
|
225
|
+
DataScienceModel()
|
226
|
+
.with_compartment_id(compartment_id)
|
227
|
+
.with_project_id(project_id)
|
228
|
+
.with_display_name(display_name)
|
229
|
+
.with_description(description)
|
230
|
+
.with_model_version_set_id(model_version_set_id)
|
231
|
+
.with_custom_metadata_list(model_custom_metadata)
|
232
|
+
.with_defined_metadata_list(model_taxonomy_metadata)
|
233
|
+
.with_provenance_metadata(ModelProvenanceMetadata(training_id=UNKNOWN))
|
234
|
+
# TODO: decide what parameters will be needed
|
235
|
+
.create(
|
236
|
+
**kwargs,
|
237
|
+
)
|
238
|
+
)
|
239
|
+
return model
|
240
|
+
|
241
|
+
def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
|
242
|
+
"""Checks if the artifact exists.
|
243
|
+
|
244
|
+
Parameters
|
245
|
+
----------
|
246
|
+
model_id : str
|
247
|
+
The model OCID.
|
248
|
+
**kwargs :
|
249
|
+
Additional keyword arguments passed in head_model_artifact.
|
250
|
+
|
251
|
+
Returns
|
252
|
+
-------
|
253
|
+
bool
|
254
|
+
Whether the artifact exists.
|
255
|
+
"""
|
256
|
+
|
257
|
+
try:
|
258
|
+
response = self.ds_client.head_model_artifact(model_id=model_id, **kwargs)
|
259
|
+
return True if response.status == 200 else False
|
260
|
+
except oci.exceptions.ServiceError as ex:
|
261
|
+
if ex.status == 404:
|
262
|
+
logger.info(f"Artifact not found in model {model_id}.")
|
263
|
+
return False
|
264
|
+
|
265
|
+
def get_config(self, model_id: str, config_file_name: str) -> Dict:
|
266
|
+
"""Gets the config for the given Aqua model.
|
267
|
+
|
268
|
+
Parameters
|
269
|
+
----------
|
270
|
+
model_id: str
|
271
|
+
The OCID of the Aqua model.
|
272
|
+
config_file_name: str
|
273
|
+
name of the config file
|
274
|
+
|
275
|
+
Returns
|
276
|
+
-------
|
277
|
+
Dict:
|
278
|
+
A dict of allowed configs.
|
279
|
+
"""
|
280
|
+
oci_model = self.ds_client.get_model(model_id).data
|
281
|
+
oci_aqua = (
|
282
|
+
(
|
283
|
+
Tags.AQUA_TAG.value in oci_model.freeform_tags
|
284
|
+
or Tags.AQUA_TAG.value.lower() in oci_model.freeform_tags
|
285
|
+
)
|
286
|
+
if oci_model.freeform_tags
|
287
|
+
else False
|
288
|
+
)
|
289
|
+
|
290
|
+
if not oci_aqua:
|
291
|
+
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
|
292
|
+
|
293
|
+
config = {}
|
294
|
+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
|
295
|
+
if not artifact_path:
|
296
|
+
logger.error(
|
297
|
+
f"Failed to get artifact path from custom metadata for the model: {model_id}"
|
298
|
+
)
|
299
|
+
return config
|
300
|
+
|
301
|
+
try:
|
302
|
+
config_path = f"{os.path.dirname(artifact_path)}/config/"
|
303
|
+
config = load_config(
|
304
|
+
config_path,
|
305
|
+
config_file_name=config_file_name,
|
306
|
+
)
|
307
|
+
except:
|
308
|
+
pass
|
309
|
+
|
310
|
+
if not config:
|
311
|
+
logger.error(
|
312
|
+
f"{config_file_name} is not available for the model: {model_id}. Check if the custom metadata has the artifact path set."
|
313
|
+
)
|
314
|
+
return config
|
315
|
+
|
316
|
+
return config
|
317
|
+
|
318
|
+
@property
|
319
|
+
def telemetry(self):
|
320
|
+
if not self._telemetry:
|
321
|
+
self._telemetry = TelemetryClient(
|
322
|
+
bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS
|
323
|
+
)
|
324
|
+
return self._telemetry
|
ads/aqua/cli.py
ADDED
@@ -0,0 +1,19 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*--
|
3
|
+
|
4
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
from ads.aqua.deployment import AquaDeploymentApp
|
8
|
+
from ads.aqua.finetune import AquaFineTuningApp
|
9
|
+
from ads.aqua.model import AquaModelApp
|
10
|
+
from ads.aqua.evaluation import AquaEvaluationApp
|
11
|
+
|
12
|
+
|
13
|
+
class AquaCommand:
|
14
|
+
"""Contains the command groups for project Aqua."""
|
15
|
+
|
16
|
+
model = AquaModelApp
|
17
|
+
fine_tuning = AquaFineTuningApp
|
18
|
+
deployment = AquaDeploymentApp
|
19
|
+
evaluation = AquaEvaluationApp
|
ads/aqua/constants.py
ADDED
@@ -0,0 +1,45 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
"""This module defines constants used in ads.aqua module."""
|
6
|
+
from enum import Enum
|
7
|
+
|
8
|
+
UNKNOWN_VALUE = ""
|
9
|
+
|
10
|
+
|
11
|
+
class RqsAdditionalDetails:
|
12
|
+
METADATA = "metadata"
|
13
|
+
CREATED_BY = "createdBy"
|
14
|
+
DESCRIPTION = "description"
|
15
|
+
MODEL_VERSION_SET_ID = "modelVersionSetId"
|
16
|
+
MODEL_VERSION_SET_NAME = "modelVersionSetName"
|
17
|
+
PROJECT_ID = "projectId"
|
18
|
+
VERSION_LABEL = "versionLabel"
|
19
|
+
|
20
|
+
|
21
|
+
class FineTuningDefinedMetadata(Enum):
|
22
|
+
"""Represents the defined metadata keys used in Fine Tuning."""
|
23
|
+
|
24
|
+
VAL_SET_SIZE = "val_set_size"
|
25
|
+
TRAINING_DATA = "training_data"
|
26
|
+
|
27
|
+
|
28
|
+
class FineTuningCustomMetadata(Enum):
|
29
|
+
"""Represents the custom metadata keys used in Fine Tuning."""
|
30
|
+
|
31
|
+
FT_SOURCE = "fine_tune_source"
|
32
|
+
FT_SOURCE_NAME = "fine_tune_source_name"
|
33
|
+
FT_OUTPUT_PATH = "fine_tune_output_path"
|
34
|
+
FT_JOB_ID = "fine_tune_job_id"
|
35
|
+
FT_JOB_RUN_ID = "fine_tune_jobrun_id"
|
36
|
+
TRAINING_METRICS_FINAL = "train_metrics_final"
|
37
|
+
VALIDATION_METRICS_FINAL = "val_metrics_final"
|
38
|
+
TRAINING_METRICS_EPOCH = "train_metrics_epoch"
|
39
|
+
VALIDATION_METRICS_EPOCH = "val_metrics_epoch"
|
40
|
+
|
41
|
+
|
42
|
+
TRAINING_METRICS_FINAL = "training_metrics_final"
|
43
|
+
VALIDATION_METRICS_FINAL = "validation_metrics_final"
|
44
|
+
TRINING_METRICS = "training_metrics"
|
45
|
+
VALIDATION_METRICS = "validation_metrics"
|
ads/aqua/data.py
ADDED
@@ -0,0 +1,40 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from enum import Enum
|
8
|
+
|
9
|
+
from ads.common.serializer import DataClassSerializable
|
10
|
+
|
11
|
+
|
12
|
+
@dataclass(repr=False)
|
13
|
+
class AquaResourceIdentifier(DataClassSerializable):
|
14
|
+
id: str = ""
|
15
|
+
name: str = ""
|
16
|
+
url: str = ""
|
17
|
+
|
18
|
+
|
19
|
+
class Resource(Enum):
|
20
|
+
JOB = "jobs"
|
21
|
+
JOBRUN = "jobruns"
|
22
|
+
MODEL = "models"
|
23
|
+
MODEL_DEPLOYMENT = "modeldeployments"
|
24
|
+
MODEL_VERSION_SET = "model-version-sets"
|
25
|
+
|
26
|
+
|
27
|
+
class DataScienceResource(Enum):
|
28
|
+
MODEL_DEPLOYMENT = "datasciencemodeldeployment"
|
29
|
+
MODEL = "datasciencemodel"
|
30
|
+
|
31
|
+
|
32
|
+
class Tags(Enum):
|
33
|
+
TASK = "task"
|
34
|
+
LICENSE = "license"
|
35
|
+
ORGANIZATION = "organization"
|
36
|
+
AQUA_TAG = "OCI_AQUA"
|
37
|
+
AQUA_SERVICE_MODEL_TAG = "aqua_service_model"
|
38
|
+
AQUA_FINE_TUNED_MODEL_TAG = "aqua_fine_tuned_model"
|
39
|
+
AQUA_EVALUATION = "aqua_evaluation"
|
40
|
+
AQUA_FINE_TUNING = "aqua_finetuning"
|
ads/aqua/decorator.py
ADDED
@@ -0,0 +1,101 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
"""Decorator module."""
|
7
|
+
|
8
|
+
import sys
|
9
|
+
from functools import wraps
|
10
|
+
|
11
|
+
from oci.exceptions import (
|
12
|
+
ClientError,
|
13
|
+
CompositeOperationError,
|
14
|
+
ConnectTimeout,
|
15
|
+
MissingEndpointForNonRegionalServiceClientError,
|
16
|
+
MultipartUploadError,
|
17
|
+
RequestException,
|
18
|
+
ServiceError,
|
19
|
+
)
|
20
|
+
|
21
|
+
from ads.aqua.exception import AquaError
|
22
|
+
from ads.aqua.extension.base_handler import AquaAPIhandler
|
23
|
+
|
24
|
+
|
25
|
+
def handle_exceptions(func):
|
26
|
+
"""Writes errors raised during call to JSON.
|
27
|
+
|
28
|
+
This decorator is designed to be used with methods in handler.py that
|
29
|
+
interact with external services or perform operations which might
|
30
|
+
fail. This decorator should be applied only to instance methods of
|
31
|
+
classes within handler.py, as it is tailored to handle exceptions
|
32
|
+
specific to the operations performed by these handlers.
|
33
|
+
|
34
|
+
Parameters
|
35
|
+
----------
|
36
|
+
func (Callable): The function to be wrapped by the decorator.
|
37
|
+
|
38
|
+
Returns
|
39
|
+
-------
|
40
|
+
Callable: A wrapper function that catches exceptions thrown by `func`.
|
41
|
+
|
42
|
+
Examples
|
43
|
+
--------
|
44
|
+
|
45
|
+
>>> from ads.aqua.decorator import handle_exceptions
|
46
|
+
|
47
|
+
>>> @handle_exceptions
|
48
|
+
>>> def some_method(self, arg1, arg2):
|
49
|
+
... # Method implementation...
|
50
|
+
... pass
|
51
|
+
|
52
|
+
"""
|
53
|
+
|
54
|
+
@wraps(func)
|
55
|
+
def inner_function(self: AquaAPIhandler, *args, **kwargs):
|
56
|
+
try:
|
57
|
+
return func(self, *args, **kwargs)
|
58
|
+
except ServiceError as error:
|
59
|
+
self.write_error(
|
60
|
+
status_code=error.status or 500,
|
61
|
+
reason=error.message,
|
62
|
+
service_payload=error.args[0] if error.args else None,
|
63
|
+
exc_info=sys.exc_info(),
|
64
|
+
)
|
65
|
+
except (
|
66
|
+
ClientError,
|
67
|
+
MissingEndpointForNonRegionalServiceClientError,
|
68
|
+
RequestException,
|
69
|
+
) as error:
|
70
|
+
self.write_error(
|
71
|
+
status_code=400,
|
72
|
+
reason=f"{type(error).__name__}: {str(error)}",
|
73
|
+
exc_info=sys.exc_info(),
|
74
|
+
)
|
75
|
+
except ConnectTimeout as error:
|
76
|
+
self.write_error(
|
77
|
+
status_code=408,
|
78
|
+
reason=f"{type(error).__name__}: {str(error)}",
|
79
|
+
exc_info=sys.exc_info(),
|
80
|
+
)
|
81
|
+
except (MultipartUploadError, CompositeOperationError) as error:
|
82
|
+
self.write_error(
|
83
|
+
status_code=500,
|
84
|
+
reason=f"{type(error).__name__}: {str(error)}",
|
85
|
+
exc_info=sys.exc_info(),
|
86
|
+
)
|
87
|
+
except AquaError as error:
|
88
|
+
self.write_error(
|
89
|
+
status_code=error.status,
|
90
|
+
reason=error.reason,
|
91
|
+
service_payload=error.service_payload,
|
92
|
+
exc_info=sys.exc_info(),
|
93
|
+
)
|
94
|
+
except Exception as ex:
|
95
|
+
self.write_error(
|
96
|
+
status_code=500,
|
97
|
+
reason=f"{type(ex).__name__}: {str(ex)}",
|
98
|
+
exc_info=sys.exc_info(),
|
99
|
+
)
|
100
|
+
|
101
|
+
return inner_function
|