oracle-ads 2.10.1__py3-none-any.whl → 2.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. ads/aqua/__init__.py +12 -0
  2. ads/aqua/base.py +324 -0
  3. ads/aqua/cli.py +19 -0
  4. ads/aqua/config/deployment_config_defaults.json +9 -0
  5. ads/aqua/config/resource_limit_names.json +7 -0
  6. ads/aqua/constants.py +45 -0
  7. ads/aqua/data.py +40 -0
  8. ads/aqua/decorator.py +101 -0
  9. ads/aqua/deployment.py +643 -0
  10. ads/aqua/dummy_data/icon.txt +1 -0
  11. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  12. ads/aqua/dummy_data/oci_models.json +1 -0
  13. ads/aqua/dummy_data/readme.md +26 -0
  14. ads/aqua/evaluation.py +1751 -0
  15. ads/aqua/exception.py +82 -0
  16. ads/aqua/extension/__init__.py +40 -0
  17. ads/aqua/extension/base_handler.py +138 -0
  18. ads/aqua/extension/common_handler.py +21 -0
  19. ads/aqua/extension/deployment_handler.py +202 -0
  20. ads/aqua/extension/evaluation_handler.py +135 -0
  21. ads/aqua/extension/finetune_handler.py +66 -0
  22. ads/aqua/extension/model_handler.py +59 -0
  23. ads/aqua/extension/ui_handler.py +201 -0
  24. ads/aqua/extension/utils.py +23 -0
  25. ads/aqua/finetune.py +579 -0
  26. ads/aqua/job.py +29 -0
  27. ads/aqua/model.py +819 -0
  28. ads/aqua/training/__init__.py +4 -0
  29. ads/aqua/training/exceptions.py +459 -0
  30. ads/aqua/ui.py +453 -0
  31. ads/aqua/utils.py +715 -0
  32. ads/cli.py +37 -6
  33. ads/common/decorator/__init__.py +7 -3
  34. ads/common/decorator/require_nonempty_arg.py +65 -0
  35. ads/common/object_storage_details.py +166 -7
  36. ads/common/oci_client.py +18 -1
  37. ads/common/oci_logging.py +2 -2
  38. ads/common/oci_mixin.py +4 -5
  39. ads/common/serializer.py +34 -5
  40. ads/common/utils.py +75 -10
  41. ads/config.py +40 -1
  42. ads/jobs/ads_job.py +43 -25
  43. ads/jobs/builders/infrastructure/base.py +4 -2
  44. ads/jobs/builders/infrastructure/dsc_job.py +49 -39
  45. ads/jobs/builders/runtimes/base.py +71 -1
  46. ads/jobs/builders/runtimes/container_runtime.py +4 -4
  47. ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
  48. ads/jobs/templates/driver_pytorch.py +27 -10
  49. ads/model/artifact_downloader.py +84 -14
  50. ads/model/artifact_uploader.py +25 -23
  51. ads/model/datascience_model.py +388 -38
  52. ads/model/deployment/model_deployment.py +10 -2
  53. ads/model/generic_model.py +8 -0
  54. ads/model/model_file_description_schema.json +68 -0
  55. ads/model/model_metadata.py +1 -1
  56. ads/model/service/oci_datascience_model.py +34 -5
  57. ads/opctl/operator/lowcode/anomaly/README.md +2 -1
  58. ads/opctl/operator/lowcode/anomaly/__main__.py +10 -4
  59. ads/opctl/operator/lowcode/anomaly/environment.yaml +2 -1
  60. ads/opctl/operator/lowcode/anomaly/model/automlx.py +12 -6
  61. ads/opctl/operator/lowcode/forecast/README.md +3 -2
  62. ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
  63. ads/opctl/operator/lowcode/forecast/model/automlx.py +12 -23
  64. ads/telemetry/base.py +62 -0
  65. ads/telemetry/client.py +105 -0
  66. ads/telemetry/telemetry.py +6 -3
  67. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/METADATA +37 -7
  68. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/RECORD +71 -36
  69. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/LICENSE.txt +0 -0
  70. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/WHEEL +0 -0
  71. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/entry_points.txt +0 -0
