oracle-ads 2.13.3__py3-none-any.whl → 2.13.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/app.py +6 -0
- ads/aqua/client/openai_client.py +305 -0
- ads/aqua/common/entities.py +224 -2
- ads/aqua/common/enums.py +3 -0
- ads/aqua/common/utils.py +105 -3
- ads/aqua/config/container_config.py +9 -0
- ads/aqua/constants.py +29 -1
- ads/aqua/evaluation/entities.py +6 -1
- ads/aqua/evaluation/evaluation.py +191 -7
- ads/aqua/extension/aqua_ws_msg_handler.py +6 -36
- ads/aqua/extension/base_handler.py +13 -71
- ads/aqua/extension/deployment_handler.py +67 -76
- ads/aqua/extension/errors.py +19 -0
- ads/aqua/extension/utils.py +114 -2
- ads/aqua/finetuning/finetuning.py +50 -1
- ads/aqua/model/constants.py +3 -0
- ads/aqua/model/enums.py +5 -0
- ads/aqua/model/model.py +236 -24
- ads/aqua/modeldeployment/deployment.py +671 -152
- ads/aqua/modeldeployment/entities.py +551 -42
- ads/aqua/modeldeployment/inference.py +4 -5
- ads/aqua/modeldeployment/utils.py +525 -0
- ads/aqua/resources/gpu_shapes_index.json +94 -0
- ads/common/utils.py +1 -17
- ads/model/datascience_model.py +81 -21
- ads/model/service/oci_datascience_model.py +50 -42
- ads/opctl/operator/lowcode/forecast/model/factory.py +8 -1
- {oracle_ads-2.13.3.dist-info → oracle_ads-2.13.5.dist-info}/METADATA +8 -4
- {oracle_ads-2.13.3.dist-info → oracle_ads-2.13.5.dist-info}/RECORD +32 -29
- {oracle_ads-2.13.3.dist-info → oracle_ads-2.13.5.dist-info}/WHEEL +1 -1
- {oracle_ads-2.13.3.dist-info → oracle_ads-2.13.5.dist-info}/entry_points.txt +0 -0
- {oracle_ads-2.13.3.dist-info → oracle_ads-2.13.5.dist-info}/licenses/LICENSE.txt +0 -0
@@ -2,19 +2,16 @@
|
|
2
2
|
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
|
-
|
6
5
|
import json
|
7
|
-
import traceback
|
8
|
-
import uuid
|
9
6
|
from dataclasses import asdict, is_dataclass
|
10
|
-
from http.client import responses
|
11
7
|
from typing import Any
|
12
8
|
|
13
9
|
from notebook.base.handlers import APIHandler
|
14
10
|
from tornado import httputil
|
15
|
-
from tornado.web import Application
|
11
|
+
from tornado.web import Application
|
16
12
|
|
17
|
-
from ads.aqua import
|
13
|
+
from ads.aqua.common.utils import is_pydantic_model
|
14
|
+
from ads.aqua.extension.utils import construct_error
|
18
15
|
from ads.config import AQUA_TELEMETRY_BUCKET, AQUA_TELEMETRY_BUCKET_NS
|
19
16
|
from ads.telemetry.client import TelemetryClient
|
20
17
|
|
@@ -40,7 +37,7 @@ class AquaAPIhandler(APIHandler):
|
|
40
37
|
def prepare(self, *args, **kwargs):
|
41
38
|
"""The base class prepare is not required for Aqua"""
|
42
39
|
pass
|
43
|
-
|
40
|
+
|
44
41
|
@staticmethod
|
45
42
|
def serialize(obj: Any):
|
46
43
|
"""Serialize the object.
|
@@ -52,6 +49,9 @@ class AquaAPIhandler(APIHandler):
|
|
52
49
|
if is_dataclass(obj):
|
53
50
|
return asdict(obj)
|
54
51
|
|
52
|
+
if is_pydantic_model(obj):
|
53
|
+
return obj.model_dump()
|
54
|
+
|
55
55
|
return str(obj)
|
56
56
|
|
57
57
|
def finish(self, payload=None): # pylint: disable=W0221
|
@@ -71,37 +71,11 @@ class AquaAPIhandler(APIHandler):
|
|
71
71
|
|
72
72
|
def write_error(self, status_code, **kwargs):
|
73
73
|
"""AquaAPIhandler errors are JSON, not human pages."""
|
74
|
-
self.set_header("Content-Type", "application/json")
|
75
|
-
reason = kwargs.get("reason")
|
76
|
-
self.set_status(status_code, reason=reason)
|
77
|
-
service_payload = kwargs.get("service_payload", {})
|
78
|
-
default_msg = responses.get(status_code, "Unknown HTTP Error")
|
79
|
-
message = self.get_default_error_messages(
|
80
|
-
service_payload, str(status_code), kwargs.get("message", default_msg)
|
81
|
-
)
|
82
|
-
|
83
|
-
reply = {
|
84
|
-
"status": status_code,
|
85
|
-
"message": message,
|
86
|
-
"service_payload": service_payload,
|
87
|
-
"reason": reason,
|
88
|
-
"request_id": str(uuid.uuid4()),
|
89
|
-
}
|
90
|
-
exc_info = kwargs.get("exc_info")
|
91
|
-
if exc_info:
|
92
|
-
logger.error(
|
93
|
-
f"Error Request ID: {reply['request_id']}\n"
|
94
|
-
f"Error: {''.join(traceback.format_exception(*exc_info))}"
|
95
|
-
)
|
96
|
-
e = exc_info[1]
|
97
|
-
if isinstance(e, HTTPError):
|
98
|
-
reply["message"] = e.log_message or message
|
99
|
-
reply["reason"] = e.reason if e.reason else reply["reason"]
|
100
74
|
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
)
|
75
|
+
reply_details = construct_error(status_code, **kwargs)
|
76
|
+
|
77
|
+
self.set_header("Content-Type", "application/json")
|
78
|
+
self.set_status(status_code, reason=reply_details.reason)
|
105
79
|
|
106
80
|
# telemetry may not be present if there is an error while initializing
|
107
81
|
if hasattr(self, "telemetry"):
|
@@ -109,40 +83,8 @@ class AquaAPIhandler(APIHandler):
|
|
109
83
|
self.telemetry.record_event_async(
|
110
84
|
category="aqua/error",
|
111
85
|
action=str(status_code),
|
112
|
-
value=reason,
|
86
|
+
value=reply_details.reason,
|
113
87
|
**aqua_api_details,
|
114
88
|
)
|
115
89
|
|
116
|
-
self.finish(
|
117
|
-
|
118
|
-
@staticmethod
|
119
|
-
def get_default_error_messages(
|
120
|
-
service_payload: dict,
|
121
|
-
status_code: str,
|
122
|
-
default_msg: str = "Unknown HTTP Error.",
|
123
|
-
):
|
124
|
-
"""Method that maps the error messages based on the operation performed or the status codes encountered."""
|
125
|
-
|
126
|
-
messages = {
|
127
|
-
"400": "Something went wrong with your request.",
|
128
|
-
"403": "We're having trouble processing your request with the information provided.",
|
129
|
-
"404": "Authorization Failed: The resource you're looking for isn't accessible.",
|
130
|
-
"408": "Server is taking too long to response, please try again.",
|
131
|
-
"create": "Authorization Failed: Could not create resource.",
|
132
|
-
"get": "Authorization Failed: The resource you're looking for isn't accessible.",
|
133
|
-
}
|
134
|
-
|
135
|
-
if service_payload and "operation_name" in service_payload:
|
136
|
-
operation_name = service_payload["operation_name"]
|
137
|
-
if operation_name:
|
138
|
-
if operation_name.startswith("create"):
|
139
|
-
return messages["create"] + f" Operation Name: {operation_name}."
|
140
|
-
elif operation_name.startswith("list") or operation_name.startswith(
|
141
|
-
"get"
|
142
|
-
):
|
143
|
-
return messages["get"] + f" Operation Name: {operation_name}."
|
144
|
-
|
145
|
-
if status_code in messages:
|
146
|
-
return messages[status_code]
|
147
|
-
else:
|
148
|
-
return default_msg
|
90
|
+
self.finish(reply_details)
|
@@ -1,7 +1,8 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
|
+
from typing import List, Union
|
5
6
|
from urllib.parse import urlparse
|
6
7
|
|
7
8
|
from tornado.web import HTTPError
|
@@ -11,7 +12,7 @@ from ads.aqua.extension.base_handler import AquaAPIhandler
|
|
11
12
|
from ads.aqua.extension.errors import Errors
|
12
13
|
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
|
13
14
|
from ads.aqua.modeldeployment.entities import ModelParams
|
14
|
-
from ads.config import COMPARTMENT_OCID
|
15
|
+
from ads.config import COMPARTMENT_OCID
|
15
16
|
|
16
17
|
|
17
18
|
class AquaDeploymentHandler(AquaAPIhandler):
|
@@ -20,7 +21,7 @@ class AquaDeploymentHandler(AquaAPIhandler):
|
|
20
21
|
|
21
22
|
Methods
|
22
23
|
-------
|
23
|
-
get(self, id
|
24
|
+
get(self, id: Union[str, List[str]])
|
24
25
|
Retrieves a list of AQUA deployments or model info or logs by ID.
|
25
26
|
post(self, *args, **kwargs)
|
26
27
|
Creates a new AQUA deployment.
|
@@ -30,6 +31,8 @@ class AquaDeploymentHandler(AquaAPIhandler):
|
|
30
31
|
Lists all the AQUA deployments.
|
31
32
|
get_deployment_config(self, model_id)
|
32
33
|
Gets the deployment config for Aqua model.
|
34
|
+
list_shapes(self)
|
35
|
+
Lists the valid model deployment shapes.
|
33
36
|
|
34
37
|
Raises
|
35
38
|
------
|
@@ -37,16 +40,23 @@ class AquaDeploymentHandler(AquaAPIhandler):
|
|
37
40
|
"""
|
38
41
|
|
39
42
|
@handle_exceptions
|
40
|
-
def get(self, id=
|
43
|
+
def get(self, id: Union[str, List[str]] = None):
|
41
44
|
"""Handle GET request."""
|
42
45
|
url_parse = urlparse(self.request.path)
|
43
46
|
paths = url_parse.path.strip("/")
|
44
47
|
if paths.startswith("aqua/deployments/config"):
|
45
|
-
if not id:
|
48
|
+
if not id or not isinstance(id, str):
|
46
49
|
raise HTTPError(
|
47
|
-
400,
|
50
|
+
400,
|
51
|
+
f"Invalid request format for {self.request.path}. "
|
52
|
+
"Expected a single model ID or a comma-separated list of model IDs.",
|
48
53
|
)
|
49
|
-
|
54
|
+
id = id.replace(" ", "")
|
55
|
+
return self.get_deployment_config(
|
56
|
+
model_id=id.split(",") if "," in id else id
|
57
|
+
)
|
58
|
+
elif paths.startswith("aqua/deployments/shapes"):
|
59
|
+
return self.list_shapes()
|
50
60
|
elif paths.startswith("aqua/deployments"):
|
51
61
|
if not id:
|
52
62
|
return self.list()
|
@@ -98,71 +108,7 @@ class AquaDeploymentHandler(AquaAPIhandler):
|
|
98
108
|
if not input_data:
|
99
109
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
100
110
|
|
101
|
-
|
102
|
-
display_name = input_data.get("display_name")
|
103
|
-
if not display_name:
|
104
|
-
raise HTTPError(
|
105
|
-
400, Errors.MISSING_REQUIRED_PARAMETER.format("display_name")
|
106
|
-
)
|
107
|
-
instance_shape = input_data.get("instance_shape")
|
108
|
-
if not instance_shape:
|
109
|
-
raise HTTPError(
|
110
|
-
400, Errors.MISSING_REQUIRED_PARAMETER.format("instance_shape")
|
111
|
-
)
|
112
|
-
model_id = input_data.get("model_id")
|
113
|
-
if not model_id:
|
114
|
-
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))
|
115
|
-
|
116
|
-
compartment_id = input_data.get("compartment_id", COMPARTMENT_OCID)
|
117
|
-
project_id = input_data.get("project_id", PROJECT_OCID)
|
118
|
-
log_group_id = input_data.get("log_group_id")
|
119
|
-
access_log_id = input_data.get("access_log_id")
|
120
|
-
predict_log_id = input_data.get("predict_log_id")
|
121
|
-
description = input_data.get("description")
|
122
|
-
instance_count = input_data.get("instance_count")
|
123
|
-
bandwidth_mbps = input_data.get("bandwidth_mbps")
|
124
|
-
web_concurrency = input_data.get("web_concurrency")
|
125
|
-
server_port = input_data.get("server_port")
|
126
|
-
health_check_port = input_data.get("health_check_port")
|
127
|
-
env_var = input_data.get("env_var")
|
128
|
-
container_family = input_data.get("container_family")
|
129
|
-
ocpus = input_data.get("ocpus")
|
130
|
-
memory_in_gbs = input_data.get("memory_in_gbs")
|
131
|
-
model_file = input_data.get("model_file")
|
132
|
-
private_endpoint_id = input_data.get("private_endpoint_id")
|
133
|
-
container_image_uri = input_data.get("container_image_uri")
|
134
|
-
cmd_var = input_data.get("cmd_var")
|
135
|
-
freeform_tags = input_data.get("freeform_tags")
|
136
|
-
defined_tags = input_data.get("defined_tags")
|
137
|
-
|
138
|
-
self.finish(
|
139
|
-
AquaDeploymentApp().create(
|
140
|
-
compartment_id=compartment_id,
|
141
|
-
project_id=project_id,
|
142
|
-
model_id=model_id,
|
143
|
-
display_name=display_name,
|
144
|
-
description=description,
|
145
|
-
instance_count=instance_count,
|
146
|
-
instance_shape=instance_shape,
|
147
|
-
log_group_id=log_group_id,
|
148
|
-
access_log_id=access_log_id,
|
149
|
-
predict_log_id=predict_log_id,
|
150
|
-
bandwidth_mbps=bandwidth_mbps,
|
151
|
-
web_concurrency=web_concurrency,
|
152
|
-
server_port=server_port,
|
153
|
-
health_check_port=health_check_port,
|
154
|
-
env_var=env_var,
|
155
|
-
container_family=container_family,
|
156
|
-
ocpus=ocpus,
|
157
|
-
memory_in_gbs=memory_in_gbs,
|
158
|
-
model_file=model_file,
|
159
|
-
private_endpoint_id=private_endpoint_id,
|
160
|
-
container_image_uri=container_image_uri,
|
161
|
-
cmd_var=cmd_var,
|
162
|
-
freeform_tags=freeform_tags,
|
163
|
-
defined_tags=defined_tags,
|
164
|
-
)
|
165
|
-
)
|
111
|
+
self.finish(AquaDeploymentApp().create(**input_data))
|
166
112
|
|
167
113
|
def read(self, id):
|
168
114
|
"""Read the information of an Aqua model deployment."""
|
@@ -181,9 +127,52 @@ class AquaDeploymentHandler(AquaAPIhandler):
|
|
181
127
|
)
|
182
128
|
)
|
183
129
|
|
184
|
-
def get_deployment_config(self, model_id):
|
185
|
-
"""
|
186
|
-
|
130
|
+
def get_deployment_config(self, model_id: Union[str, List[str]]):
|
131
|
+
"""
|
132
|
+
Retrieves the deployment configuration for one or more Aqua models.
|
133
|
+
|
134
|
+
Parameters
|
135
|
+
----------
|
136
|
+
model_id : Union[str, List[str]]
|
137
|
+
A single model ID (str) or a list of model IDs (List[str]).
|
138
|
+
|
139
|
+
Returns
|
140
|
+
-------
|
141
|
+
None
|
142
|
+
The function sends the deployment configuration as a response.
|
143
|
+
"""
|
144
|
+
app = AquaDeploymentApp()
|
145
|
+
|
146
|
+
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
|
147
|
+
|
148
|
+
if isinstance(model_id, list):
|
149
|
+
# Handle multiple model deployment
|
150
|
+
primary_model_id = self.get_argument("primary_model_id", default=None)
|
151
|
+
deployment_config = app.get_multimodel_deployment_config(
|
152
|
+
model_ids=model_id,
|
153
|
+
primary_model_id=primary_model_id,
|
154
|
+
compartment_id=compartment_id,
|
155
|
+
)
|
156
|
+
else:
|
157
|
+
# Handle single model deployment
|
158
|
+
deployment_config = app.get_deployment_config(model_id=model_id)
|
159
|
+
|
160
|
+
return self.finish(deployment_config)
|
161
|
+
|
162
|
+
def list_shapes(self):
|
163
|
+
"""
|
164
|
+
Lists the valid model deployment shapes.
|
165
|
+
|
166
|
+
Returns
|
167
|
+
-------
|
168
|
+
List[ComputeShapeSummary]:
|
169
|
+
The list of the model deployment shapes.
|
170
|
+
"""
|
171
|
+
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
|
172
|
+
|
173
|
+
return self.finish(
|
174
|
+
AquaDeploymentApp().list_shapes(compartment_id=compartment_id)
|
175
|
+
)
|
187
176
|
|
188
177
|
|
189
178
|
class AquaDeploymentInferenceHandler(AquaAPIhandler):
|
@@ -259,9 +248,10 @@ class AquaDeploymentParamsHandler(AquaAPIhandler):
|
|
259
248
|
def get(self, model_id):
|
260
249
|
"""Handle GET request."""
|
261
250
|
instance_shape = self.get_argument("instance_shape")
|
251
|
+
gpu_count = self.get_argument("gpu_count", default=None)
|
262
252
|
return self.finish(
|
263
253
|
AquaDeploymentApp().get_deployment_default_params(
|
264
|
-
model_id=model_id, instance_shape=instance_shape
|
254
|
+
model_id=model_id, instance_shape=instance_shape, gpu_count=gpu_count
|
265
255
|
)
|
266
256
|
)
|
267
257
|
|
@@ -300,6 +290,7 @@ class AquaDeploymentParamsHandler(AquaAPIhandler):
|
|
300
290
|
__handlers__ = [
|
301
291
|
("deployments/?([^/]*)/params", AquaDeploymentParamsHandler),
|
302
292
|
("deployments/config/?([^/]*)", AquaDeploymentHandler),
|
293
|
+
("deployments/shapes/?([^/]*)", AquaDeploymentHandler),
|
303
294
|
("deployments/?([^/]*)", AquaDeploymentHandler),
|
304
295
|
("deployments/?([^/]*)/activate", AquaDeploymentHandler),
|
305
296
|
("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),
|
ads/aqua/extension/errors.py
CHANGED
@@ -1,7 +1,16 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
+
import uuid
|
5
|
+
from typing import Any, Dict, List, Optional
|
4
6
|
|
7
|
+
from pydantic import Field
|
8
|
+
|
9
|
+
from ads.aqua.config.utils.serializer import Serializable
|
10
|
+
|
11
|
+
from ads.aqua.constants import (
|
12
|
+
AQUA_TROUBLESHOOTING_LINK
|
13
|
+
)
|
5
14
|
|
6
15
|
class Errors(str):
|
7
16
|
INVALID_INPUT_DATA_FORMAT = "Invalid format of input data."
|
@@ -9,3 +18,13 @@ class Errors(str):
|
|
9
18
|
MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
|
10
19
|
MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required."
|
11
20
|
INVALID_VALUE_OF_PARAMETER = "Invalid value of parameter: '{}'"
|
21
|
+
|
22
|
+
class ReplyDetails(Serializable):
|
23
|
+
"""Structured reply to be returned to the client."""
|
24
|
+
status: int
|
25
|
+
troubleshooting_tips: str = Field(f"For general tips on troubleshooting: {AQUA_TROUBLESHOOTING_LINK}",
|
26
|
+
description="GitHub Link for troubleshooting documentation")
|
27
|
+
message: str = Field("Unknown HTTP Error.", description="GitHub Link for troubleshooting documentation")
|
28
|
+
service_payload: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
29
|
+
reason: str = Field("Unknown error", description="Reason for Error")
|
30
|
+
request_id: str = Field(str(uuid.uuid4()), description="Unique ID for tracking the error.")
|
ads/aqua/extension/utils.py
CHANGED
@@ -1,16 +1,26 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
+
|
5
|
+
import re
|
6
|
+
import traceback
|
7
|
+
import uuid
|
4
8
|
from dataclasses import fields
|
5
9
|
from datetime import datetime, timedelta
|
10
|
+
from http.client import responses
|
6
11
|
from typing import Dict, Optional
|
7
12
|
|
8
13
|
from cachetools import TTLCache, cached
|
9
14
|
from tornado.web import HTTPError
|
10
15
|
|
11
|
-
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
|
16
|
+
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
|
12
17
|
from ads.aqua.common.utils import fetch_service_compartment
|
13
|
-
from ads.aqua.
|
18
|
+
from ads.aqua.constants import (
|
19
|
+
AQUA_TROUBLESHOOTING_LINK,
|
20
|
+
OCI_OPERATION_FAILURES,
|
21
|
+
STATUS_CODE_MESSAGES,
|
22
|
+
)
|
23
|
+
from ads.aqua.extension.errors import Errors, ReplyDetails
|
14
24
|
|
15
25
|
|
16
26
|
def validate_function_parameters(data_class, input_data: Dict):
|
@@ -32,3 +42,105 @@ def ui_compatability_check():
|
|
32
42
|
fetched from the configuration. The cached result is returned when multiple calls are made in quick succession
|
33
43
|
from the UI to avoid multiple config file loads."""
|
34
44
|
return ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment()
|
45
|
+
|
46
|
+
|
47
|
+
def get_default_error_messages(
|
48
|
+
service_payload: dict,
|
49
|
+
status_code: str,
|
50
|
+
default_msg: str = "Unknown HTTP Error.",
|
51
|
+
)-> str:
|
52
|
+
"""Method that maps the error messages based on the operation performed or the status codes encountered."""
|
53
|
+
|
54
|
+
if service_payload and "operation_name" in service_payload:
|
55
|
+
operation_name = service_payload.get("operation_name")
|
56
|
+
|
57
|
+
if operation_name and status_code in STATUS_CODE_MESSAGES:
|
58
|
+
return f"{STATUS_CODE_MESSAGES[status_code]}\n{service_payload.get('message')}\nOperation Name: {operation_name}."
|
59
|
+
|
60
|
+
return STATUS_CODE_MESSAGES.get(status_code, default_msg)
|
61
|
+
|
62
|
+
|
63
|
+
def get_documentation_link(key: str) -> str:
|
64
|
+
"""Generates appropriate GitHub link to AQUA Troubleshooting Documentation per the user's error."""
|
65
|
+
github_header = re.sub(r"_", "-", key)
|
66
|
+
return f"{AQUA_TROUBLESHOOTING_LINK}#{github_header}"
|
67
|
+
|
68
|
+
|
69
|
+
def get_troubleshooting_tips(service_payload: dict,
|
70
|
+
status_code: str) -> str:
|
71
|
+
"""Maps authorization errors to potential solutions on Troubleshooting Page per Aqua Documentation on oci-data-science-ai-samples"""
|
72
|
+
|
73
|
+
tip = f"For general tips on troubleshooting: {AQUA_TROUBLESHOOTING_LINK}"
|
74
|
+
|
75
|
+
if status_code in (404, 400):
|
76
|
+
failed_operation = service_payload.get('operation_name')
|
77
|
+
|
78
|
+
if failed_operation in OCI_OPERATION_FAILURES:
|
79
|
+
link = get_documentation_link(failed_operation)
|
80
|
+
tip = OCI_OPERATION_FAILURES[failed_operation] + link
|
81
|
+
|
82
|
+
return tip
|
83
|
+
|
84
|
+
|
85
|
+
def construct_error(status_code: int, **kwargs) -> ReplyDetails:
|
86
|
+
"""
|
87
|
+
Formats an error response based on the provided status code and optional details.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
status_code (int): The HTTP status code of the error.
|
91
|
+
**kwargs: Additional optional parameters:
|
92
|
+
- reason (str, optional): A brief reason for the error.
|
93
|
+
- service_payload (dict, optional): Contextual error data from OCI SDK methods
|
94
|
+
- message (str, optional): A custom error message, from error raised from failed AQUA methods calling OCI SDK methods
|
95
|
+
- exc_info (tuple, optional): Exception information (e.g., from `sys.exc_info()`), used for logging.
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
ReplyDetails: A Pydantic object containing details about the formatted error response.
|
99
|
+
kwargs:
|
100
|
+
- "status" (int): The HTTP status code.
|
101
|
+
- "troubleshooting_tips" (str): a GitHub link to AQUA troubleshooting docs, may be linked to a specific header.
|
102
|
+
- "message" (str): error message.
|
103
|
+
- "service_payload" (Dict[str, Any], optional) : Additional context from OCI Python SDK call.
|
104
|
+
- "reason" (str): The reason for the error.
|
105
|
+
- "request_id" (str): A unique identifier for tracking the error.
|
106
|
+
|
107
|
+
Logs:
|
108
|
+
- Logs the error details with a unique request ID.
|
109
|
+
- If `exc_info` is provided and contains an `HTTPError`, updates the response message and reason accordingly.
|
110
|
+
|
111
|
+
"""
|
112
|
+
reason = kwargs.get("reason", "Unknown Error")
|
113
|
+
service_payload = kwargs.get("service_payload", {})
|
114
|
+
default_msg = responses.get(status_code, "Unknown HTTP Error")
|
115
|
+
message = get_default_error_messages(
|
116
|
+
service_payload, str(status_code), kwargs.get("message", default_msg)
|
117
|
+
)
|
118
|
+
|
119
|
+
tips = get_troubleshooting_tips(service_payload, status_code)
|
120
|
+
|
121
|
+
|
122
|
+
reply = ReplyDetails(
|
123
|
+
status = status_code,
|
124
|
+
troubleshooting_tips = tips,
|
125
|
+
message = message,
|
126
|
+
service_payload = service_payload,
|
127
|
+
reason = reason,
|
128
|
+
request_id = str(uuid.uuid4()),
|
129
|
+
)
|
130
|
+
|
131
|
+
exc_info = kwargs.get("exc_info")
|
132
|
+
if exc_info:
|
133
|
+
logger.error(
|
134
|
+
f"Error Request ID: {reply.request_id}\n"
|
135
|
+
f"Error: {''.join(traceback.format_exception(*exc_info))}"
|
136
|
+
)
|
137
|
+
e = exc_info[1]
|
138
|
+
if isinstance(e, HTTPError):
|
139
|
+
reply.message = e.log_message or message
|
140
|
+
reply.reason = e.reason if e.reason else reply.reason
|
141
|
+
|
142
|
+
logger.error(
|
143
|
+
f"Error Request ID: {reply.request_id}\n"
|
144
|
+
f"Error: {reply.message} {reply.reason}"
|
145
|
+
)
|
146
|
+
return reply
|
@@ -87,13 +87,62 @@ class AquaFineTuningApp(AquaApp):
|
|
87
87
|
def create(
|
88
88
|
self, create_fine_tuning_details: CreateFineTuningDetails = None, **kwargs
|
89
89
|
) -> "AquaFineTuningSummary":
|
90
|
-
"""Creates Aqua fine tuning for model
|
90
|
+
"""Creates Aqua fine tuning for model.\n
|
91
|
+
For detailed information about CLI flags see: https://github.com/oracle-samples/oci-data-science-ai-samples/blob/f271ca63d12e3c256718f23a14d93da4b4fc086b/ai-quick-actions/cli-tips.md#create-fine-tuned-model
|
91
92
|
|
92
93
|
Parameters
|
93
94
|
----------
|
94
95
|
create_fine_tuning_details: CreateFineTuningDetails
|
95
96
|
The CreateFineTuningDetails data class which contains all
|
96
97
|
required and optional fields to create the aqua fine tuning.
|
98
|
+
kwargs:
|
99
|
+
ft_source_id: str The fine tuning source id. Must be model OCID.
|
100
|
+
ft_name: str
|
101
|
+
The name for fine tuning.
|
102
|
+
dataset_path: str
|
103
|
+
The dataset path for fine tuning. Could be either a local path from notebook session
|
104
|
+
or an object storage path.
|
105
|
+
report_path: str
|
106
|
+
The report path for fine tuning. Must be an object storage path.
|
107
|
+
ft_parameters: dict
|
108
|
+
The parameters for fine tuning.
|
109
|
+
shape_name: str
|
110
|
+
The shape name for fine tuning job infrastructure.
|
111
|
+
replica: int
|
112
|
+
The replica for fine tuning job runtime.
|
113
|
+
validation_set_size: float
|
114
|
+
The validation set size for fine tuning job. Must be a float in between [0,1).
|
115
|
+
ft_description: (str, optional). Defaults to `None`.
|
116
|
+
The description for fine tuning.
|
117
|
+
compartment_id: (str, optional). Defaults to `None`.
|
118
|
+
The compartment id for fine tuning.
|
119
|
+
project_id: (str, optional). Defaults to `None`.
|
120
|
+
The project id for fine tuning.
|
121
|
+
experiment_id: (str, optional). Defaults to `None`.
|
122
|
+
The fine tuning model version set id. If provided,
|
123
|
+
fine tuning model will be associated with it.
|
124
|
+
experiment_name: (str, optional). Defaults to `None`.
|
125
|
+
The fine tuning model version set name. If provided,
|
126
|
+
the fine tuning version set with the same name will be used if exists,
|
127
|
+
otherwise a new model version set will be created with the name.
|
128
|
+
experiment_description: (str, optional). Defaults to `None`.
|
129
|
+
The description for fine tuning model version set.
|
130
|
+
block_storage_size: (int, optional). Defaults to 256.
|
131
|
+
The storage for fine tuning job infrastructure.
|
132
|
+
subnet_id: (str, optional). Defaults to `None`.
|
133
|
+
The custom egress for fine tuning job.
|
134
|
+
log_group_id: (str, optional). Defaults to `None`.
|
135
|
+
The log group id for fine tuning job infrastructure.
|
136
|
+
log_id: (str, optional). Defaults to `None`.
|
137
|
+
The log id for fine tuning job infrastructure.
|
138
|
+
watch_logs: (bool, optional). Defaults to `False`.
|
139
|
+
The flag to watch the job run logs when a fine-tuning job is created.
|
140
|
+
force_overwrite: (bool, optional). Defaults to `False`.
|
141
|
+
Whether to force overwrite the existing file in object storage.
|
142
|
+
freeform_tags: (dict, optional)
|
143
|
+
Freeform tags for the fine-tuning model
|
144
|
+
defined_tags: (dict, optional)
|
145
|
+
Defined tags for the fine-tuning model
|
97
146
|
kwargs:
|
98
147
|
The kwargs for creating CreateFineTuningDetails instance if
|
99
148
|
no create_fine_tuning_details provided.
|
ads/aqua/model/constants.py
CHANGED
@@ -18,6 +18,8 @@ class ModelCustomMetadataFields(ExtendedEnum):
|
|
18
18
|
EVALUATION_CONTAINER = "evaluation-container"
|
19
19
|
FINETUNE_CONTAINER = "finetune-container"
|
20
20
|
DEPLOYMENT_CONTAINER_URI = "deployment-container-uri"
|
21
|
+
MULTIMODEL_GROUP_COUNT = "model_group_count"
|
22
|
+
MULTIMODEL_METADATA = "multi_model_metadata"
|
21
23
|
|
22
24
|
|
23
25
|
class ModelTask(ExtendedEnum):
|
@@ -34,6 +36,7 @@ class FineTuningMetricCategories(ExtendedEnum):
|
|
34
36
|
class ModelType(ExtendedEnum):
|
35
37
|
FT = "FT" # Fine Tuned Model
|
36
38
|
BASE = "BASE" # Base model
|
39
|
+
MULTIMODEL = "MULTIMODEL"
|
37
40
|
|
38
41
|
|
39
42
|
# TODO: merge metadata key used in create FT
|
ads/aqua/model/enums.py
CHANGED
@@ -23,3 +23,8 @@ class FineTuningCustomMetadata(ExtendedEnum):
|
|
23
23
|
VALIDATION_METRICS_FINAL = "val_metrics_final"
|
24
24
|
TRAINING_METRICS_EPOCH = "train_metrics_epoch"
|
25
25
|
VALIDATION_METRICS_EPOCH = "val_metrics_epoch"
|
26
|
+
|
27
|
+
|
28
|
+
class MultiModelSupportedTaskType(ExtendedEnum):
|
29
|
+
TEXT_GENERATION = "text-generation"
|
30
|
+
TEXT_GENERATION_ALT = "text_generation"
|