oracle-ads 2.10.0__py3-none-any.whl → 2.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +12 -0
- ads/aqua/base.py +324 -0
- ads/aqua/cli.py +19 -0
- ads/aqua/config/deployment_config_defaults.json +9 -0
- ads/aqua/config/resource_limit_names.json +7 -0
- ads/aqua/constants.py +45 -0
- ads/aqua/data.py +40 -0
- ads/aqua/decorator.py +101 -0
- ads/aqua/deployment.py +643 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation.py +1751 -0
- ads/aqua/exception.py +82 -0
- ads/aqua/extension/__init__.py +40 -0
- ads/aqua/extension/base_handler.py +138 -0
- ads/aqua/extension/common_handler.py +21 -0
- ads/aqua/extension/deployment_handler.py +202 -0
- ads/aqua/extension/evaluation_handler.py +135 -0
- ads/aqua/extension/finetune_handler.py +66 -0
- ads/aqua/extension/model_handler.py +59 -0
- ads/aqua/extension/ui_handler.py +201 -0
- ads/aqua/extension/utils.py +23 -0
- ads/aqua/finetune.py +579 -0
- ads/aqua/job.py +29 -0
- ads/aqua/model.py +819 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +459 -0
- ads/aqua/ui.py +453 -0
- ads/aqua/utils.py +715 -0
- ads/cli.py +37 -6
- ads/common/auth.py +7 -0
- ads/common/decorator/__init__.py +7 -3
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/object_storage_details.py +166 -7
- ads/common/oci_client.py +18 -1
- ads/common/oci_logging.py +2 -2
- ads/common/oci_mixin.py +4 -5
- ads/common/serializer.py +34 -5
- ads/common/utils.py +75 -10
- ads/config.py +40 -1
- ads/dataset/correlation_plot.py +10 -12
- ads/jobs/ads_job.py +43 -25
- ads/jobs/builders/infrastructure/base.py +4 -2
- ads/jobs/builders/infrastructure/dsc_job.py +49 -39
- ads/jobs/builders/runtimes/base.py +71 -1
- ads/jobs/builders/runtimes/container_runtime.py +4 -4
- ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
- ads/jobs/templates/driver_pytorch.py +27 -10
- ads/model/artifact_downloader.py +84 -14
- ads/model/artifact_uploader.py +25 -23
- ads/model/datascience_model.py +388 -38
- ads/model/deployment/model_deployment.py +10 -2
- ads/model/generic_model.py +8 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_metadata.py +1 -1
- ads/model/service/oci_datascience_model.py +34 -5
- ads/opctl/config/merger.py +2 -2
- ads/opctl/operator/__init__.py +3 -1
- ads/opctl/operator/cli.py +7 -1
- ads/opctl/operator/cmd.py +3 -3
- ads/opctl/operator/common/errors.py +2 -1
- ads/opctl/operator/common/operator_config.py +22 -3
- ads/opctl/operator/common/utils.py +16 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +15 -0
- ads/opctl/operator/lowcode/anomaly/README.md +209 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +104 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +88 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +12 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +147 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +89 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +103 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +354 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +67 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +105 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +359 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +81 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +96 -0
- ads/opctl/operator/lowcode/common/errors.py +41 -0
- ads/opctl/operator/lowcode/common/transformations.py +191 -0
- ads/opctl/operator/lowcode/common/utils.py +250 -0
- ads/opctl/operator/lowcode/forecast/README.md +3 -2
- ads/opctl/operator/lowcode/forecast/__main__.py +18 -2
- ads/opctl/operator/lowcode/forecast/cmd.py +8 -7
- ads/opctl/operator/lowcode/forecast/const.py +17 -1
- ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
- ads/opctl/operator/lowcode/forecast/model/arima.py +106 -117
- ads/opctl/operator/lowcode/forecast/model/automlx.py +204 -180
- ads/opctl/operator/lowcode/forecast/model/autots.py +144 -253
- ads/opctl/operator/lowcode/forecast/model/base_model.py +326 -259
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +325 -176
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +293 -237
- ads/opctl/operator/lowcode/forecast/model/prophet.py +191 -208
- ads/opctl/operator/lowcode/forecast/operator_config.py +24 -33
- ads/opctl/operator/lowcode/forecast/schema.yaml +116 -29
- ads/opctl/operator/lowcode/forecast/utils.py +186 -356
- ads/opctl/operator/lowcode/pii/model/guardrails.py +18 -15
- ads/opctl/operator/lowcode/pii/model/report.py +7 -7
- ads/opctl/operator/lowcode/pii/operator_config.py +1 -8
- ads/opctl/operator/lowcode/pii/utils.py +0 -82
- ads/opctl/operator/runtime/runtime.py +3 -2
- ads/telemetry/base.py +62 -0
- ads/telemetry/client.py +105 -0
- ads/telemetry/telemetry.py +6 -3
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +44 -7
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +116 -59
- ads/opctl/operator/lowcode/forecast/model/transformations.py +0 -125
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
ads/aqua/exception.py
ADDED
@@ -0,0 +1,82 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
"""
|
7
|
+
aqua.exception
|
8
|
+
~~~~~~~~~~~~~~
|
9
|
+
|
10
|
+
This module contains the set of Aqua exceptions.
|
11
|
+
"""
|
12
|
+
|
13
|
+
|
14
|
+
class AquaError(Exception):
|
15
|
+
"""AquaError
|
16
|
+
|
17
|
+
The base exception from which all exceptions raised by Aqua
|
18
|
+
will inherit.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
reason: str,
|
24
|
+
status: int,
|
25
|
+
service_payload: dict = None,
|
26
|
+
):
|
27
|
+
"""Initializes an AquaError.
|
28
|
+
|
29
|
+
Parameters
|
30
|
+
----------
|
31
|
+
reason: str
|
32
|
+
User friendly error message.
|
33
|
+
status: int
|
34
|
+
Http status code that are going to raise.
|
35
|
+
service_payload: dict
|
36
|
+
Payload to contain more details related to the error.
|
37
|
+
"""
|
38
|
+
self.service_payload = service_payload or {}
|
39
|
+
self.status = status
|
40
|
+
self.reason = reason
|
41
|
+
|
42
|
+
|
43
|
+
class AquaValueError(AquaError, ValueError):
|
44
|
+
"""Exception raised for unexpected values."""
|
45
|
+
|
46
|
+
def __init__(self, reason, status=403, service_payload=None):
|
47
|
+
super().__init__(reason, status, service_payload)
|
48
|
+
|
49
|
+
|
50
|
+
class AquaFileNotFoundError(AquaError, FileNotFoundError):
|
51
|
+
"""Exception raised for missing target file."""
|
52
|
+
|
53
|
+
def __init__(self, reason, status=404, service_payload=None):
|
54
|
+
super().__init__(reason, status, service_payload)
|
55
|
+
|
56
|
+
|
57
|
+
class AquaRuntimeError(AquaError, RuntimeError):
|
58
|
+
"""Exception raised for generic errors at runtime."""
|
59
|
+
|
60
|
+
def __init__(self, reason, status=400, service_payload=None):
|
61
|
+
super().__init__(reason, status, service_payload)
|
62
|
+
|
63
|
+
|
64
|
+
class AquaMissingKeyError(AquaError):
|
65
|
+
"""Exception raised when missing metadata in resource."""
|
66
|
+
|
67
|
+
def __init__(self, reason, status=400, service_payload=None):
|
68
|
+
super().__init__(reason, status, service_payload)
|
69
|
+
|
70
|
+
|
71
|
+
class AquaFileExistsError(AquaError, FileExistsError):
|
72
|
+
"""Exception raised when file already exists in resource."""
|
73
|
+
|
74
|
+
def __init__(self, reason, status=400, service_payload=None):
|
75
|
+
super().__init__(reason, status, service_payload)
|
76
|
+
|
77
|
+
|
78
|
+
class AquaResourceAccessError(AquaError):
|
79
|
+
"""Exception raised when file already exists in resource."""
|
80
|
+
|
81
|
+
def __init__(self, reason, status=404, service_payload=None):
|
82
|
+
super().__init__(reason, status, service_payload)
|
@@ -0,0 +1,40 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*--
|
3
|
+
|
4
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
from jupyter_server.utils import url_path_join
|
8
|
+
|
9
|
+
from ads.aqua.extension.common_handler import __handlers__ as __common_handlers__
|
10
|
+
from ads.aqua.extension.deployment_handler import (
|
11
|
+
__handlers__ as __deployment_handlers__,
|
12
|
+
)
|
13
|
+
from ads.aqua.extension.evaluation_handler import __handlers__ as __eval_handlers__
|
14
|
+
from ads.aqua.extension.finetune_handler import __handlers__ as __finetune_handlers__
|
15
|
+
from ads.aqua.extension.model_handler import __handlers__ as __model_handlers__
|
16
|
+
from ads.aqua.extension.ui_handler import __handlers__ as __ui_handlers__
|
17
|
+
|
18
|
+
__handlers__ = (
|
19
|
+
__finetune_handlers__
|
20
|
+
+ __model_handlers__
|
21
|
+
+ __common_handlers__
|
22
|
+
+ __deployment_handlers__
|
23
|
+
+ __ui_handlers__
|
24
|
+
+ __eval_handlers__
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
def load_jupyter_server_extension(nb_server_app):
|
29
|
+
web_app = nb_server_app.web_app
|
30
|
+
host_pattern = ".*$"
|
31
|
+
route_pattern = url_path_join(web_app.settings["base_url"], "aqua")
|
32
|
+
|
33
|
+
web_app.add_handlers(
|
34
|
+
host_pattern,
|
35
|
+
[(url_path_join(route_pattern, url), handler) for url, handler in __handlers__],
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
def _jupyter_server_extension_paths():
|
40
|
+
return [{"module": "ads.aqua.extension"}]
|
@@ -0,0 +1,138 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
|
7
|
+
import json
|
8
|
+
import traceback
|
9
|
+
import uuid
|
10
|
+
from dataclasses import asdict, is_dataclass
|
11
|
+
from typing import Any
|
12
|
+
|
13
|
+
from notebook.base.handlers import APIHandler
|
14
|
+
from tornado.web import HTTPError, Application
|
15
|
+
from tornado import httputil
|
16
|
+
from ads.telemetry.client import TelemetryClient
|
17
|
+
from ads.config import AQUA_TELEMETRY_BUCKET, AQUA_TELEMETRY_BUCKET_NS
|
18
|
+
from ads.aqua import logger
|
19
|
+
|
20
|
+
|
21
|
+
class AquaAPIhandler(APIHandler):
|
22
|
+
"""Base handler for Aqua REST APIs."""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
application: "Application",
|
27
|
+
request: httputil.HTTPServerRequest,
|
28
|
+
**kwargs: Any,
|
29
|
+
):
|
30
|
+
super().__init__(application, request, **kwargs)
|
31
|
+
|
32
|
+
try:
|
33
|
+
self.telemetry = TelemetryClient(
|
34
|
+
bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS
|
35
|
+
)
|
36
|
+
except:
|
37
|
+
pass
|
38
|
+
|
39
|
+
@staticmethod
|
40
|
+
def serialize(obj: Any):
|
41
|
+
"""Serialize the object.
|
42
|
+
If the object is a dataclass, convert it to dictionary. Otherwise, convert it to string.
|
43
|
+
"""
|
44
|
+
if hasattr(obj, "to_dict") and callable(obj.to_dict):
|
45
|
+
return obj.to_dict()
|
46
|
+
|
47
|
+
if is_dataclass(obj):
|
48
|
+
return asdict(obj)
|
49
|
+
|
50
|
+
return str(obj)
|
51
|
+
|
52
|
+
def finish(self, payload=None): # pylint: disable=W0221
|
53
|
+
"""Ending the HTTP request by returning a payload and status code.
|
54
|
+
|
55
|
+
Tornado finish() only takes one argument.
|
56
|
+
Calling finish() with more than one arguments will cause error.
|
57
|
+
"""
|
58
|
+
if payload is None:
|
59
|
+
return super().finish()
|
60
|
+
# If the payload is a list, put into a dictionary with key=data
|
61
|
+
if isinstance(payload, list):
|
62
|
+
payload = {"data": payload}
|
63
|
+
# Convert the payload to a JSON serializable object
|
64
|
+
payload = json.loads(json.dumps(payload, default=self.serialize))
|
65
|
+
return super().finish(payload)
|
66
|
+
|
67
|
+
def write_error(self, status_code, **kwargs):
|
68
|
+
"""AquaAPIhandler errors are JSON, not human pages."""
|
69
|
+
|
70
|
+
self.set_header("Content-Type", "application/json")
|
71
|
+
reason = kwargs.get("reason")
|
72
|
+
self.set_status(status_code, reason=reason)
|
73
|
+
service_payload = kwargs.get("service_payload", {})
|
74
|
+
message = self.get_default_error_messages(service_payload, str(status_code))
|
75
|
+
reply = {
|
76
|
+
"status": status_code,
|
77
|
+
"message": message,
|
78
|
+
"service_payload": service_payload,
|
79
|
+
"reason": reason,
|
80
|
+
}
|
81
|
+
exc_info = kwargs.get("exc_info")
|
82
|
+
if exc_info:
|
83
|
+
logger.error("".join(traceback.format_exception(*exc_info)))
|
84
|
+
e = exc_info[1]
|
85
|
+
if isinstance(e, HTTPError):
|
86
|
+
reply["message"] = e.log_message or message
|
87
|
+
reply["reason"] = e.reason
|
88
|
+
reply["request_id"] = str(uuid.uuid4())
|
89
|
+
else:
|
90
|
+
reply["request_id"] = str(uuid.uuid4())
|
91
|
+
|
92
|
+
logger.warning(reply["message"])
|
93
|
+
|
94
|
+
# telemetry may not be present if there is an error while initializing
|
95
|
+
if hasattr(self, "telemetry"):
|
96
|
+
self.telemetry.record_event_async(
|
97
|
+
category="aqua/error",
|
98
|
+
action=str(status_code),
|
99
|
+
value=reason,
|
100
|
+
)
|
101
|
+
|
102
|
+
self.finish(json.dumps(reply))
|
103
|
+
|
104
|
+
@staticmethod
|
105
|
+
def get_default_error_messages(service_payload: dict, status_code: str):
|
106
|
+
"""Method that maps the error messages based on the operation performed or the status codes encountered."""
|
107
|
+
|
108
|
+
messages = {
|
109
|
+
"400": "Something went wrong with your request.",
|
110
|
+
"403": "We're having trouble processing your request with the information provided.",
|
111
|
+
"404": "Authorization Failed: The resource you're looking for isn't accessible.",
|
112
|
+
"408": "Server is taking too long to response, please try again.",
|
113
|
+
"500": "An error occurred while creating the resource.",
|
114
|
+
"create": "Authorization Failed: Could not create resource.",
|
115
|
+
"get": "Authorization Failed: The resource you're looking for isn't accessible.",
|
116
|
+
}
|
117
|
+
|
118
|
+
if service_payload and "operation_name" in service_payload:
|
119
|
+
operation_name = service_payload["operation_name"]
|
120
|
+
if operation_name:
|
121
|
+
if operation_name.startswith("create"):
|
122
|
+
return messages["create"]
|
123
|
+
elif operation_name.startswith("list") or operation_name.startswith(
|
124
|
+
"get"
|
125
|
+
):
|
126
|
+
return messages["get"] + f" Operation Name: {operation_name}."
|
127
|
+
|
128
|
+
if status_code in messages:
|
129
|
+
return messages[status_code]
|
130
|
+
else:
|
131
|
+
return "Unknown HTTP Error."
|
132
|
+
|
133
|
+
|
134
|
+
# todo: remove after error handler is implemented
|
135
|
+
class Errors(str):
|
136
|
+
INVALID_INPUT_DATA_FORMAT = "Invalid format of input data."
|
137
|
+
NO_INPUT_DATA = "No input data provided."
|
138
|
+
MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
|
@@ -0,0 +1,21 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
|
7
|
+
from importlib import metadata
|
8
|
+
|
9
|
+
from ads.aqua.extension.base_handler import AquaAPIhandler
|
10
|
+
|
11
|
+
|
12
|
+
class ADSVersionHandler(AquaAPIhandler):
|
13
|
+
"""The handler to get the current version of the ADS."""
|
14
|
+
|
15
|
+
def get(self):
|
16
|
+
self.finish({"data": metadata.version("oracle_ads")})
|
17
|
+
|
18
|
+
|
19
|
+
__handlers__ = [
|
20
|
+
("ads_version", ADSVersionHandler),
|
21
|
+
]
|
@@ -0,0 +1,202 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
from urllib.parse import urlparse
|
7
|
+
|
8
|
+
from tornado.web import HTTPError
|
9
|
+
|
10
|
+
from ads.aqua.deployment import AquaDeploymentApp, MDInferenceResponse, ModelParams
|
11
|
+
from ads.aqua.extension.base_handler import AquaAPIhandler, Errors
|
12
|
+
from ads.config import COMPARTMENT_OCID, PROJECT_OCID
|
13
|
+
from ads.aqua.decorator import handle_exceptions
|
14
|
+
|
15
|
+
|
16
|
+
class AquaDeploymentHandler(AquaAPIhandler):
|
17
|
+
"""
|
18
|
+
Handler for Aqua Deployment REST APIs.
|
19
|
+
|
20
|
+
Methods
|
21
|
+
-------
|
22
|
+
get(self, id="")
|
23
|
+
Retrieves a list of AQUA deployments or model info or logs by ID.
|
24
|
+
post(self, *args, **kwargs)
|
25
|
+
Creates a new AQUA deployment.
|
26
|
+
read(self, id: str)
|
27
|
+
Reads the AQUA deployment information.
|
28
|
+
list(self)
|
29
|
+
Lists all the AQUA deployments.
|
30
|
+
get_deployment_config(self, model_id)
|
31
|
+
Gets the deployment config for Aqua model.
|
32
|
+
|
33
|
+
Raises
|
34
|
+
------
|
35
|
+
HTTPError: For various failure scenarios such as invalid input format, missing data, etc.
|
36
|
+
"""
|
37
|
+
|
38
|
+
@handle_exceptions
|
39
|
+
def get(self, id=""):
|
40
|
+
"""Handle GET request."""
|
41
|
+
url_parse = urlparse(self.request.path)
|
42
|
+
paths = url_parse.path.strip("/")
|
43
|
+
if paths.startswith("aqua/deployments/config"):
|
44
|
+
if not id:
|
45
|
+
raise HTTPError(
|
46
|
+
400, f"The request {self.request.path} requires model id."
|
47
|
+
)
|
48
|
+
return self.get_deployment_config(id)
|
49
|
+
elif paths.startswith("aqua/deployments"):
|
50
|
+
if not id:
|
51
|
+
return self.list()
|
52
|
+
return self.read(id)
|
53
|
+
else:
|
54
|
+
raise HTTPError(400, f"The request {self.request.path} is invalid.")
|
55
|
+
|
56
|
+
@handle_exceptions
|
57
|
+
def post(self, *args, **kwargs):
|
58
|
+
"""
|
59
|
+
Handles post request for the deployment APIs
|
60
|
+
Raises
|
61
|
+
------
|
62
|
+
HTTPError
|
63
|
+
Raises HTTPError if inputs are missing or are invalid
|
64
|
+
"""
|
65
|
+
try:
|
66
|
+
input_data = self.get_json_body()
|
67
|
+
except Exception:
|
68
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
69
|
+
|
70
|
+
if not input_data:
|
71
|
+
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
72
|
+
|
73
|
+
# required input parameters
|
74
|
+
display_name = input_data.get("display_name")
|
75
|
+
if not display_name:
|
76
|
+
raise HTTPError(
|
77
|
+
400, Errors.MISSING_REQUIRED_PARAMETER.format("display_name")
|
78
|
+
)
|
79
|
+
instance_shape = input_data.get("instance_shape")
|
80
|
+
if not instance_shape:
|
81
|
+
raise HTTPError(
|
82
|
+
400, Errors.MISSING_REQUIRED_PARAMETER.format("instance_shape")
|
83
|
+
)
|
84
|
+
model_id = input_data.get("model_id")
|
85
|
+
if not model_id:
|
86
|
+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))
|
87
|
+
|
88
|
+
compartment_id = input_data.get("compartment_id", COMPARTMENT_OCID)
|
89
|
+
project_id = input_data.get("project_id", PROJECT_OCID)
|
90
|
+
log_group_id = input_data.get("log_group_id")
|
91
|
+
access_log_id = input_data.get("access_log_id")
|
92
|
+
predict_log_id = input_data.get("predict_log_id")
|
93
|
+
description = input_data.get("description")
|
94
|
+
instance_count = input_data.get("instance_count")
|
95
|
+
bandwidth_mbps = input_data.get("bandwidth_mbps")
|
96
|
+
|
97
|
+
self.finish(
|
98
|
+
AquaDeploymentApp().create(
|
99
|
+
compartment_id=compartment_id,
|
100
|
+
project_id=project_id,
|
101
|
+
model_id=model_id,
|
102
|
+
display_name=display_name,
|
103
|
+
description=description,
|
104
|
+
instance_count=instance_count,
|
105
|
+
instance_shape=instance_shape,
|
106
|
+
log_group_id=log_group_id,
|
107
|
+
access_log_id=access_log_id,
|
108
|
+
predict_log_id=predict_log_id,
|
109
|
+
bandwidth_mbps=bandwidth_mbps,
|
110
|
+
)
|
111
|
+
)
|
112
|
+
|
113
|
+
@handle_exceptions
|
114
|
+
def read(self, id):
|
115
|
+
"""Read the information of an Aqua model deployment."""
|
116
|
+
return self.finish(AquaDeploymentApp().get(model_deployment_id=id))
|
117
|
+
|
118
|
+
@handle_exceptions
|
119
|
+
def list(self):
|
120
|
+
"""List Aqua models."""
|
121
|
+
# If default is not specified,
|
122
|
+
# jupyterlab will raise 400 error when argument is not provided by the HTTP request.
|
123
|
+
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
|
124
|
+
# project_id is optional.
|
125
|
+
project_id = self.get_argument("project_id", default=None)
|
126
|
+
return self.finish(
|
127
|
+
AquaDeploymentApp().list(
|
128
|
+
compartment_id=compartment_id, project_id=project_id
|
129
|
+
)
|
130
|
+
)
|
131
|
+
|
132
|
+
@handle_exceptions
|
133
|
+
def get_deployment_config(self, model_id):
|
134
|
+
"""Gets the deployment config for Aqua model."""
|
135
|
+
return self.finish(AquaDeploymentApp().get_deployment_config(model_id=model_id))
|
136
|
+
|
137
|
+
|
138
|
+
class AquaDeploymentInferenceHandler(AquaAPIhandler):
|
139
|
+
@staticmethod
|
140
|
+
def validate_predict_url(endpoint):
|
141
|
+
try:
|
142
|
+
url = urlparse(endpoint)
|
143
|
+
if url.scheme != "https":
|
144
|
+
return False
|
145
|
+
if not url.netloc:
|
146
|
+
return False
|
147
|
+
if not url.path.endswith("/predict"):
|
148
|
+
return False
|
149
|
+
return True
|
150
|
+
except Exception:
|
151
|
+
return False
|
152
|
+
|
153
|
+
@handle_exceptions
|
154
|
+
def post(self, *args, **kwargs):
|
155
|
+
"""
|
156
|
+
Handles inference request for the Active Model Deployments
|
157
|
+
Raises
|
158
|
+
------
|
159
|
+
HTTPError
|
160
|
+
Raises HTTPError if inputs are missing or are invalid
|
161
|
+
"""
|
162
|
+
try:
|
163
|
+
input_data = self.get_json_body()
|
164
|
+
except Exception:
|
165
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
166
|
+
|
167
|
+
if not input_data:
|
168
|
+
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
169
|
+
|
170
|
+
endpoint = input_data.get("endpoint")
|
171
|
+
if not endpoint:
|
172
|
+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("endpoint"))
|
173
|
+
|
174
|
+
if not self.validate_predict_url(endpoint):
|
175
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT.format("endpoint"))
|
176
|
+
|
177
|
+
prompt = input_data.get("prompt")
|
178
|
+
if not prompt:
|
179
|
+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt"))
|
180
|
+
|
181
|
+
model_params = (
|
182
|
+
input_data.get("model_params") if input_data.get("model_params") else {}
|
183
|
+
)
|
184
|
+
try:
|
185
|
+
model_params_obj = ModelParams(**model_params)
|
186
|
+
except:
|
187
|
+
raise HTTPError(
|
188
|
+
400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params")
|
189
|
+
)
|
190
|
+
|
191
|
+
return self.finish(
|
192
|
+
MDInferenceResponse(prompt, model_params_obj).get_model_deployment_response(
|
193
|
+
endpoint
|
194
|
+
)
|
195
|
+
)
|
196
|
+
|
197
|
+
|
198
|
+
__handlers__ = [
|
199
|
+
("deployments/?([^/]*)", AquaDeploymentHandler),
|
200
|
+
("deployments/config/?([^/]*)", AquaDeploymentHandler),
|
201
|
+
("inference", AquaDeploymentInferenceHandler),
|
202
|
+
]
|
@@ -0,0 +1,135 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
from urllib.parse import urlparse
|
7
|
+
|
8
|
+
from requests import HTTPError
|
9
|
+
|
10
|
+
from ads.aqua.decorator import handle_exceptions
|
11
|
+
from ads.aqua.evaluation import AquaEvaluationApp, CreateAquaEvaluationDetails
|
12
|
+
from ads.aqua.exception import AquaError
|
13
|
+
from ads.aqua.extension.base_handler import AquaAPIhandler, Errors
|
14
|
+
from ads.aqua.extension.utils import validate_function_parameters
|
15
|
+
from ads.config import COMPARTMENT_OCID
|
16
|
+
|
17
|
+
|
18
|
+
class AquaEvaluationHandler(AquaAPIhandler):
|
19
|
+
"""Handler for Aqua Model Evaluation REST APIs."""
|
20
|
+
|
21
|
+
@handle_exceptions
|
22
|
+
def get(self, eval_id=""):
|
23
|
+
"""Handle GET request."""
|
24
|
+
url_parse = urlparse(self.request.path)
|
25
|
+
paths = url_parse.path.strip("/")
|
26
|
+
if paths.startswith("aqua/evaluation/metrics"):
|
27
|
+
return self.get_default_metrics()
|
28
|
+
if not eval_id:
|
29
|
+
return self.list()
|
30
|
+
return self.read(eval_id)
|
31
|
+
|
32
|
+
@handle_exceptions
|
33
|
+
def post(self, *args, **kwargs):
|
34
|
+
"""Handles post request for the evaluation APIs
|
35
|
+
|
36
|
+
Raises
|
37
|
+
------
|
38
|
+
HTTPError
|
39
|
+
Raises HTTPError if inputs are missing or are invalid.
|
40
|
+
"""
|
41
|
+
try:
|
42
|
+
input_data = self.get_json_body()
|
43
|
+
except Exception:
|
44
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
45
|
+
|
46
|
+
if not input_data:
|
47
|
+
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
48
|
+
|
49
|
+
validate_function_parameters(
|
50
|
+
data_class=CreateAquaEvaluationDetails, input_data=input_data
|
51
|
+
)
|
52
|
+
|
53
|
+
self.finish(
|
54
|
+
# TODO: decide what other kwargs will be needed for create aqua evaluation.
|
55
|
+
AquaEvaluationApp().create(
|
56
|
+
create_aqua_evaluation_details=(
|
57
|
+
CreateAquaEvaluationDetails(**input_data)
|
58
|
+
)
|
59
|
+
)
|
60
|
+
)
|
61
|
+
|
62
|
+
@handle_exceptions
|
63
|
+
def put(self, eval_id):
|
64
|
+
"""Handles PUT request for the evaluation APIs"""
|
65
|
+
eval_id = eval_id.split("/")[0]
|
66
|
+
return self.finish(AquaEvaluationApp().cancel(eval_id))
|
67
|
+
|
68
|
+
@handle_exceptions
|
69
|
+
def delete(self, eval_id):
|
70
|
+
return self.finish(AquaEvaluationApp().delete(eval_id))
|
71
|
+
|
72
|
+
def read(self, eval_id):
|
73
|
+
"""Read the information of an Aqua model."""
|
74
|
+
return self.finish(AquaEvaluationApp().get(eval_id))
|
75
|
+
|
76
|
+
def list(self):
|
77
|
+
"""List Aqua models."""
|
78
|
+
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
|
79
|
+
# project_id is no needed.
|
80
|
+
project_id = self.get_argument("project_id", default=None)
|
81
|
+
return self.finish(AquaEvaluationApp().list(compartment_id, project_id))
|
82
|
+
|
83
|
+
def get_default_metrics(self):
|
84
|
+
"""Lists supported metrics for evaluation."""
|
85
|
+
return self.finish(AquaEvaluationApp().get_supported_metrics())
|
86
|
+
|
87
|
+
|
88
|
+
class AquaEvaluationStatusHandler(AquaAPIhandler):
|
89
|
+
"""Handler for Aqua Evaluation status REST APIs."""
|
90
|
+
|
91
|
+
@handle_exceptions
|
92
|
+
def get(self, eval_id):
|
93
|
+
"""Handle GET request."""
|
94
|
+
eval_id = eval_id.split("/")[0]
|
95
|
+
return self.finish(AquaEvaluationApp().get_status(eval_id))
|
96
|
+
|
97
|
+
|
98
|
+
class AquaEvaluationReportHandler(AquaAPIhandler):
|
99
|
+
"""Handler for Aqua Evaluation report REST APIs."""
|
100
|
+
|
101
|
+
@handle_exceptions
|
102
|
+
def get(self, eval_id):
|
103
|
+
"""Handle GET request."""
|
104
|
+
eval_id = eval_id.split("/")[0]
|
105
|
+
return self.finish(AquaEvaluationApp().download_report(eval_id))
|
106
|
+
|
107
|
+
|
108
|
+
class AquaEvaluationMetricsHandler(AquaAPIhandler):
|
109
|
+
"""Handler for Aqua Evaluation metrics REST APIs."""
|
110
|
+
|
111
|
+
@handle_exceptions
|
112
|
+
def get(self, eval_id):
|
113
|
+
"""Handle GET request."""
|
114
|
+
eval_id = eval_id.split("/")[0]
|
115
|
+
return self.finish(AquaEvaluationApp().load_metrics(eval_id))
|
116
|
+
|
117
|
+
|
118
|
+
class AquaEvaluationConfigHandler(AquaAPIhandler):
|
119
|
+
"""Handler for Aqua Evaluation Config REST APIs."""
|
120
|
+
|
121
|
+
@handle_exceptions
|
122
|
+
def get(self, model_id):
|
123
|
+
"""Handle GET request."""
|
124
|
+
|
125
|
+
return self.finish(AquaEvaluationApp().load_evaluation_config(model_id))
|
126
|
+
|
127
|
+
|
128
|
+
__handlers__ = [
|
129
|
+
("evaluation/config/?([^/]*)", AquaEvaluationConfigHandler),
|
130
|
+
("evaluation/?([^/]*)", AquaEvaluationHandler),
|
131
|
+
("evaluation/?([^/]*/report)", AquaEvaluationReportHandler),
|
132
|
+
("evaluation/?([^/]*/metrics)", AquaEvaluationMetricsHandler),
|
133
|
+
("evaluation/?([^/]*/status)", AquaEvaluationStatusHandler),
|
134
|
+
("evaluation/?([^/]*/cancel)", AquaEvaluationHandler),
|
135
|
+
]
|
@@ -0,0 +1,66 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
|
7
|
+
from urllib.parse import urlparse
|
8
|
+
|
9
|
+
from tornado.web import HTTPError
|
10
|
+
|
11
|
+
from ads.aqua.decorator import handle_exceptions
|
12
|
+
from ads.aqua.extension.base_handler import AquaAPIhandler, Errors
|
13
|
+
from ads.aqua.extension.utils import validate_function_parameters
|
14
|
+
from ads.aqua.finetune import AquaFineTuningApp, CreateFineTuningDetails
|
15
|
+
|
16
|
+
|
17
|
+
class AquaFineTuneHandler(AquaAPIhandler):
|
18
|
+
"""Handler for Aqua fine-tuning job REST APIs."""
|
19
|
+
|
20
|
+
@handle_exceptions
|
21
|
+
def get(self, id=""):
|
22
|
+
"""Handle GET request."""
|
23
|
+
url_parse = urlparse(self.request.path)
|
24
|
+
paths = url_parse.path.strip("/")
|
25
|
+
if paths.startswith("aqua/finetuning/config"):
|
26
|
+
if not id:
|
27
|
+
raise HTTPError(
|
28
|
+
400, f"The request {self.request.path} requires model id."
|
29
|
+
)
|
30
|
+
return self.get_finetuning_config(id)
|
31
|
+
else:
|
32
|
+
raise HTTPError(400, f"The request {self.request.path} is invalid.")
|
33
|
+
|
34
|
+
@handle_exceptions
|
35
|
+
def post(self, *args, **kwargs):
|
36
|
+
"""Handles post request for the fine-tuning API
|
37
|
+
|
38
|
+
Raises
|
39
|
+
------
|
40
|
+
HTTPError
|
41
|
+
Raises HTTPError if inputs are missing or are invalid.
|
42
|
+
"""
|
43
|
+
try:
|
44
|
+
input_data = self.get_json_body()
|
45
|
+
except Exception:
|
46
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
47
|
+
|
48
|
+
if not input_data:
|
49
|
+
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
50
|
+
|
51
|
+
validate_function_parameters(
|
52
|
+
data_class=CreateFineTuningDetails, input_data=input_data
|
53
|
+
)
|
54
|
+
|
55
|
+
self.finish(AquaFineTuningApp().create(CreateFineTuningDetails(**input_data)))
|
56
|
+
|
57
|
+
@handle_exceptions
|
58
|
+
def get_finetuning_config(self, model_id):
|
59
|
+
"""Gets the finetuning config for Aqua model."""
|
60
|
+
return self.finish(AquaFineTuningApp().get_finetuning_config(model_id=model_id))
|
61
|
+
|
62
|
+
|
63
|
+
__handlers__ = [
|
64
|
+
("finetuning/?([^/]*)", AquaFineTuneHandler),
|
65
|
+
("finetuning/config/?([^/]*)", AquaFineTuneHandler),
|
66
|
+
]
|