ads/aqua/exception.py ADDED
@@ -0,0 +1,82 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ """
7
+ aqua.exception
8
+ ~~~~~~~~~~~~~~
9
+
10
+ This module contains the set of Aqua exceptions.
11
+ """
12
+
13
+
14
+ class AquaError(Exception):
15
+ """AquaError
16
+
17
+ The base exception from which all exceptions raised by Aqua
18
+ will inherit.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ reason: str,
24
+ status: int,
25
+ service_payload: dict = None,
26
+ ):
27
+ """Initializes an AquaError.
28
+
29
+ Parameters
30
+ ----------
31
+ reason: str
32
+ User friendly error message.
33
+ status: int
34
+ Http status code that are going to raise.
35
+ service_payload: dict
36
+ Payload to contain more details related to the error.
37
+ """
38
+ self.service_payload = service_payload or {}
39
+ self.status = status
40
+ self.reason = reason
41
+
42
+
43
+ class AquaValueError(AquaError, ValueError):
44
+ """Exception raised for unexpected values."""
45
+
46
+ def __init__(self, reason, status=403, service_payload=None):
47
+ super().__init__(reason, status, service_payload)
48
+
49
+
50
+ class AquaFileNotFoundError(AquaError, FileNotFoundError):
51
+ """Exception raised for missing target file."""
52
+
53
+ def __init__(self, reason, status=404, service_payload=None):
54
+ super().__init__(reason, status, service_payload)
55
+
56
+
57
+ class AquaRuntimeError(AquaError, RuntimeError):
58
+ """Exception raised for generic errors at runtime."""
59
+
60
+ def __init__(self, reason, status=400, service_payload=None):
61
+ super().__init__(reason, status, service_payload)
62
+
63
+
64
+ class AquaMissingKeyError(AquaError):
65
+ """Exception raised when missing metadata in resource."""
66
+
67
+ def __init__(self, reason, status=400, service_payload=None):
68
+ super().__init__(reason, status, service_payload)
69
+
70
+
71
+ class AquaFileExistsError(AquaError, FileExistsError):
72
+ """Exception raised when file already exists in resource."""
73
+
74
+ def __init__(self, reason, status=400, service_payload=None):
75
+ super().__init__(reason, status, service_payload)
76
+
77
+
78
+ class AquaResourceAccessError(AquaError):
79
+ """Exception raised when file already exists in resource."""
80
+
81
+ def __init__(self, reason, status=404, service_payload=None):
82
+ super().__init__(reason, status, service_payload)
@@ -0,0 +1,40 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*--
3
+
4
+ # Copyright (c) 2024 Oracle and/or its affiliates.
5
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
+
7
+ from jupyter_server.utils import url_path_join
8
+
9
+ from ads.aqua.extension.common_handler import __handlers__ as __common_handlers__
10
+ from ads.aqua.extension.deployment_handler import (
11
+ __handlers__ as __deployment_handlers__,
12
+ )
13
+ from ads.aqua.extension.evaluation_handler import __handlers__ as __eval_handlers__
14
+ from ads.aqua.extension.finetune_handler import __handlers__ as __finetune_handlers__
15
+ from ads.aqua.extension.model_handler import __handlers__ as __model_handlers__
16
+ from ads.aqua.extension.ui_handler import __handlers__ as __ui_handlers__
17
+
18
+ __handlers__ = (
19
+ __finetune_handlers__
20
+ + __model_handlers__
21
+ + __common_handlers__
22
+ + __deployment_handlers__
23
+ + __ui_handlers__
24
+ + __eval_handlers__
25
+ )
26
+
27
+
28
+ def load_jupyter_server_extension(nb_server_app):
29
+ web_app = nb_server_app.web_app
30
+ host_pattern = ".*$"
31
+ route_pattern = url_path_join(web_app.settings["base_url"], "aqua")
32
+
33
+ web_app.add_handlers(
34
+ host_pattern,
35
+ [(url_path_join(route_pattern, url), handler) for url, handler in __handlers__],
36
+ )
37
+
38
+
39
+ def _jupyter_server_extension_paths():
40
+ return [{"module": "ads.aqua.extension"}]
@@ -0,0 +1,138 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+
7
+ import json
8
+ import traceback
9
+ import uuid
10
+ from dataclasses import asdict, is_dataclass
11
+ from typing import Any
12
+
13
+ from notebook.base.handlers import APIHandler
14
+ from tornado.web import HTTPError, Application
15
+ from tornado import httputil
16
+ from ads.telemetry.client import TelemetryClient
17
+ from ads.config import AQUA_TELEMETRY_BUCKET, AQUA_TELEMETRY_BUCKET_NS
18
+ from ads.aqua import logger
19
+
20
+
21
+ class AquaAPIhandler(APIHandler):
22
+ """Base handler for Aqua REST APIs."""
23
+
24
+ def __init__(
25
+ self,
26
+ application: "Application",
27
+ request: httputil.HTTPServerRequest,
28
+ **kwargs: Any,
29
+ ):
30
+ super().__init__(application, request, **kwargs)
31
+
32
+ try:
33
+ self.telemetry = TelemetryClient(
34
+ bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS
35
+ )
36
+ except:
37
+ pass
38
+
39
+ @staticmethod
40
+ def serialize(obj: Any):
41
+ """Serialize the object.
42
+ If the object is a dataclass, convert it to dictionary. Otherwise, convert it to string.
43
+ """
44
+ if hasattr(obj, "to_dict") and callable(obj.to_dict):
45
+ return obj.to_dict()
46
+
47
+ if is_dataclass(obj):
48
+ return asdict(obj)
49
+
50
+ return str(obj)
51
+
52
+ def finish(self, payload=None): # pylint: disable=W0221
53
+ """Ending the HTTP request by returning a payload and status code.
54
+
55
+ Tornado finish() only takes one argument.
56
+ Calling finish() with more than one arguments will cause error.
57
+ """
58
+ if payload is None:
59
+ return super().finish()
60
+ # If the payload is a list, put into a dictionary with key=data
61
+ if isinstance(payload, list):
62
+ payload = {"data": payload}
63
+ # Convert the payload to a JSON serializable object
64
+ payload = json.loads(json.dumps(payload, default=self.serialize))
65
+ return super().finish(payload)
66
+
67
+ def write_error(self, status_code, **kwargs):
68
+ """AquaAPIhandler errors are JSON, not human pages."""
69
+
70
+ self.set_header("Content-Type", "application/json")
71
+ reason = kwargs.get("reason")
72
+ self.set_status(status_code, reason=reason)
73
+ service_payload = kwargs.get("service_payload", {})
74
+ message = self.get_default_error_messages(service_payload, str(status_code))
75
+ reply = {
76
+ "status": status_code,
77
+ "message": message,
78
+ "service_payload": service_payload,
79
+ "reason": reason,
80
+ }
81
+ exc_info = kwargs.get("exc_info")
82
+ if exc_info:
83
+ logger.error("".join(traceback.format_exception(*exc_info)))
84
+ e = exc_info[1]
85
+ if isinstance(e, HTTPError):
86
+ reply["message"] = e.log_message or message
87
+ reply["reason"] = e.reason
88
+ reply["request_id"] = str(uuid.uuid4())
89
+ else:
90
+ reply["request_id"] = str(uuid.uuid4())
91
+
92
+ logger.warning(reply["message"])
93
+
94
+ # telemetry may not be present if there is an error while initializing
95
+ if hasattr(self, "telemetry"):
96
+ self.telemetry.record_event_async(
97
+ category="aqua/error",
98
+ action=str(status_code),
99
+ value=reason,
100
+ )
101
+
102
+ self.finish(json.dumps(reply))
103
+
104
+ @staticmethod
105
+ def get_default_error_messages(service_payload: dict, status_code: str):
106
+ """Method that maps the error messages based on the operation performed or the status codes encountered."""
107
+
108
+ messages = {
109
+ "400": "Something went wrong with your request.",
110
+ "403": "We're having trouble processing your request with the information provided.",
111
+ "404": "Authorization Failed: The resource you're looking for isn't accessible.",
112
+ "408": "Server is taking too long to response, please try again.",
113
+ "500": "An error occurred while creating the resource.",
114
+ "create": "Authorization Failed: Could not create resource.",
115
+ "get": "Authorization Failed: The resource you're looking for isn't accessible.",
116
+ }
117
+
118
+ if service_payload and "operation_name" in service_payload:
119
+ operation_name = service_payload["operation_name"]
120
+ if operation_name:
121
+ if operation_name.startswith("create"):
122
+ return messages["create"]
123
+ elif operation_name.startswith("list") or operation_name.startswith(
124
+ "get"
125
+ ):
126
+ return messages["get"] + f" Operation Name: {operation_name}."
127
+
128
+ if status_code in messages:
129
+ return messages[status_code]
130
+ else:
131
+ return "Unknown HTTP Error."
132
+
133
+
134
+ # todo: remove after error handler is implemented
135
+ class Errors(str):
136
+ INVALID_INPUT_DATA_FORMAT = "Invalid format of input data."
137
+ NO_INPUT_DATA = "No input data provided."
138
+ MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
@@ -0,0 +1,21 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+
7
+ from importlib import metadata
8
+
9
+ from ads.aqua.extension.base_handler import AquaAPIhandler
10
+
11
+
12
+ class ADSVersionHandler(AquaAPIhandler):
13
+ """The handler to get the current version of the ADS."""
14
+
15
+ def get(self):
16
+ self.finish({"data": metadata.version("oracle_ads")})
17
+
18
+
19
+ __handlers__ = [
20
+ ("ads_version", ADSVersionHandler),
21
+ ]
@@ -0,0 +1,202 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ from urllib.parse import urlparse
7
+
8
+ from tornado.web import HTTPError
9
+
10
+ from ads.aqua.deployment import AquaDeploymentApp, MDInferenceResponse, ModelParams
11
+ from ads.aqua.extension.base_handler import AquaAPIhandler, Errors
12
+ from ads.config import COMPARTMENT_OCID, PROJECT_OCID
13
+ from ads.aqua.decorator import handle_exceptions
14
+
15
+
16
+ class AquaDeploymentHandler(AquaAPIhandler):
17
+ """
18
+ Handler for Aqua Deployment REST APIs.
19
+
20
+ Methods
21
+ -------
22
+ get(self, id="")
23
+ Retrieves a list of AQUA deployments or model info or logs by ID.
24
+ post(self, *args, **kwargs)
25
+ Creates a new AQUA deployment.
26
+ read(self, id: str)
27
+ Reads the AQUA deployment information.
28
+ list(self)
29
+ Lists all the AQUA deployments.
30
+ get_deployment_config(self, model_id)
31
+ Gets the deployment config for Aqua model.
32
+
33
+ Raises
34
+ ------
35
+ HTTPError: For various failure scenarios such as invalid input format, missing data, etc.
36
+ """
37
+
38
+ @handle_exceptions
39
+ def get(self, id=""):
40
+ """Handle GET request."""
41
+ url_parse = urlparse(self.request.path)
42
+ paths = url_parse.path.strip("/")
43
+ if paths.startswith("aqua/deployments/config"):
44
+ if not id:
45
+ raise HTTPError(
46
+ 400, f"The request {self.request.path} requires model id."
47
+ )
48
+ return self.get_deployment_config(id)
49
+ elif paths.startswith("aqua/deployments"):
50
+ if not id:
51
+ return self.list()
52
+ return self.read(id)
53
+ else:
54
+ raise HTTPError(400, f"The request {self.request.path} is invalid.")
55
+
56
+ @handle_exceptions
57
+ def post(self, *args, **kwargs):
58
+ """
59
+ Handles post request for the deployment APIs
60
+ Raises
61
+ ------
62
+ HTTPError
63
+ Raises HTTPError if inputs are missing or are invalid
64
+ """
65
+ try:
66
+ input_data = self.get_json_body()
67
+ except Exception:
68
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
69
+
70
+ if not input_data:
71
+ raise HTTPError(400, Errors.NO_INPUT_DATA)
72
+
73
+ # required input parameters
74
+ display_name = input_data.get("display_name")
75
+ if not display_name:
76
+ raise HTTPError(
77
+ 400, Errors.MISSING_REQUIRED_PARAMETER.format("display_name")
78
+ )
79
+ instance_shape = input_data.get("instance_shape")
80
+ if not instance_shape:
81
+ raise HTTPError(
82
+ 400, Errors.MISSING_REQUIRED_PARAMETER.format("instance_shape")
83
+ )
84
+ model_id = input_data.get("model_id")
85
+ if not model_id:
86
+ raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))
87
+
88
+ compartment_id = input_data.get("compartment_id", COMPARTMENT_OCID)
89
+ project_id = input_data.get("project_id", PROJECT_OCID)
90
+ log_group_id = input_data.get("log_group_id")
91
+ access_log_id = input_data.get("access_log_id")
92
+ predict_log_id = input_data.get("predict_log_id")
93
+ description = input_data.get("description")
94
+ instance_count = input_data.get("instance_count")
95
+ bandwidth_mbps = input_data.get("bandwidth_mbps")
96
+
97
+ self.finish(
98
+ AquaDeploymentApp().create(
99
+ compartment_id=compartment_id,
100
+ project_id=project_id,
101
+ model_id=model_id,
102
+ display_name=display_name,
103
+ description=description,
104
+ instance_count=instance_count,
105
+ instance_shape=instance_shape,
106
+ log_group_id=log_group_id,
107
+ access_log_id=access_log_id,
108
+ predict_log_id=predict_log_id,
109
+ bandwidth_mbps=bandwidth_mbps,
110
+ )
111
+ )
112
+
113
+ @handle_exceptions
114
+ def read(self, id):
115
+ """Read the information of an Aqua model deployment."""
116
+ return self.finish(AquaDeploymentApp().get(model_deployment_id=id))
117
+
118
+ @handle_exceptions
119
+ def list(self):
120
+ """List Aqua models."""
121
+ # If default is not specified,
122
+ # jupyterlab will raise 400 error when argument is not provided by the HTTP request.
123
+ compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
124
+ # project_id is optional.
125
+ project_id = self.get_argument("project_id", default=None)
126
+ return self.finish(
127
+ AquaDeploymentApp().list(
128
+ compartment_id=compartment_id, project_id=project_id
129
+ )
130
+ )
131
+
132
+ @handle_exceptions
133
+ def get_deployment_config(self, model_id):
134
+ """Gets the deployment config for Aqua model."""
135
+ return self.finish(AquaDeploymentApp().get_deployment_config(model_id=model_id))
136
+
137
+
138
+ class AquaDeploymentInferenceHandler(AquaAPIhandler):
139
+ @staticmethod
140
+ def validate_predict_url(endpoint):
141
+ try:
142
+ url = urlparse(endpoint)
143
+ if url.scheme != "https":
144
+ return False
145
+ if not url.netloc:
146
+ return False
147
+ if not url.path.endswith("/predict"):
148
+ return False
149
+ return True
150
+ except Exception:
151
+ return False
152
+
153
+ @handle_exceptions
154
+ def post(self, *args, **kwargs):
155
+ """
156
+ Handles inference request for the Active Model Deployments
157
+ Raises
158
+ ------
159
+ HTTPError
160
+ Raises HTTPError if inputs are missing or are invalid
161
+ """
162
+ try:
163
+ input_data = self.get_json_body()
164
+ except Exception:
165
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
166
+
167
+ if not input_data:
168
+ raise HTTPError(400, Errors.NO_INPUT_DATA)
169
+
170
+ endpoint = input_data.get("endpoint")
171
+ if not endpoint:
172
+ raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("endpoint"))
173
+
174
+ if not self.validate_predict_url(endpoint):
175
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT.format("endpoint"))
176
+
177
+ prompt = input_data.get("prompt")
178
+ if not prompt:
179
+ raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt"))
180
+
181
+ model_params = (
182
+ input_data.get("model_params") if input_data.get("model_params") else {}
183
+ )
184
+ try:
185
+ model_params_obj = ModelParams(**model_params)
186
+ except:
187
+ raise HTTPError(
188
+ 400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params")
189
+ )
190
+
191
+ return self.finish(
192
+ MDInferenceResponse(prompt, model_params_obj).get_model_deployment_response(
193
+ endpoint
194
+ )
195
+ )
196
+
197
+
198
+ __handlers__ = [
199
+ ("deployments/?([^/]*)", AquaDeploymentHandler),
200
+ ("deployments/config/?([^/]*)", AquaDeploymentHandler),
201
+ ("inference", AquaDeploymentInferenceHandler),
202
+ ]
@@ -0,0 +1,135 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ from urllib.parse import urlparse
7
+
8
+ from requests import HTTPError
9
+
10
+ from ads.aqua.decorator import handle_exceptions
11
+ from ads.aqua.evaluation import AquaEvaluationApp, CreateAquaEvaluationDetails
12
+ from ads.aqua.exception import AquaError
13
+ from ads.aqua.extension.base_handler import AquaAPIhandler, Errors
14
+ from ads.aqua.extension.utils import validate_function_parameters
15
+ from ads.config import COMPARTMENT_OCID
16
+
17
+
18
+ class AquaEvaluationHandler(AquaAPIhandler):
19
+ """Handler for Aqua Model Evaluation REST APIs."""
20
+
21
+ @handle_exceptions
22
+ def get(self, eval_id=""):
23
+ """Handle GET request."""
24
+ url_parse = urlparse(self.request.path)
25
+ paths = url_parse.path.strip("/")
26
+ if paths.startswith("aqua/evaluation/metrics"):
27
+ return self.get_default_metrics()
28
+ if not eval_id:
29
+ return self.list()
30
+ return self.read(eval_id)
31
+
32
+ @handle_exceptions
33
+ def post(self, *args, **kwargs):
34
+ """Handles post request for the evaluation APIs
35
+
36
+ Raises
37
+ ------
38
+ HTTPError
39
+ Raises HTTPError if inputs are missing or are invalid.
40
+ """
41
+ try:
42
+ input_data = self.get_json_body()
43
+ except Exception:
44
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
45
+
46
+ if not input_data:
47
+ raise HTTPError(400, Errors.NO_INPUT_DATA)
48
+
49
+ validate_function_parameters(
50
+ data_class=CreateAquaEvaluationDetails, input_data=input_data
51
+ )
52
+
53
+ self.finish(
54
+ # TODO: decide what other kwargs will be needed for create aqua evaluation.
55
+ AquaEvaluationApp().create(
56
+ create_aqua_evaluation_details=(
57
+ CreateAquaEvaluationDetails(**input_data)
58
+ )
59
+ )
60
+ )
61
+
62
+ @handle_exceptions
63
+ def put(self, eval_id):
64
+ """Handles PUT request for the evaluation APIs"""
65
+ eval_id = eval_id.split("/")[0]
66
+ return self.finish(AquaEvaluationApp().cancel(eval_id))
67
+
68
+ @handle_exceptions
69
+ def delete(self, eval_id):
70
+ return self.finish(AquaEvaluationApp().delete(eval_id))
71
+
72
+ def read(self, eval_id):
73
+ """Read the information of an Aqua model."""
74
+ return self.finish(AquaEvaluationApp().get(eval_id))
75
+
76
+ def list(self):
77
+ """List Aqua models."""
78
+ compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
79
+ # project_id is no needed.
80
+ project_id = self.get_argument("project_id", default=None)
81
+ return self.finish(AquaEvaluationApp().list(compartment_id, project_id))
82
+
83
+ def get_default_metrics(self):
84
+ """Lists supported metrics for evaluation."""
85
+ return self.finish(AquaEvaluationApp().get_supported_metrics())
86
+
87
+
88
+ class AquaEvaluationStatusHandler(AquaAPIhandler):
89
+ """Handler for Aqua Evaluation status REST APIs."""
90
+
91
+ @handle_exceptions
92
+ def get(self, eval_id):
93
+ """Handle GET request."""
94
+ eval_id = eval_id.split("/")[0]
95
+ return self.finish(AquaEvaluationApp().get_status(eval_id))
96
+
97
+
98
+ class AquaEvaluationReportHandler(AquaAPIhandler):
99
+ """Handler for Aqua Evaluation report REST APIs."""
100
+
101
+ @handle_exceptions
102
+ def get(self, eval_id):
103
+ """Handle GET request."""
104
+ eval_id = eval_id.split("/")[0]
105
+ return self.finish(AquaEvaluationApp().download_report(eval_id))
106
+
107
+
108
+ class AquaEvaluationMetricsHandler(AquaAPIhandler):
109
+ """Handler for Aqua Evaluation metrics REST APIs."""
110
+
111
+ @handle_exceptions
112
+ def get(self, eval_id):
113
+ """Handle GET request."""
114
+ eval_id = eval_id.split("/")[0]
115
+ return self.finish(AquaEvaluationApp().load_metrics(eval_id))
116
+
117
+
118
+ class AquaEvaluationConfigHandler(AquaAPIhandler):
119
+ """Handler for Aqua Evaluation Config REST APIs."""
120
+
121
+ @handle_exceptions
122
+ def get(self, model_id):
123
+ """Handle GET request."""
124
+
125
+ return self.finish(AquaEvaluationApp().load_evaluation_config(model_id))
126
+
127
+
128
+ __handlers__ = [
129
+ ("evaluation/config/?([^/]*)", AquaEvaluationConfigHandler),
130
+ ("evaluation/?([^/]*)", AquaEvaluationHandler),
131
+ ("evaluation/?([^/]*/report)", AquaEvaluationReportHandler),
132
+ ("evaluation/?([^/]*/metrics)", AquaEvaluationMetricsHandler),
133
+ ("evaluation/?([^/]*/status)", AquaEvaluationStatusHandler),
134
+ ("evaluation/?([^/]*/cancel)", AquaEvaluationHandler),
135
+ ]
@@ -0,0 +1,66 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+
7
+ from urllib.parse import urlparse
8
+
9
+ from tornado.web import HTTPError
10
+
11
+ from ads.aqua.decorator import handle_exceptions
12
+ from ads.aqua.extension.base_handler import AquaAPIhandler, Errors
13
+ from ads.aqua.extension.utils import validate_function_parameters
14
+ from ads.aqua.finetune import AquaFineTuningApp, CreateFineTuningDetails
15
+
16
+
17
+ class AquaFineTuneHandler(AquaAPIhandler):
18
+ """Handler for Aqua fine-tuning job REST APIs."""
19
+
20
+ @handle_exceptions
21
+ def get(self, id=""):
22
+ """Handle GET request."""
23
+ url_parse = urlparse(self.request.path)
24
+ paths = url_parse.path.strip("/")
25
+ if paths.startswith("aqua/finetuning/config"):
26
+ if not id:
27
+ raise HTTPError(
28
+ 400, f"The request {self.request.path} requires model id."
29
+ )
30
+ return self.get_finetuning_config(id)
31
+ else:
32
+ raise HTTPError(400, f"The request {self.request.path} is invalid.")
33
+
34
+ @handle_exceptions
35
+ def post(self, *args, **kwargs):
36
+ """Handles post request for the fine-tuning API
37
+
38
+ Raises
39
+ ------
40
+ HTTPError
41
+ Raises HTTPError if inputs are missing or are invalid.
42
+ """
43
+ try:
44
+ input_data = self.get_json_body()
45
+ except Exception:
46
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
47
+
48
+ if not input_data:
49
+ raise HTTPError(400, Errors.NO_INPUT_DATA)
50
+
51
+ validate_function_parameters(
52
+ data_class=CreateFineTuningDetails, input_data=input_data
53
+ )
54
+
55
+ self.finish(AquaFineTuningApp().create(CreateFineTuningDetails(**input_data)))
56
+
57
+ @handle_exceptions
58
+ def get_finetuning_config(self, model_id):
59
+ """Gets the finetuning config for Aqua model."""
60
+ return self.finish(AquaFineTuningApp().get_finetuning_config(model_id=model_id))
61
+
62
+
63
+ __handlers__ = [
64
+ ("finetuning/?([^/]*)", AquaFineTuneHandler),
65
+ ("finetuning/config/?([^/]*)", AquaFineTuneHandler),
66
+ ]