oracle-ads 2.10.0__py3-none-any.whl → 2.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. ads/aqua/__init__.py +12 -0
  2. ads/aqua/base.py +324 -0
  3. ads/aqua/cli.py +19 -0
  4. ads/aqua/config/deployment_config_defaults.json +9 -0
  5. ads/aqua/config/resource_limit_names.json +7 -0
  6. ads/aqua/constants.py +45 -0
  7. ads/aqua/data.py +40 -0
  8. ads/aqua/decorator.py +101 -0
  9. ads/aqua/deployment.py +643 -0
  10. ads/aqua/dummy_data/icon.txt +1 -0
  11. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  12. ads/aqua/dummy_data/oci_models.json +1 -0
  13. ads/aqua/dummy_data/readme.md +26 -0
  14. ads/aqua/evaluation.py +1751 -0
  15. ads/aqua/exception.py +82 -0
  16. ads/aqua/extension/__init__.py +40 -0
  17. ads/aqua/extension/base_handler.py +138 -0
  18. ads/aqua/extension/common_handler.py +21 -0
  19. ads/aqua/extension/deployment_handler.py +202 -0
  20. ads/aqua/extension/evaluation_handler.py +135 -0
  21. ads/aqua/extension/finetune_handler.py +66 -0
  22. ads/aqua/extension/model_handler.py +59 -0
  23. ads/aqua/extension/ui_handler.py +201 -0
  24. ads/aqua/extension/utils.py +23 -0
  25. ads/aqua/finetune.py +579 -0
  26. ads/aqua/job.py +29 -0
  27. ads/aqua/model.py +819 -0
  28. ads/aqua/training/__init__.py +4 -0
  29. ads/aqua/training/exceptions.py +459 -0
  30. ads/aqua/ui.py +453 -0
  31. ads/aqua/utils.py +715 -0
  32. ads/cli.py +37 -6
  33. ads/common/auth.py +7 -0
  34. ads/common/decorator/__init__.py +7 -3
  35. ads/common/decorator/require_nonempty_arg.py +65 -0
  36. ads/common/object_storage_details.py +166 -7
  37. ads/common/oci_client.py +18 -1
  38. ads/common/oci_logging.py +2 -2
  39. ads/common/oci_mixin.py +4 -5
  40. ads/common/serializer.py +34 -5
  41. ads/common/utils.py +75 -10
  42. ads/config.py +40 -1
  43. ads/dataset/correlation_plot.py +10 -12
  44. ads/jobs/ads_job.py +43 -25
  45. ads/jobs/builders/infrastructure/base.py +4 -2
  46. ads/jobs/builders/infrastructure/dsc_job.py +49 -39
  47. ads/jobs/builders/runtimes/base.py +71 -1
  48. ads/jobs/builders/runtimes/container_runtime.py +4 -4
  49. ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
  50. ads/jobs/templates/driver_pytorch.py +27 -10
  51. ads/model/artifact_downloader.py +84 -14
  52. ads/model/artifact_uploader.py +25 -23
  53. ads/model/datascience_model.py +388 -38
  54. ads/model/deployment/model_deployment.py +10 -2
  55. ads/model/generic_model.py +8 -0
  56. ads/model/model_file_description_schema.json +68 -0
  57. ads/model/model_metadata.py +1 -1
  58. ads/model/service/oci_datascience_model.py +34 -5
  59. ads/opctl/config/merger.py +2 -2
  60. ads/opctl/operator/__init__.py +3 -1
  61. ads/opctl/operator/cli.py +7 -1
  62. ads/opctl/operator/cmd.py +3 -3
  63. ads/opctl/operator/common/errors.py +2 -1
  64. ads/opctl/operator/common/operator_config.py +22 -3
  65. ads/opctl/operator/common/utils.py +16 -0
  66. ads/opctl/operator/lowcode/anomaly/MLoperator +15 -0
  67. ads/opctl/operator/lowcode/anomaly/README.md +209 -0
  68. ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
  69. ads/opctl/operator/lowcode/anomaly/__main__.py +104 -0
  70. ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
  71. ads/opctl/operator/lowcode/anomaly/const.py +88 -0
  72. ads/opctl/operator/lowcode/anomaly/environment.yaml +12 -0
  73. ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
  74. ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +147 -0
  75. ads/opctl/operator/lowcode/anomaly/model/automlx.py +89 -0
  76. ads/opctl/operator/lowcode/anomaly/model/autots.py +103 -0
  77. ads/opctl/operator/lowcode/anomaly/model/base_model.py +354 -0
  78. ads/opctl/operator/lowcode/anomaly/model/factory.py +67 -0
  79. ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
  80. ads/opctl/operator/lowcode/anomaly/operator_config.py +105 -0
  81. ads/opctl/operator/lowcode/anomaly/schema.yaml +359 -0
  82. ads/opctl/operator/lowcode/anomaly/utils.py +81 -0
  83. ads/opctl/operator/lowcode/common/__init__.py +5 -0
  84. ads/opctl/operator/lowcode/common/const.py +10 -0
  85. ads/opctl/operator/lowcode/common/data.py +96 -0
  86. ads/opctl/operator/lowcode/common/errors.py +41 -0
  87. ads/opctl/operator/lowcode/common/transformations.py +191 -0
  88. ads/opctl/operator/lowcode/common/utils.py +250 -0
  89. ads/opctl/operator/lowcode/forecast/README.md +3 -2
  90. ads/opctl/operator/lowcode/forecast/__main__.py +18 -2
  91. ads/opctl/operator/lowcode/forecast/cmd.py +8 -7
  92. ads/opctl/operator/lowcode/forecast/const.py +17 -1
  93. ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
  94. ads/opctl/operator/lowcode/forecast/model/arima.py +106 -117
  95. ads/opctl/operator/lowcode/forecast/model/automlx.py +204 -180
  96. ads/opctl/operator/lowcode/forecast/model/autots.py +144 -253
  97. ads/opctl/operator/lowcode/forecast/model/base_model.py +326 -259
  98. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +325 -176
  99. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +293 -237
  100. ads/opctl/operator/lowcode/forecast/model/prophet.py +191 -208
  101. ads/opctl/operator/lowcode/forecast/operator_config.py +24 -33
  102. ads/opctl/operator/lowcode/forecast/schema.yaml +116 -29
  103. ads/opctl/operator/lowcode/forecast/utils.py +186 -356
  104. ads/opctl/operator/lowcode/pii/model/guardrails.py +18 -15
  105. ads/opctl/operator/lowcode/pii/model/report.py +7 -7
  106. ads/opctl/operator/lowcode/pii/operator_config.py +1 -8
  107. ads/opctl/operator/lowcode/pii/utils.py +0 -82
  108. ads/opctl/operator/runtime/runtime.py +3 -2
  109. ads/telemetry/base.py +62 -0
  110. ads/telemetry/client.py +105 -0
  111. ads/telemetry/telemetry.py +6 -3
  112. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +44 -7
  113. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +116 -59
  114. ads/opctl/operator/lowcode/forecast/model/transformations.py +0 -125
  115. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
  116. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
  117. {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
ads/aqua/__init__.py ADDED
@@ -0,0 +1,12 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+
7
+ import logging
8
+ import sys
9
+
10
+ logger = logging.getLogger(__name__)
11
+ handler = logging.StreamHandler(sys.stdout)
12
+ logger.setLevel(logging.INFO)
ads/aqua/base.py ADDED
@@ -0,0 +1,324 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ import os
7
+ from typing import Dict, Union
8
+
9
+ import oci
10
+ from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
11
+
12
+ from ads import set_auth
13
+ from ads.aqua import logger
14
+ from ads.aqua.data import Tags
15
+ from ads.aqua.exception import AquaRuntimeError, AquaValueError
16
+ from ads.aqua.utils import (
17
+ UNKNOWN,
18
+ _is_valid_mvs,
19
+ get_artifact_path,
20
+ get_base_model_from_tags,
21
+ is_valid_ocid,
22
+ load_config,
23
+ logger,
24
+ )
25
+ from ads.common import oci_client as oc
26
+ from ads.common.auth import default_signer
27
+ from ads.common.utils import extract_region
28
+ from ads.config import (
29
+ AQUA_TELEMETRY_BUCKET,
30
+ AQUA_TELEMETRY_BUCKET_NS,
31
+ OCI_ODSC_SERVICE_ENDPOINT,
32
+ OCI_RESOURCE_PRINCIPAL_VERSION,
33
+ )
34
+ from ads.model.datascience_model import DataScienceModel
35
+ from ads.model.deployment.model_deployment import ModelDeployment
36
+ from ads.model.model_metadata import (
37
+ ModelCustomMetadata,
38
+ ModelProvenanceMetadata,
39
+ ModelTaxonomyMetadata,
40
+ )
41
+ from ads.model.model_version_set import ModelVersionSet
42
+ from ads.telemetry import telemetry
43
+ from ads.telemetry.client import TelemetryClient
44
+
45
+
46
+ class AquaApp:
47
+ """Base Aqua App to contain common components."""
48
+
49
+ @telemetry(name="aqua")
50
+ def __init__(self) -> None:
51
+ if OCI_RESOURCE_PRINCIPAL_VERSION:
52
+ set_auth("resource_principal")
53
+ self._auth = default_signer({"service_endpoint": OCI_ODSC_SERVICE_ENDPOINT})
54
+ self.ds_client = oc.OCIClientFactory(**self._auth).data_science
55
+ self.logging_client = oc.OCIClientFactory(**default_signer()).logging_management
56
+ self.identity_client = oc.OCIClientFactory(**default_signer()).identity
57
+ self.region = extract_region(self._auth)
58
+ self._telemetry = None
59
+
60
+ def list_resource(
61
+ self,
62
+ list_func_ref,
63
+ **kwargs,
64
+ ) -> list:
65
+ """Generic method to list OCI Data Science resources.
66
+
67
+ Parameters
68
+ ----------
69
+ list_func_ref : function
70
+ A reference to the list operation which will be called.
71
+ **kwargs :
72
+ Additional keyword arguments to filter the resource.
73
+ The kwargs are passed into OCI API.
74
+
75
+ Returns
76
+ -------
77
+ list
78
+ A list of OCI Data Science resources.
79
+ """
80
+ return oci.pagination.list_call_get_all_results(
81
+ list_func_ref,
82
+ **kwargs,
83
+ ).data
84
+
85
+ def update_model(self, model_id: str, update_model_details: UpdateModelDetails):
86
+ """Updates model details.
87
+
88
+ Parameters
89
+ ----------
90
+ model_id : str
91
+ The id of target model.
92
+ update_model_details: UpdateModelDetails
93
+ The model details to be updated.
94
+ """
95
+ self.ds_client.update_model(
96
+ model_id=model_id, update_model_details=update_model_details
97
+ )
98
+
99
+ def update_model_provenance(
100
+ self,
101
+ model_id: str,
102
+ update_model_provenance_details: UpdateModelProvenanceDetails,
103
+ ):
104
+ """Updates model provenance details.
105
+
106
+ Parameters
107
+ ----------
108
+ model_id : str
109
+ The id of target model.
110
+ update_model_provenance_details: UpdateModelProvenanceDetails
111
+ The model provenance details to be updated.
112
+ """
113
+ self.ds_client.update_model_provenance(
114
+ model_id=model_id,
115
+ update_model_provenance_details=update_model_provenance_details,
116
+ )
117
+
118
+ # TODO: refactor model evaluation implementation to use it.
119
+ @staticmethod
120
+ def get_source(source_id: str) -> Union[ModelDeployment, DataScienceModel]:
121
+ if is_valid_ocid(source_id):
122
+ if "datasciencemodeldeployment" in source_id:
123
+ return ModelDeployment.from_id(source_id)
124
+ elif "datasciencemodel" in source_id:
125
+ return DataScienceModel.from_id(source_id)
126
+
127
+ raise AquaValueError(
128
+ f"Invalid source {source_id}. "
129
+ "Specify either a model or model deployment id."
130
+ )
131
+
132
+ # TODO: refactor model evaluation implementation to use it.
133
+ @staticmethod
134
+ def create_model_version_set(
135
+ model_version_set_id: str = None,
136
+ model_version_set_name: str = None,
137
+ description: str = None,
138
+ compartment_id: str = None,
139
+ project_id: str = None,
140
+ **kwargs,
141
+ ) -> tuple:
142
+ """Creates ModelVersionSet from given ID or Name.
143
+
144
+ Parameters
145
+ ----------
146
+ model_version_set_id: (str, optional):
147
+ ModelVersionSet OCID.
148
+ model_version_set_name: (str, optional):
149
+ ModelVersionSet Name.
150
+ description: (str, optional):
151
+ TBD
152
+ compartment_id: (str, optional):
153
+ Compartment OCID.
154
+ project_id: (str, optional):
155
+ Project OCID.
156
+ tag: (str, optional)
157
+ calling tag, can be Tags.AQUA_FINE_TUNING or Tags.AQUA_EVALUATION
158
+
159
+ Returns
160
+ -------
161
+ tuple: (model_version_set_id, model_version_set_name)
162
+ """
163
+ # TODO: tag should be selected based on which operation (eval/FT) invoke this method
164
+ # currently only used by fine-tuning flow.
165
+ tag = Tags.AQUA_FINE_TUNING.value
166
+
167
+ if not model_version_set_id:
168
+ tag = Tags.AQUA_FINE_TUNING.value # TODO: Fix this
169
+ try:
170
+ model_version_set = ModelVersionSet.from_name(
171
+ name=model_version_set_name,
172
+ compartment_id=compartment_id,
173
+ )
174
+
175
+ if not _is_valid_mvs(model_version_set, tag):
176
+ raise AquaValueError(
177
+ f"Invalid model version set name. Please provide a model version set with `{tag}` in tags."
178
+ )
179
+
180
+ except:
181
+ logger.debug(
182
+ f"Model version set {model_version_set_name} doesn't exist. "
183
+ "Creating new model version set."
184
+ )
185
+ mvs_freeform_tags = {
186
+ tag: tag,
187
+ }
188
+ model_version_set = (
189
+ ModelVersionSet()
190
+ .with_compartment_id(compartment_id)
191
+ .with_project_id(project_id)
192
+ .with_name(model_version_set_name)
193
+ .with_description(description)
194
+ .with_freeform_tags(**mvs_freeform_tags)
195
+ # TODO: decide what parameters will be needed
196
+ # when refactor eval to use this method, we need to pass tag here.
197
+ .create(**kwargs)
198
+ )
199
+ logger.debug(
200
+ f"Successfully created model version set {model_version_set_name} with id {model_version_set.id}."
201
+ )
202
+ return (model_version_set.id, model_version_set_name)
203
+ else:
204
+ model_version_set = ModelVersionSet.from_id(model_version_set_id)
205
+ # TODO: tag should be selected based on which operation (eval/FT) invoke this method
206
+ if not _is_valid_mvs(model_version_set, tag):
207
+ raise AquaValueError(
208
+ f"Invalid model version set id. Please provide a model version set with `{tag}` in tags."
209
+ )
210
+ return (model_version_set_id, model_version_set.name)
211
+
212
+ # TODO: refactor model evaluation implementation to use it.
213
+ @staticmethod
214
+ def create_model_catalog(
215
+ display_name: str,
216
+ description: str,
217
+ model_version_set_id: str,
218
+ model_custom_metadata: Union[ModelCustomMetadata, Dict],
219
+ model_taxonomy_metadata: Union[ModelTaxonomyMetadata, Dict],
220
+ compartment_id: str,
221
+ project_id: str,
222
+ **kwargs,
223
+ ) -> DataScienceModel:
224
+ model = (
225
+ DataScienceModel()
226
+ .with_compartment_id(compartment_id)
227
+ .with_project_id(project_id)
228
+ .with_display_name(display_name)
229
+ .with_description(description)
230
+ .with_model_version_set_id(model_version_set_id)
231
+ .with_custom_metadata_list(model_custom_metadata)
232
+ .with_defined_metadata_list(model_taxonomy_metadata)
233
+ .with_provenance_metadata(ModelProvenanceMetadata(training_id=UNKNOWN))
234
+ # TODO: decide what parameters will be needed
235
+ .create(
236
+ **kwargs,
237
+ )
238
+ )
239
+ return model
240
+
241
+ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
242
+ """Checks if the artifact exists.
243
+
244
+ Parameters
245
+ ----------
246
+ model_id : str
247
+ The model OCID.
248
+ **kwargs :
249
+ Additional keyword arguments passed in head_model_artifact.
250
+
251
+ Returns
252
+ -------
253
+ bool
254
+ Whether the artifact exists.
255
+ """
256
+
257
+ try:
258
+ response = self.ds_client.head_model_artifact(model_id=model_id, **kwargs)
259
+ return True if response.status == 200 else False
260
+ except oci.exceptions.ServiceError as ex:
261
+ if ex.status == 404:
262
+ logger.info(f"Artifact not found in model {model_id}.")
263
+ return False
264
+
265
+ def get_config(self, model_id: str, config_file_name: str) -> Dict:
266
+ """Gets the config for the given Aqua model.
267
+
268
+ Parameters
269
+ ----------
270
+ model_id: str
271
+ The OCID of the Aqua model.
272
+ config_file_name: str
273
+ name of the config file
274
+
275
+ Returns
276
+ -------
277
+ Dict:
278
+ A dict of allowed configs.
279
+ """
280
+ oci_model = self.ds_client.get_model(model_id).data
281
+ oci_aqua = (
282
+ (
283
+ Tags.AQUA_TAG.value in oci_model.freeform_tags
284
+ or Tags.AQUA_TAG.value.lower() in oci_model.freeform_tags
285
+ )
286
+ if oci_model.freeform_tags
287
+ else False
288
+ )
289
+
290
+ if not oci_aqua:
291
+ raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
292
+
293
+ config = {}
294
+ artifact_path = get_artifact_path(oci_model.custom_metadata_list)
295
+ if not artifact_path:
296
+ logger.error(
297
+ f"Failed to get artifact path from custom metadata for the model: {model_id}"
298
+ )
299
+ return config
300
+
301
+ try:
302
+ config_path = f"{os.path.dirname(artifact_path)}/config/"
303
+ config = load_config(
304
+ config_path,
305
+ config_file_name=config_file_name,
306
+ )
307
+ except:
308
+ pass
309
+
310
+ if not config:
311
+ logger.error(
312
+ f"{config_file_name} is not available for the model: {model_id}. Check if the custom metadata has the artifact path set."
313
+ )
314
+ return config
315
+
316
+ return config
317
+
318
+ @property
319
+ def telemetry(self):
320
+ if not self._telemetry:
321
+ self._telemetry = TelemetryClient(
322
+ bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS
323
+ )
324
+ return self._telemetry
ads/aqua/cli.py ADDED
@@ -0,0 +1,19 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*--
3
+
4
+ # Copyright (c) 2024 Oracle and/or its affiliates.
5
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
+
7
+ from ads.aqua.deployment import AquaDeploymentApp
8
+ from ads.aqua.finetune import AquaFineTuningApp
9
+ from ads.aqua.model import AquaModelApp
10
+ from ads.aqua.evaluation import AquaEvaluationApp
11
+
12
+
13
+ class AquaCommand:
14
+ """Contains the command groups for project Aqua."""
15
+
16
+ model = AquaModelApp
17
+ fine_tuning = AquaFineTuningApp
18
+ deployment = AquaDeploymentApp
19
+ evaluation = AquaEvaluationApp
@@ -0,0 +1,9 @@
1
+ {
2
+ "shape": [
3
+ "VM.GPU.A10.1",
4
+ "VM.GPU.A10.2",
5
+ "BM.GPU.A10.4",
6
+ "BM.GPU4.8",
7
+ "BM.GPU.A100-v2.8"
8
+ ]
9
+ }
@@ -0,0 +1,7 @@
1
+ {
2
+ "BM.GPU.A10.4": "ds-gpu-a10-count",
3
+ "BM.GPU.A100-v2.8": "ds-gpu-a100-v2-count",
4
+ "BM.GPU4.8": "ds-gpu4-count",
5
+ "VM.GPU.A10.1": "ds-gpu-a10-count",
6
+ "VM.GPU.A10.2": "ds-gpu-a10-count"
7
+ }
ads/aqua/constants.py ADDED
@@ -0,0 +1,45 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+ """This module defines constants used in ads.aqua module."""
6
+ from enum import Enum
7
+
8
+ UNKNOWN_VALUE = ""
9
+
10
+
11
+ class RqsAdditionalDetails:
12
+ METADATA = "metadata"
13
+ CREATED_BY = "createdBy"
14
+ DESCRIPTION = "description"
15
+ MODEL_VERSION_SET_ID = "modelVersionSetId"
16
+ MODEL_VERSION_SET_NAME = "modelVersionSetName"
17
+ PROJECT_ID = "projectId"
18
+ VERSION_LABEL = "versionLabel"
19
+
20
+
21
+ class FineTuningDefinedMetadata(Enum):
22
+ """Represents the defined metadata keys used in Fine Tuning."""
23
+
24
+ VAL_SET_SIZE = "val_set_size"
25
+ TRAINING_DATA = "training_data"
26
+
27
+
28
+ class FineTuningCustomMetadata(Enum):
29
+ """Represents the custom metadata keys used in Fine Tuning."""
30
+
31
+ FT_SOURCE = "fine_tune_source"
32
+ FT_SOURCE_NAME = "fine_tune_source_name"
33
+ FT_OUTPUT_PATH = "fine_tune_output_path"
34
+ FT_JOB_ID = "fine_tune_job_id"
35
+ FT_JOB_RUN_ID = "fine_tune_jobrun_id"
36
+ TRAINING_METRICS_FINAL = "train_metrics_final"
37
+ VALIDATION_METRICS_FINAL = "val_metrics_final"
38
+ TRAINING_METRICS_EPOCH = "train_metrics_epoch"
39
+ VALIDATION_METRICS_EPOCH = "val_metrics_epoch"
40
+
41
+
42
+ TRAINING_METRICS_FINAL = "training_metrics_final"
43
+ VALIDATION_METRICS_FINAL = "validation_metrics_final"
44
+ TRINING_METRICS = "training_metrics"
45
+ VALIDATION_METRICS = "validation_metrics"
ads/aqua/data.py ADDED
@@ -0,0 +1,40 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+
9
+ from ads.common.serializer import DataClassSerializable
10
+
11
+
12
+ @dataclass(repr=False)
13
+ class AquaResourceIdentifier(DataClassSerializable):
14
+ id: str = ""
15
+ name: str = ""
16
+ url: str = ""
17
+
18
+
19
+ class Resource(Enum):
20
+ JOB = "jobs"
21
+ JOBRUN = "jobruns"
22
+ MODEL = "models"
23
+ MODEL_DEPLOYMENT = "modeldeployments"
24
+ MODEL_VERSION_SET = "model-version-sets"
25
+
26
+
27
+ class DataScienceResource(Enum):
28
+ MODEL_DEPLOYMENT = "datasciencemodeldeployment"
29
+ MODEL = "datasciencemodel"
30
+
31
+
32
+ class Tags(Enum):
33
+ TASK = "task"
34
+ LICENSE = "license"
35
+ ORGANIZATION = "organization"
36
+ AQUA_TAG = "OCI_AQUA"
37
+ AQUA_SERVICE_MODEL_TAG = "aqua_service_model"
38
+ AQUA_FINE_TUNED_MODEL_TAG = "aqua_fine_tuned_model"
39
+ AQUA_EVALUATION = "aqua_evaluation"
40
+ AQUA_FINE_TUNING = "aqua_finetuning"
ads/aqua/decorator.py ADDED
@@ -0,0 +1,101 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ """Decorator module."""
7
+
8
+ import sys
9
+ from functools import wraps
10
+
11
+ from oci.exceptions import (
12
+ ClientError,
13
+ CompositeOperationError,
14
+ ConnectTimeout,
15
+ MissingEndpointForNonRegionalServiceClientError,
16
+ MultipartUploadError,
17
+ RequestException,
18
+ ServiceError,
19
+ )
20
+
21
+ from ads.aqua.exception import AquaError
22
+ from ads.aqua.extension.base_handler import AquaAPIhandler
23
+
24
+
25
+ def handle_exceptions(func):
26
+ """Writes errors raised during call to JSON.
27
+
28
+ This decorator is designed to be used with methods in handler.py that
29
+ interact with external services or perform operations which might
30
+ fail. This decorator should be applied only to instance methods of
31
+ classes within handler.py, as it is tailored to handle exceptions
32
+ specific to the operations performed by these handlers.
33
+
34
+ Parameters
35
+ ----------
36
+ func (Callable): The function to be wrapped by the decorator.
37
+
38
+ Returns
39
+ -------
40
+ Callable: A wrapper function that catches exceptions thrown by `func`.
41
+
42
+ Examples
43
+ --------
44
+
45
+ >>> from ads.aqua.decorator import handle_exceptions
46
+
47
+ >>> @handle_exceptions
48
+ >>> def some_method(self, arg1, arg2):
49
+ ... # Method implementation...
50
+ ... pass
51
+
52
+ """
53
+
54
+ @wraps(func)
55
+ def inner_function(self: AquaAPIhandler, *args, **kwargs):
56
+ try:
57
+ return func(self, *args, **kwargs)
58
+ except ServiceError as error:
59
+ self.write_error(
60
+ status_code=error.status or 500,
61
+ reason=error.message,
62
+ service_payload=error.args[0] if error.args else None,
63
+ exc_info=sys.exc_info(),
64
+ )
65
+ except (
66
+ ClientError,
67
+ MissingEndpointForNonRegionalServiceClientError,
68
+ RequestException,
69
+ ) as error:
70
+ self.write_error(
71
+ status_code=400,
72
+ reason=f"{type(error).__name__}: {str(error)}",
73
+ exc_info=sys.exc_info(),
74
+ )
75
+ except ConnectTimeout as error:
76
+ self.write_error(
77
+ status_code=408,
78
+ reason=f"{type(error).__name__}: {str(error)}",
79
+ exc_info=sys.exc_info(),
80
+ )
81
+ except (MultipartUploadError, CompositeOperationError) as error:
82
+ self.write_error(
83
+ status_code=500,
84
+ reason=f"{type(error).__name__}: {str(error)}",
85
+ exc_info=sys.exc_info(),
86
+ )
87
+ except AquaError as error:
88
+ self.write_error(
89
+ status_code=error.status,
90
+ reason=error.reason,
91
+ service_payload=error.service_payload,
92
+ exc_info=sys.exc_info(),
93
+ )
94
+ except Exception as ex:
95
+ self.write_error(
96
+ status_code=500,
97
+ reason=f"{type(ex).__name__}: {str(ex)}",
98
+ exc_info=sys.exc_info(),
99
+ )
100
+
101
+ return inner_function