oracle-ads 2.11.7__py3-none-any.whl → 2.11.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +24 -14
- ads/aqua/base.py +0 -2
- ads/aqua/cli.py +50 -2
- ads/aqua/decorator.py +8 -0
- ads/aqua/deployment.py +37 -34
- ads/aqua/evaluation.py +106 -49
- ads/aqua/extension/base_handler.py +18 -10
- ads/aqua/extension/common_handler.py +21 -2
- ads/aqua/extension/deployment_handler.py +1 -4
- ads/aqua/extension/evaluation_handler.py +1 -2
- ads/aqua/extension/finetune_handler.py +0 -1
- ads/aqua/extension/ui_handler.py +1 -12
- ads/aqua/extension/utils.py +4 -4
- ads/aqua/finetune.py +24 -11
- ads/aqua/model.py +2 -4
- ads/aqua/utils.py +39 -23
- ads/cli.py +19 -1
- ads/common/serializer.py +5 -4
- ads/common/utils.py +6 -2
- ads/config.py +1 -0
- ads/llm/serializers/runnable_parallel.py +7 -1
- ads/opctl/operator/lowcode/anomaly/README.md +1 -1
- ads/opctl/operator/lowcode/anomaly/environment.yaml +1 -1
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +15 -10
- ads/opctl/operator/lowcode/anomaly/model/autots.py +9 -10
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +34 -37
- ads/opctl/operator/lowcode/anomaly/model/tods.py +4 -4
- ads/opctl/operator/lowcode/anomaly/schema.yaml +1 -1
- ads/opctl/operator/lowcode/forecast/README.md +1 -1
- ads/opctl/operator/lowcode/forecast/environment.yaml +4 -4
- ads/opctl/operator/lowcode/forecast/model/arima.py +19 -21
- ads/opctl/operator/lowcode/forecast/model/automlx.py +36 -42
- ads/opctl/operator/lowcode/forecast/model/autots.py +41 -25
- ads/opctl/operator/lowcode/forecast/model/base_model.py +93 -107
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +51 -45
- ads/opctl/operator/lowcode/forecast/model/prophet.py +32 -27
- ads/opctl/operator/lowcode/forecast/schema.yaml +2 -2
- ads/opctl/operator/lowcode/forecast/utils.py +4 -4
- ads/opctl/operator/lowcode/pii/README.md +1 -1
- ads/opctl/operator/lowcode/pii/environment.yaml +1 -1
- ads/opctl/operator/lowcode/pii/model/report.py +71 -70
- {oracle_ads-2.11.7.dist-info → oracle_ads-2.11.8.dist-info}/METADATA +5 -5
- {oracle_ads-2.11.7.dist-info → oracle_ads-2.11.8.dist-info}/RECORD +46 -46
- {oracle_ads-2.11.7.dist-info → oracle_ads-2.11.8.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.11.7.dist-info → oracle_ads-2.11.8.dist-info}/WHEEL +0 -0
- {oracle_ads-2.11.7.dist-info → oracle_ads-2.11.8.dist-info}/entry_points.txt +0 -0
@@ -8,14 +8,16 @@ import json
|
|
8
8
|
import traceback
|
9
9
|
import uuid
|
10
10
|
from dataclasses import asdict, is_dataclass
|
11
|
+
from http.client import responses
|
11
12
|
from typing import Any
|
12
13
|
|
13
14
|
from notebook.base.handlers import APIHandler
|
14
|
-
from tornado.web import HTTPError, Application
|
15
15
|
from tornado import httputil
|
16
|
-
from
|
17
|
-
|
16
|
+
from tornado.web import Application, HTTPError
|
17
|
+
|
18
18
|
from ads.aqua import logger
|
19
|
+
from ads.config import AQUA_TELEMETRY_BUCKET, AQUA_TELEMETRY_BUCKET_NS
|
20
|
+
from ads.telemetry.client import TelemetryClient
|
19
21
|
|
20
22
|
|
21
23
|
class AquaAPIhandler(APIHandler):
|
@@ -66,12 +68,15 @@ class AquaAPIhandler(APIHandler):
|
|
66
68
|
|
67
69
|
def write_error(self, status_code, **kwargs):
|
68
70
|
"""AquaAPIhandler errors are JSON, not human pages."""
|
69
|
-
|
70
71
|
self.set_header("Content-Type", "application/json")
|
71
72
|
reason = kwargs.get("reason")
|
72
73
|
self.set_status(status_code, reason=reason)
|
73
74
|
service_payload = kwargs.get("service_payload", {})
|
74
|
-
|
75
|
+
default_msg = responses.get(status_code, "Unknown HTTP Error")
|
76
|
+
message = self.get_default_error_messages(
|
77
|
+
service_payload, str(status_code), kwargs.get("message", default_msg)
|
78
|
+
)
|
79
|
+
|
75
80
|
reply = {
|
76
81
|
"status": status_code,
|
77
82
|
"message": message,
|
@@ -84,7 +89,7 @@ class AquaAPIhandler(APIHandler):
|
|
84
89
|
e = exc_info[1]
|
85
90
|
if isinstance(e, HTTPError):
|
86
91
|
reply["message"] = e.log_message or message
|
87
|
-
reply["reason"] = e.reason
|
92
|
+
reply["reason"] = e.reason if e.reason else reply["reason"]
|
88
93
|
reply["request_id"] = str(uuid.uuid4())
|
89
94
|
else:
|
90
95
|
reply["request_id"] = str(uuid.uuid4())
|
@@ -102,7 +107,11 @@ class AquaAPIhandler(APIHandler):
|
|
102
107
|
self.finish(json.dumps(reply))
|
103
108
|
|
104
109
|
@staticmethod
|
105
|
-
def get_default_error_messages(
|
110
|
+
def get_default_error_messages(
|
111
|
+
service_payload: dict,
|
112
|
+
status_code: str,
|
113
|
+
default_msg: str = "Unknown HTTP Error.",
|
114
|
+
):
|
106
115
|
"""Method that maps the error messages based on the operation performed or the status codes encountered."""
|
107
116
|
|
108
117
|
messages = {
|
@@ -110,7 +119,6 @@ class AquaAPIhandler(APIHandler):
|
|
110
119
|
"403": "We're having trouble processing your request with the information provided.",
|
111
120
|
"404": "Authorization Failed: The resource you're looking for isn't accessible.",
|
112
121
|
"408": "Server is taking too long to response, please try again.",
|
113
|
-
"500": "An error occurred while creating the resource.",
|
114
122
|
"create": "Authorization Failed: Could not create resource.",
|
115
123
|
"get": "Authorization Failed: The resource you're looking for isn't accessible.",
|
116
124
|
}
|
@@ -119,7 +127,7 @@ class AquaAPIhandler(APIHandler):
|
|
119
127
|
operation_name = service_payload["operation_name"]
|
120
128
|
if operation_name:
|
121
129
|
if operation_name.startswith("create"):
|
122
|
-
return messages["create"]
|
130
|
+
return messages["create"] + f" Operation Name: {operation_name}."
|
123
131
|
elif operation_name.startswith("list") or operation_name.startswith(
|
124
132
|
"get"
|
125
133
|
):
|
@@ -128,7 +136,7 @@ class AquaAPIhandler(APIHandler):
|
|
128
136
|
if status_code in messages:
|
129
137
|
return messages[status_code]
|
130
138
|
else:
|
131
|
-
return
|
139
|
+
return default_msg
|
132
140
|
|
133
141
|
|
134
142
|
# todo: remove after error handler is implemented
|
@@ -6,14 +6,17 @@
|
|
6
6
|
|
7
7
|
from importlib import metadata
|
8
8
|
|
9
|
-
from ads.aqua.extension.base_handler import AquaAPIhandler
|
10
9
|
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
|
10
|
+
from ads.aqua.decorator import handle_exceptions
|
11
11
|
from ads.aqua.exception import AquaResourceAccessError
|
12
|
+
from ads.aqua.extension.base_handler import AquaAPIhandler
|
13
|
+
from ads.aqua.utils import known_realm, fetch_service_compartment
|
12
14
|
|
13
15
|
|
14
16
|
class ADSVersionHandler(AquaAPIhandler):
|
15
17
|
"""The handler to get the current version of the ADS."""
|
16
18
|
|
19
|
+
@handle_exceptions
|
17
20
|
def get(self):
|
18
21
|
self.finish({"data": metadata.version("oracle_ads")})
|
19
22
|
|
@@ -21,9 +24,25 @@ class ADSVersionHandler(AquaAPIhandler):
|
|
21
24
|
class CompatibilityCheckHandler(AquaAPIhandler):
|
22
25
|
"""The handler to check if the extension is compatible."""
|
23
26
|
|
27
|
+
@handle_exceptions
|
24
28
|
def get(self):
|
25
|
-
|
29
|
+
"""This method provides the availability status of Aqua. If ODSC_MODEL_COMPARTMENT_OCID environment variable
|
30
|
+
is set, then status `ok` is returned. For regions where Aqua is available but the environment variable is not
|
31
|
+
set due to accesses/policies, we return the `compatible` status to indicate that the extension can be enabled
|
32
|
+
for the selected notebook session.
|
33
|
+
|
34
|
+
Returns
|
35
|
+
-------
|
36
|
+
status dict:
|
37
|
+
ok or compatible
|
38
|
+
Raises:
|
39
|
+
AquaResourceAccessError: raised when aqua is not accessible in the given session/region.
|
40
|
+
|
41
|
+
"""
|
42
|
+
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
|
26
43
|
return self.finish(dict(status="ok"))
|
44
|
+
elif known_realm():
|
45
|
+
return self.finish(dict(status="compatible"))
|
27
46
|
else:
|
28
47
|
raise AquaResourceAccessError(
|
29
48
|
f"The AI Quick actions extension is not compatible in the given region."
|
@@ -7,10 +7,10 @@ from urllib.parse import urlparse
|
|
7
7
|
|
8
8
|
from tornado.web import HTTPError
|
9
9
|
|
10
|
+
from ads.aqua.decorator import handle_exceptions
|
10
11
|
from ads.aqua.deployment import AquaDeploymentApp, MDInferenceResponse, ModelParams
|
11
12
|
from ads.aqua.extension.base_handler import AquaAPIhandler, Errors
|
12
13
|
from ads.config import COMPARTMENT_OCID, PROJECT_OCID
|
13
|
-
from ads.aqua.decorator import handle_exceptions
|
14
14
|
|
15
15
|
|
16
16
|
class AquaDeploymentHandler(AquaAPIhandler):
|
@@ -110,12 +110,10 @@ class AquaDeploymentHandler(AquaAPIhandler):
|
|
110
110
|
)
|
111
111
|
)
|
112
112
|
|
113
|
-
@handle_exceptions
|
114
113
|
def read(self, id):
|
115
114
|
"""Read the information of an Aqua model deployment."""
|
116
115
|
return self.finish(AquaDeploymentApp().get(model_deployment_id=id))
|
117
116
|
|
118
|
-
@handle_exceptions
|
119
117
|
def list(self):
|
120
118
|
"""List Aqua models."""
|
121
119
|
# If default is not specified,
|
@@ -129,7 +127,6 @@ class AquaDeploymentHandler(AquaAPIhandler):
|
|
129
127
|
)
|
130
128
|
)
|
131
129
|
|
132
|
-
@handle_exceptions
|
133
130
|
def get_deployment_config(self, model_id):
|
134
131
|
"""Gets the deployment config for Aqua model."""
|
135
132
|
return self.finish(AquaDeploymentApp().get_deployment_config(model_id=model_id))
|
@@ -5,11 +5,10 @@
|
|
5
5
|
|
6
6
|
from urllib.parse import urlparse
|
7
7
|
|
8
|
-
from
|
8
|
+
from tornado.web import HTTPError
|
9
9
|
|
10
10
|
from ads.aqua.decorator import handle_exceptions
|
11
11
|
from ads.aqua.evaluation import AquaEvaluationApp, CreateAquaEvaluationDetails
|
12
|
-
from ads.aqua.exception import AquaError
|
13
12
|
from ads.aqua.extension.base_handler import AquaAPIhandler, Errors
|
14
13
|
from ads.aqua.extension.utils import validate_function_parameters
|
15
14
|
from ads.config import COMPARTMENT_OCID
|
@@ -54,7 +54,6 @@ class AquaFineTuneHandler(AquaAPIhandler):
|
|
54
54
|
|
55
55
|
self.finish(AquaFineTuningApp().create(CreateFineTuningDetails(**input_data)))
|
56
56
|
|
57
|
-
@handle_exceptions
|
58
57
|
def get_finetuning_config(self, model_id):
|
59
58
|
"""Gets the finetuning config for Aqua model."""
|
60
59
|
return self.finish(AquaFineTuningApp().get_finetuning_config(model_id=model_id))
|
ads/aqua/extension/ui_handler.py
CHANGED
@@ -34,6 +34,7 @@ class AquaUIHandler(AquaAPIhandler):
|
|
34
34
|
HTTPError: For various failure scenarios such as invalid input format, missing data, etc.
|
35
35
|
"""
|
36
36
|
|
37
|
+
@handle_exceptions
|
37
38
|
def get(self, id=""):
|
38
39
|
"""Handle GET request."""
|
39
40
|
url_parse = urlparse(self.request.path)
|
@@ -76,7 +77,6 @@ class AquaUIHandler(AquaAPIhandler):
|
|
76
77
|
else:
|
77
78
|
raise HTTPError(400, f"The request {self.request.path} is invalid.")
|
78
79
|
|
79
|
-
@handle_exceptions
|
80
80
|
def list_log_groups(self, **kwargs):
|
81
81
|
"""Lists all log groups for the specified compartment or tenancy."""
|
82
82
|
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
|
@@ -84,22 +84,18 @@ class AquaUIHandler(AquaAPIhandler):
|
|
84
84
|
AquaUIApp().list_log_groups(compartment_id=compartment_id, **kwargs)
|
85
85
|
)
|
86
86
|
|
87
|
-
@handle_exceptions
|
88
87
|
def list_logs(self, log_group_id: str, **kwargs):
|
89
88
|
"""Lists the specified log group's log objects."""
|
90
89
|
return self.finish(AquaUIApp().list_logs(log_group_id=log_group_id, **kwargs))
|
91
90
|
|
92
|
-
@handle_exceptions
|
93
91
|
def list_compartments(self):
|
94
92
|
"""Lists the compartments in a compartment specified by ODSC_MODEL_COMPARTMENT_OCID env variable."""
|
95
93
|
return self.finish(AquaUIApp().list_compartments())
|
96
94
|
|
97
|
-
@handle_exceptions
|
98
95
|
def get_default_compartment(self):
|
99
96
|
"""Returns user compartment ocid."""
|
100
97
|
return self.finish(AquaUIApp().get_default_compartment())
|
101
98
|
|
102
|
-
@handle_exceptions
|
103
99
|
def list_model_version_sets(self, **kwargs):
|
104
100
|
"""Lists all model version sets for the specified compartment or tenancy."""
|
105
101
|
|
@@ -112,7 +108,6 @@ class AquaUIHandler(AquaAPIhandler):
|
|
112
108
|
)
|
113
109
|
)
|
114
110
|
|
115
|
-
@handle_exceptions
|
116
111
|
def list_experiments(self, **kwargs):
|
117
112
|
"""Lists all experiments for the specified compartment or tenancy."""
|
118
113
|
|
@@ -125,7 +120,6 @@ class AquaUIHandler(AquaAPIhandler):
|
|
125
120
|
)
|
126
121
|
)
|
127
122
|
|
128
|
-
@handle_exceptions
|
129
123
|
def list_buckets(self, **kwargs):
|
130
124
|
"""Lists all model version sets for the specified compartment or tenancy."""
|
131
125
|
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
|
@@ -138,7 +132,6 @@ class AquaUIHandler(AquaAPIhandler):
|
|
138
132
|
)
|
139
133
|
)
|
140
134
|
|
141
|
-
@handle_exceptions
|
142
135
|
def list_job_shapes(self, **kwargs):
|
143
136
|
"""Lists job shapes available in the specified compartment."""
|
144
137
|
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
|
@@ -146,7 +139,6 @@ class AquaUIHandler(AquaAPIhandler):
|
|
146
139
|
AquaUIApp().list_job_shapes(compartment_id=compartment_id, **kwargs)
|
147
140
|
)
|
148
141
|
|
149
|
-
@handle_exceptions
|
150
142
|
def list_vcn(self, **kwargs):
|
151
143
|
"""Lists the virtual cloud networks (VCNs) in the specified compartment."""
|
152
144
|
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
|
@@ -154,7 +146,6 @@ class AquaUIHandler(AquaAPIhandler):
|
|
154
146
|
AquaUIApp().list_vcn(compartment_id=compartment_id, **kwargs)
|
155
147
|
)
|
156
148
|
|
157
|
-
@handle_exceptions
|
158
149
|
def list_subnets(self, **kwargs):
|
159
150
|
"""Lists the subnets in the specified VCN and the specified compartment."""
|
160
151
|
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
|
@@ -165,7 +156,6 @@ class AquaUIHandler(AquaAPIhandler):
|
|
165
156
|
)
|
166
157
|
)
|
167
158
|
|
168
|
-
@handle_exceptions
|
169
159
|
def get_shape_availability(self, **kwargs):
|
170
160
|
"""For a given compartmentId, resource limit name, and scope, returns the number of available resources associated
|
171
161
|
with the given limit."""
|
@@ -178,7 +168,6 @@ class AquaUIHandler(AquaAPIhandler):
|
|
178
168
|
)
|
179
169
|
)
|
180
170
|
|
181
|
-
@handle_exceptions
|
182
171
|
def is_bucket_versioned(self):
|
183
172
|
"""For a given compartmentId, resource limit name, and scope, returns the number of available resources associated
|
184
173
|
with the given limit."""
|
ads/aqua/extension/utils.py
CHANGED
@@ -4,16 +4,16 @@
|
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
from dataclasses import fields
|
6
6
|
from typing import Dict, Optional
|
7
|
-
|
7
|
+
|
8
|
+
from tornado.web import HTTPError
|
8
9
|
|
9
10
|
from ads.aqua.extension.base_handler import Errors
|
10
11
|
|
11
12
|
|
12
13
|
def validate_function_parameters(data_class, input_data: Dict):
|
13
|
-
"""Validates if the required parameters are provided in input data."""
|
14
|
+
"""Validates if the required parameters are provided in input data."""
|
14
15
|
required_parameters = [
|
15
|
-
field.name for field in fields(data_class)
|
16
|
-
if field.type != Optional[field.type]
|
16
|
+
field.name for field in fields(data_class) if field.type != Optional[field.type]
|
17
17
|
]
|
18
18
|
|
19
19
|
for required_parameter in required_parameters:
|
ads/aqua/finetune.py
CHANGED
@@ -15,6 +15,7 @@ from oci.data_science.models import (
|
|
15
15
|
UpdateModelProvenanceDetails,
|
16
16
|
)
|
17
17
|
|
18
|
+
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
|
18
19
|
from ads.aqua.base import AquaApp
|
19
20
|
from ads.aqua.data import AquaResourceIdentifier, Resource, Tags
|
20
21
|
from ads.aqua.exception import AquaFileExistsError, AquaValueError
|
@@ -28,7 +29,6 @@ from ads.aqua.utils import (
|
|
28
29
|
UNKNOWN,
|
29
30
|
UNKNOWN_DICT,
|
30
31
|
get_container_image,
|
31
|
-
logger,
|
32
32
|
upload_local_to_os,
|
33
33
|
)
|
34
34
|
from ads.common.auth import default_signer
|
@@ -69,6 +69,7 @@ class FineTuneCustomMetadata(Enum):
|
|
69
69
|
class AquaFineTuningParams(DataClassSerializable):
|
70
70
|
epochs: int = None
|
71
71
|
learning_rate: float = None
|
72
|
+
sample_packing: str = "True"
|
72
73
|
|
73
74
|
|
74
75
|
@dataclass(repr=False)
|
@@ -122,6 +123,8 @@ class CreateFineTuningDetails(DataClassSerializable):
|
|
122
123
|
The log group id for fine tuning job infrastructure.
|
123
124
|
log_id: (str, optional). Defaults to `None`.
|
124
125
|
The log id for fine tuning job infrastructure.
|
126
|
+
force_overwrite: (bool, optional). Defaults to `False`.
|
127
|
+
Whether to force overwrite the existing file in object storage.
|
125
128
|
"""
|
126
129
|
|
127
130
|
ft_source_id: str
|
@@ -142,6 +145,7 @@ class CreateFineTuningDetails(DataClassSerializable):
|
|
142
145
|
subnet_id: Optional[str] = None
|
143
146
|
log_id: Optional[str] = None
|
144
147
|
log_group_id: Optional[str] = None
|
148
|
+
force_overwrite: Optional[bool] = False
|
145
149
|
|
146
150
|
|
147
151
|
class AquaFineTuningApp(AquaApp):
|
@@ -192,12 +196,11 @@ class AquaFineTuningApp(AquaApp):
|
|
192
196
|
)
|
193
197
|
|
194
198
|
source = self.get_source(create_fine_tuning_details.ft_source_id)
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
# )
|
199
|
+
if source.compartment_id != ODSC_MODEL_COMPARTMENT_OCID:
|
200
|
+
raise AquaValueError(
|
201
|
+
f"Fine tuning is only supported for Aqua service models in {ODSC_MODEL_COMPARTMENT_OCID}. "
|
202
|
+
"Use a valid Aqua service model id instead."
|
203
|
+
)
|
201
204
|
|
202
205
|
target_compartment = (
|
203
206
|
create_fine_tuning_details.compartment_id or COMPARTMENT_OCID
|
@@ -273,12 +276,12 @@ class AquaFineTuningApp(AquaApp):
|
|
273
276
|
src_uri=ft_dataset_path,
|
274
277
|
dst_uri=dst_uri,
|
275
278
|
auth=default_signer(),
|
276
|
-
force_overwrite=
|
279
|
+
force_overwrite=create_fine_tuning_details.force_overwrite,
|
277
280
|
)
|
278
281
|
except FileExistsError:
|
279
282
|
raise AquaFileExistsError(
|
280
283
|
f"Dataset {dataset_file} already exists in {create_fine_tuning_details.report_path}. "
|
281
|
-
"Please use a new dataset file name or
|
284
|
+
"Please use a new dataset file name, report path or set `force_overwrite` as True."
|
282
285
|
)
|
283
286
|
logger.debug(
|
284
287
|
f"Uploaded local file {ft_dataset_path} to object storage {dst_uri}."
|
@@ -460,16 +463,26 @@ class AquaFineTuningApp(AquaApp):
|
|
460
463
|
telemetry_kwargs = (
|
461
464
|
{"ocid": ft_job.id[-6:]} if ft_job and len(ft_job.id) > 6 else {}
|
462
465
|
)
|
466
|
+
# track shapes that were used for fine-tune creation
|
463
467
|
self.telemetry.record_event_async(
|
464
|
-
category=f"aqua/service/
|
468
|
+
category=f"aqua/service/finetune/create/shape/",
|
465
469
|
action=f"{create_fine_tuning_details.shape_name}x{create_fine_tuning_details.replica}",
|
466
470
|
**telemetry_kwargs,
|
467
471
|
)
|
468
472
|
# tracks unique fine-tuned models that were created in the user compartment
|
473
|
+
# TODO: retrieve the service model name for FT custom models.
|
469
474
|
self.telemetry.record_event_async(
|
470
475
|
category="aqua/service/finetune",
|
471
476
|
action="create",
|
472
477
|
detail=source.display_name,
|
478
|
+
**telemetry_kwargs,
|
479
|
+
)
|
480
|
+
# track combination of model and shape used for fine-tune creation
|
481
|
+
self.telemetry.record_event_async(
|
482
|
+
category="aqua/service/finetune/create",
|
483
|
+
action="shape",
|
484
|
+
detail=f"{create_fine_tuning_details.shape_name}x{create_fine_tuning_details.replica}",
|
485
|
+
value=source.display_name,
|
473
486
|
)
|
474
487
|
|
475
488
|
return AquaFineTuningSummary(
|
@@ -550,7 +563,7 @@ class AquaFineTuningApp(AquaApp):
|
|
550
563
|
}
|
551
564
|
),
|
552
565
|
"OCI__LAUNCH_CMD": (
|
553
|
-
f"--micro_batch_size {batch_size} --num_epochs {parameters.epochs} --learning_rate {parameters.learning_rate} --training_data {dataset_path} --output_dir {report_path} --val_set_size {val_set_size} "
|
566
|
+
f"--micro_batch_size {batch_size} --num_epochs {parameters.epochs} --learning_rate {parameters.learning_rate} --training_data {dataset_path} --output_dir {report_path} --val_set_size {val_set_size} --sample_packing {parameters.sample_packing} "
|
554
567
|
+ (f"{finetuning_params}" if finetuning_params else "")
|
555
568
|
),
|
556
569
|
"CONDA_BUCKET_NS": CONDA_BUCKET_NS,
|
ads/aqua/model.py
CHANGED
@@ -14,7 +14,7 @@ import oci
|
|
14
14
|
from cachetools import TTLCache
|
15
15
|
from oci.data_science.models import JobRun, Model
|
16
16
|
|
17
|
-
from ads.aqua import logger, utils
|
17
|
+
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger, utils
|
18
18
|
from ads.aqua.base import AquaApp
|
19
19
|
from ads.aqua.constants import (
|
20
20
|
TRAINING_METRICS_FINAL,
|
@@ -26,7 +26,6 @@ from ads.aqua.constants import (
|
|
26
26
|
)
|
27
27
|
from ads.aqua.data import AquaResourceIdentifier, Tags
|
28
28
|
from ads.aqua.exception import AquaRuntimeError
|
29
|
-
|
30
29
|
from ads.aqua.training.exceptions import exit_code_dict
|
31
30
|
from ads.aqua.utils import (
|
32
31
|
LICENSE_TXT,
|
@@ -50,7 +49,6 @@ from ads.config import (
|
|
50
49
|
PROJECT_OCID,
|
51
50
|
TENANCY_OCID,
|
52
51
|
)
|
53
|
-
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
|
54
52
|
from ads.model import DataScienceModel
|
55
53
|
from ads.model.model_metadata import MetadataTaxonomyKeys, ModelCustomMetadata
|
56
54
|
from ads.telemetry import telemetry
|
@@ -228,7 +226,7 @@ class AquaFineTuneModel(AquaModel, AquaEvalFTCommon, DataClassSerializable):
|
|
228
226
|
).value
|
229
227
|
except Exception as e:
|
230
228
|
logger.debug(
|
231
|
-
f"Failed to extract model hyperparameters from {model.id}:" f"{str(e)}"
|
229
|
+
f"Failed to extract model hyperparameters from {model.id}: " f"{str(e)}"
|
232
230
|
)
|
233
231
|
model_hyperparameters = {}
|
234
232
|
|
ads/aqua/utils.py
CHANGED
@@ -10,7 +10,6 @@ import logging
|
|
10
10
|
import os
|
11
11
|
import random
|
12
12
|
import re
|
13
|
-
import sys
|
14
13
|
from enum import Enum
|
15
14
|
from functools import wraps
|
16
15
|
from pathlib import Path
|
@@ -22,22 +21,16 @@ import oci
|
|
22
21
|
from oci.data_science.models import JobRun, Model
|
23
22
|
|
24
23
|
from ads.aqua.constants import RqsAdditionalDetails
|
25
|
-
from ads.aqua.data import AquaResourceIdentifier
|
24
|
+
from ads.aqua.data import AquaResourceIdentifier
|
26
25
|
from ads.aqua.exception import AquaFileNotFoundError, AquaRuntimeError, AquaValueError
|
27
26
|
from ads.common.auth import default_signer
|
28
27
|
from ads.common.object_storage_details import ObjectStorageDetails
|
29
28
|
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
|
30
29
|
from ads.common.utils import get_console_link, upload_to_os
|
31
|
-
from ads.config import
|
32
|
-
AQUA_SERVICE_MODELS_BUCKET,
|
33
|
-
CONDA_BUCKET_NS,
|
34
|
-
TENANCY_OCID,
|
35
|
-
)
|
30
|
+
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
|
36
31
|
from ads.model import DataScienceModel, ModelVersionSet
|
37
32
|
|
38
|
-
|
39
|
-
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
40
|
-
logger = logging.getLogger("ODSC_AQUA")
|
33
|
+
logger = logging.getLogger("ads.aqua")
|
41
34
|
|
42
35
|
UNKNOWN = ""
|
43
36
|
UNKNOWN_DICT = {}
|
@@ -79,6 +72,9 @@ NB_SESSION_IDENTIFIER = "NB_SESSION_OCID"
|
|
79
72
|
LIFECYCLE_DETAILS_MISSING_JOBRUN = "The asscociated JobRun resource has been deleted."
|
80
73
|
READY_TO_DEPLOY_STATUS = "ACTIVE"
|
81
74
|
READY_TO_FINE_TUNE_STATUS = "TRUE"
|
75
|
+
AQUA_GA_LIST = ["id19sfcrra6z"]
|
76
|
+
AQUA_MODEL_TYPE_SERVICE = "service"
|
77
|
+
AQUA_MODEL_TYPE_CUSTOM = "custom"
|
82
78
|
|
83
79
|
|
84
80
|
class LifecycleStatus(Enum):
|
@@ -144,10 +140,6 @@ SUPPORTED_FILE_FORMATS = ["jsonl"]
|
|
144
140
|
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
|
145
141
|
|
146
142
|
|
147
|
-
def get_logger():
|
148
|
-
return logger
|
149
|
-
|
150
|
-
|
151
143
|
def random_color_generator(word: str):
|
152
144
|
seed = sum([ord(c) for c in word]) % 13
|
153
145
|
random.seed(seed)
|
@@ -234,7 +226,7 @@ def read_file(file_path: str, **kwargs) -> str:
|
|
234
226
|
with fsspec.open(file_path, "r", **kwargs.get("auth", {})) as f:
|
235
227
|
return f.read()
|
236
228
|
except Exception as e:
|
237
|
-
logger.
|
229
|
+
logger.debug(f"Failed to read file {file_path}. {e}")
|
238
230
|
return UNKNOWN
|
239
231
|
|
240
232
|
|
@@ -484,7 +476,7 @@ def _build_resource_identifier(
|
|
484
476
|
),
|
485
477
|
)
|
486
478
|
except Exception as e:
|
487
|
-
logger.
|
479
|
+
logger.debug(
|
488
480
|
f"Failed to construct AquaResourceIdentifier from given id=`{id}`, and name=`{name}`, {str(e)}"
|
489
481
|
)
|
490
482
|
return AquaResourceIdentifier()
|
@@ -577,20 +569,27 @@ def get_container_image(
|
|
577
569
|
return container_image
|
578
570
|
|
579
571
|
|
580
|
-
def fetch_service_compartment():
|
581
|
-
"""Loads the compartment mapping json from service bucket
|
572
|
+
def fetch_service_compartment() -> Union[str, None]:
|
573
|
+
"""Loads the compartment mapping json from service bucket. This json file has a service-model-compartment key which
|
574
|
+
contains a dictionary of namespaces and the compartment OCID of the service models in that namespace.
|
575
|
+
"""
|
582
576
|
config_file_name = (
|
583
577
|
f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
|
584
578
|
)
|
585
579
|
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
580
|
+
try:
|
581
|
+
config = load_config(
|
582
|
+
file_path=config_file_name,
|
583
|
+
config_file_name=CONTAINER_INDEX,
|
584
|
+
)
|
585
|
+
except AquaFileNotFoundError:
|
586
|
+
logger.error(
|
587
|
+
f"Config file {config_file_name}/{CONTAINER_INDEX} to fetch service compartment OCID could not be found."
|
588
|
+
)
|
589
|
+
return
|
590
590
|
compartment_mapping = config.get(COMPARTMENT_MAPPING_KEY)
|
591
591
|
if compartment_mapping:
|
592
592
|
return compartment_mapping.get(CONDA_BUCKET_NS)
|
593
|
-
return None
|
594
593
|
|
595
594
|
|
596
595
|
def get_max_version(versions):
|
@@ -733,3 +732,20 @@ def _is_valid_mvs(mvs: ModelVersionSet, target_tag: str) -> bool:
|
|
733
732
|
return False
|
734
733
|
|
735
734
|
return target_tag in mvs.freeform_tags
|
735
|
+
|
736
|
+
|
737
|
+
def known_realm():
|
738
|
+
"""This helper function returns True if the Aqua service is available by default in the given namespace.
|
739
|
+
Returns
|
740
|
+
-------
|
741
|
+
bool:
|
742
|
+
Return True if aqua service is available.
|
743
|
+
|
744
|
+
"""
|
745
|
+
return os.environ.get("CONDA_BUCKET_NS") in AQUA_GA_LIST
|
746
|
+
|
747
|
+
|
748
|
+
def get_ocid_substring(ocid: str, key_len: int) -> str:
|
749
|
+
"""This helper function returns the last n characters of the ocid specified by key_len parameter.
|
750
|
+
If ocid is None or length is less than key_len, it returns an empty string."""
|
751
|
+
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
|
ads/cli.py
CHANGED
@@ -8,6 +8,7 @@ import traceback
|
|
8
8
|
import sys
|
9
9
|
|
10
10
|
import fire
|
11
|
+
from dataclasses import is_dataclass
|
11
12
|
from ads.common import logger
|
12
13
|
|
13
14
|
try:
|
@@ -70,11 +71,28 @@ def _SeparateFlagArgs(args):
|
|
70
71
|
fire.core.parser.SeparateFlagArgs = _SeparateFlagArgs
|
71
72
|
|
72
73
|
|
74
|
+
def serialize(data):
|
75
|
+
"""Serialize dataclass objects or lists of dataclass objects.
|
76
|
+
Parameters:
|
77
|
+
data: A dataclass object or a list of dataclass objects.
|
78
|
+
Returns:
|
79
|
+
None
|
80
|
+
Prints:
|
81
|
+
The string representation of each dataclass object.
|
82
|
+
"""
|
83
|
+
if isinstance(data, list):
|
84
|
+
[print(str(item)) for item in data]
|
85
|
+
else:
|
86
|
+
print(str(data))
|
87
|
+
|
88
|
+
|
73
89
|
def cli():
|
74
90
|
if len(sys.argv) > 1 and sys.argv[1] == "aqua":
|
75
91
|
from ads.aqua.cli import AquaCommand
|
76
92
|
|
77
|
-
fire.Fire(
|
93
|
+
fire.Fire(
|
94
|
+
AquaCommand, command=sys.argv[2:], name="ads aqua", serialize=serialize
|
95
|
+
)
|
78
96
|
else:
|
79
97
|
click_cli()
|
80
98
|
|
ads/common/serializer.py
CHANGED
@@ -195,7 +195,10 @@ class Serializable(ABC):
|
|
195
195
|
`None` in case when `uri` provided.
|
196
196
|
"""
|
197
197
|
json_string = json.dumps(
|
198
|
-
self.to_dict(**kwargs),
|
198
|
+
self.to_dict(**kwargs),
|
199
|
+
cls=encoder,
|
200
|
+
default=default or self.serialize,
|
201
|
+
indent=4,
|
199
202
|
)
|
200
203
|
if uri:
|
201
204
|
self._write_to_file(s=json_string, uri=uri, **kwargs)
|
@@ -463,9 +466,7 @@ class DataClassSerializable(Serializable):
|
|
463
466
|
"These fields will be ignored."
|
464
467
|
)
|
465
468
|
|
466
|
-
obj = cls(
|
467
|
-
**{key: obj_dict.get(key) for key in allowed_fields}
|
468
|
-
)
|
469
|
+
obj = cls(**{key: obj_dict.get(key) for key in allowed_fields})
|
469
470
|
|
470
471
|
for key, value in obj_dict.items():
|
471
472
|
if (
|
ads/common/utils.py
CHANGED
@@ -102,6 +102,8 @@ DIMENSION = 2
|
|
102
102
|
# The number of worker processes to use in parallel for uploading individual parts of a multipart upload.
|
103
103
|
DEFAULT_PARALLEL_PROCESS_COUNT = 9
|
104
104
|
|
105
|
+
LOG_LEVELS = ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
106
|
+
|
105
107
|
|
106
108
|
class FileOverwriteError(Exception): # pragma: no cover
|
107
109
|
pass
|
@@ -1751,7 +1753,7 @@ def get_log_links(
|
|
1751
1753
|
) -> str:
|
1752
1754
|
"""
|
1753
1755
|
This method returns the web console link for the given log ids.
|
1754
|
-
|
1756
|
+
|
1755
1757
|
Parameters
|
1756
1758
|
----------
|
1757
1759
|
log_group_id: str, required
|
@@ -1776,7 +1778,9 @@ def get_log_links(
|
|
1776
1778
|
query_range = f'''search "{compartment_id}/{log_group_id}/{log_id}"'''
|
1777
1779
|
query_source = f"source='{source_id}'"
|
1778
1780
|
sort_condition = f"sort by datetime desc®ions={region}"
|
1779
|
-
search_query =
|
1781
|
+
search_query = (
|
1782
|
+
f"search?searchQuery={query_range} | {query_source} | {sort_condition}"
|
1783
|
+
)
|
1780
1784
|
console_link_url = f"https://cloud.oracle.com/logging/{search_query}"
|
1781
1785
|
elif log_group_id:
|
1782
1786
|
console_link_url = f"https://cloud.oracle.com/logging/log-groups/{log_group_id}?region={region}"
|
ads/config.py
CHANGED
@@ -79,6 +79,7 @@ AQUA_TELEMETRY_BUCKET = os.environ.get(
|
|
79
79
|
"AQUA_TELEMETRY_BUCKET", "service-managed-models"
|
80
80
|
)
|
81
81
|
AQUA_TELEMETRY_BUCKET_NS = os.environ.get("AQUA_TELEMETRY_BUCKET_NS", CONDA_BUCKET_NS)
|
82
|
+
|
82
83
|
DEBUG_TELEMETRY = os.environ.get("DEBUG_TELEMETRY", None)
|
83
84
|
AQUA_SERVICE_NAME = "aqua"
|
84
85
|
DATA_SCIENCE_SERVICE_NAME = "data-science"
|