oracle-ads 2.11.15__py3-none-any.whl → 2.11.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/app.py +5 -6
- ads/aqua/common/entities.py +17 -0
- ads/aqua/common/enums.py +14 -1
- ads/aqua/common/utils.py +160 -3
- ads/aqua/config/config.py +1 -1
- ads/aqua/config/deployment_config_defaults.json +29 -1
- ads/aqua/config/resource_limit_names.json +1 -0
- ads/aqua/constants.py +6 -1
- ads/aqua/evaluation/entities.py +0 -1
- ads/aqua/evaluation/evaluation.py +47 -14
- ads/aqua/extension/common_handler.py +75 -5
- ads/aqua/extension/common_ws_msg_handler.py +57 -0
- ads/aqua/extension/deployment_handler.py +16 -13
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +1 -1
- ads/aqua/extension/evaluation_ws_msg_handler.py +28 -6
- ads/aqua/extension/model_handler.py +134 -8
- ads/aqua/extension/models/ws_models.py +78 -3
- ads/aqua/extension/models_ws_msg_handler.py +49 -0
- ads/aqua/extension/ui_websocket_handler.py +7 -1
- ads/aqua/model/entities.py +28 -0
- ads/aqua/model/model.py +544 -129
- ads/aqua/modeldeployment/deployment.py +102 -43
- ads/aqua/modeldeployment/entities.py +9 -20
- ads/aqua/ui.py +152 -28
- ads/common/object_storage_details.py +2 -5
- ads/common/serializer.py +2 -3
- ads/jobs/builders/infrastructure/dsc_job.py +41 -12
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +74 -27
- ads/jobs/builders/runtimes/container_runtime.py +83 -4
- ads/opctl/operator/lowcode/anomaly/const.py +1 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +23 -7
- ads/opctl/operator/lowcode/anomaly/operator_config.py +1 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +4 -0
- ads/opctl/operator/lowcode/common/errors.py +6 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +3 -1
- ads/opctl/operator/lowcode/forecast/model/base_model.py +21 -13
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +11 -2
- ads/pipeline/ads_pipeline_run.py +13 -2
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/METADATA +2 -1
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/RECORD +44 -40
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/WHEEL +0 -0
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,57 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
import json
|
7
|
+
from importlib import metadata
|
8
|
+
from typing import List, Union
|
9
|
+
|
10
|
+
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, fetch_service_compartment
|
11
|
+
from ads.aqua.common.decorator import handle_exceptions
|
12
|
+
from ads.aqua.common.errors import AquaResourceAccessError
|
13
|
+
from ads.aqua.common.utils import known_realm
|
14
|
+
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
|
15
|
+
from ads.aqua.extension.models.ws_models import (
|
16
|
+
AdsVersionResponse,
|
17
|
+
CompatibilityCheckResponse,
|
18
|
+
RequestResponseType,
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
class AquaCommonWsMsgHandler(AquaWSMsgHandler):
|
23
|
+
@staticmethod
|
24
|
+
def get_message_types() -> List[RequestResponseType]:
|
25
|
+
return [RequestResponseType.AdsVersion, RequestResponseType.CompatibilityCheck]
|
26
|
+
|
27
|
+
def __init__(self, message: Union[str, bytes]):
|
28
|
+
super().__init__(message)
|
29
|
+
|
30
|
+
@handle_exceptions
|
31
|
+
def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]:
|
32
|
+
request = json.loads(self.message)
|
33
|
+
if request.get("kind") == "AdsVersion":
|
34
|
+
version = metadata.version("oracle_ads")
|
35
|
+
response = AdsVersionResponse(
|
36
|
+
message_id=request.get("message_id"),
|
37
|
+
kind=RequestResponseType.AdsVersion,
|
38
|
+
data=version,
|
39
|
+
)
|
40
|
+
return response
|
41
|
+
if request.get("kind") == "CompatibilityCheck":
|
42
|
+
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
|
43
|
+
return CompatibilityCheckResponse(
|
44
|
+
message_id=request.get("message_id"),
|
45
|
+
kind=RequestResponseType.CompatibilityCheck,
|
46
|
+
data={"status": "ok"},
|
47
|
+
)
|
48
|
+
elif known_realm():
|
49
|
+
return CompatibilityCheckResponse(
|
50
|
+
message_id=request.get("message_id"),
|
51
|
+
kind=RequestResponseType.CompatibilityCheck,
|
52
|
+
data={"status": "compatible"},
|
53
|
+
)
|
54
|
+
else:
|
55
|
+
raise AquaResourceAccessError(
|
56
|
+
"The AI Quick actions extension is not compatible in the given region."
|
57
|
+
)
|
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
@@ -8,8 +7,8 @@ from urllib.parse import urlparse
|
|
8
7
|
from tornado.web import HTTPError
|
9
8
|
|
10
9
|
from ads.aqua.common.decorator import handle_exceptions
|
11
|
-
from ads.aqua.extension.errors import Errors
|
12
10
|
from ads.aqua.extension.base_handler import AquaAPIhandler
|
11
|
+
from ads.aqua.extension.errors import Errors
|
13
12
|
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
|
14
13
|
from ads.aqua.modeldeployment.entities import ModelParams
|
15
14
|
from ads.config import COMPARTMENT_OCID, PROJECT_OCID
|
@@ -66,8 +65,8 @@ class AquaDeploymentHandler(AquaAPIhandler):
|
|
66
65
|
"""
|
67
66
|
try:
|
68
67
|
input_data = self.get_json_body()
|
69
|
-
except Exception:
|
70
|
-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
68
|
+
except Exception as ex:
|
69
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
|
71
70
|
|
72
71
|
if not input_data:
|
73
72
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
@@ -100,6 +99,9 @@ class AquaDeploymentHandler(AquaAPIhandler):
|
|
100
99
|
health_check_port = input_data.get("health_check_port")
|
101
100
|
env_var = input_data.get("env_var")
|
102
101
|
container_family = input_data.get("container_family")
|
102
|
+
ocpus = input_data.get("ocpus")
|
103
|
+
memory_in_gbs = input_data.get("memory_in_gbs")
|
104
|
+
model_file = input_data.get("model_file")
|
103
105
|
|
104
106
|
self.finish(
|
105
107
|
AquaDeploymentApp().create(
|
@@ -119,6 +121,9 @@ class AquaDeploymentHandler(AquaAPIhandler):
|
|
119
121
|
health_check_port=health_check_port,
|
120
122
|
env_var=env_var,
|
121
123
|
container_family=container_family,
|
124
|
+
ocpus=ocpus,
|
125
|
+
memory_in_gbs=memory_in_gbs,
|
126
|
+
model_file=model_file,
|
122
127
|
)
|
123
128
|
)
|
124
129
|
|
@@ -153,9 +158,7 @@ class AquaDeploymentInferenceHandler(AquaAPIhandler):
|
|
153
158
|
return False
|
154
159
|
if not url.netloc:
|
155
160
|
return False
|
156
|
-
|
157
|
-
return False
|
158
|
-
return True
|
161
|
+
return url.path.endswith("/predict")
|
159
162
|
except Exception:
|
160
163
|
return False
|
161
164
|
|
@@ -170,8 +173,8 @@ class AquaDeploymentInferenceHandler(AquaAPIhandler):
|
|
170
173
|
"""
|
171
174
|
try:
|
172
175
|
input_data = self.get_json_body()
|
173
|
-
except Exception:
|
174
|
-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
176
|
+
except Exception as ex:
|
177
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
|
175
178
|
|
176
179
|
if not input_data:
|
177
180
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
@@ -192,10 +195,10 @@ class AquaDeploymentInferenceHandler(AquaAPIhandler):
|
|
192
195
|
)
|
193
196
|
try:
|
194
197
|
model_params_obj = ModelParams(**model_params)
|
195
|
-
except:
|
198
|
+
except Exception as ex:
|
196
199
|
raise HTTPError(
|
197
200
|
400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params")
|
198
|
-
)
|
201
|
+
) from ex
|
199
202
|
|
200
203
|
return self.finish(
|
201
204
|
MDInferenceResponse(prompt, model_params_obj).get_model_deployment_response(
|
@@ -236,8 +239,8 @@ class AquaDeploymentParamsHandler(AquaAPIhandler):
|
|
236
239
|
"""
|
237
240
|
try:
|
238
241
|
input_data = self.get_json_body()
|
239
|
-
except Exception:
|
240
|
-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
242
|
+
except Exception as ex:
|
243
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
|
241
244
|
|
242
245
|
if not input_data:
|
243
246
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
import json
|
7
|
+
from typing import List, Union
|
8
|
+
|
9
|
+
from ads.aqua.common.decorator import handle_exceptions
|
10
|
+
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
|
11
|
+
from ads.aqua.extension.models.ws_models import (
|
12
|
+
ListDeploymentResponse,
|
13
|
+
ModelDeploymentDetailsResponse,
|
14
|
+
RequestResponseType,
|
15
|
+
)
|
16
|
+
from ads.aqua.modeldeployment import AquaDeploymentApp
|
17
|
+
from ads.config import COMPARTMENT_OCID
|
18
|
+
|
19
|
+
|
20
|
+
class AquaDeploymentWSMsgHandler(AquaWSMsgHandler):
|
21
|
+
def __init__(self, message: Union[str, bytes]):
|
22
|
+
super().__init__(message)
|
23
|
+
|
24
|
+
@staticmethod
|
25
|
+
def get_message_types() -> List[RequestResponseType]:
|
26
|
+
return [
|
27
|
+
RequestResponseType.ListDeployments,
|
28
|
+
RequestResponseType.DeploymentDetails,
|
29
|
+
]
|
30
|
+
|
31
|
+
@handle_exceptions
|
32
|
+
def process(self) -> Union[ListDeploymentResponse, ModelDeploymentDetailsResponse]:
|
33
|
+
request = json.loads(self.message)
|
34
|
+
if request.get("kind") == "ListDeployments":
|
35
|
+
deployment_list = AquaDeploymentApp().list(
|
36
|
+
compartment_id=request.get("compartment_id") or COMPARTMENT_OCID,
|
37
|
+
project_id=request.get("project_id"),
|
38
|
+
)
|
39
|
+
response = ListDeploymentResponse(
|
40
|
+
message_id=request.get("message_id"),
|
41
|
+
kind=RequestResponseType.ListDeployments,
|
42
|
+
data=deployment_list,
|
43
|
+
)
|
44
|
+
return response
|
45
|
+
elif request.get("kind") == "DeploymentDetails":
|
46
|
+
deployment_details = AquaDeploymentApp().get(
|
47
|
+
request.get("model_deployment_id")
|
48
|
+
)
|
49
|
+
response = ModelDeploymentDetailsResponse(
|
50
|
+
message_id=request.get("message_id"),
|
51
|
+
kind=RequestResponseType.DeploymentDetails,
|
52
|
+
data=deployment_details,
|
53
|
+
)
|
54
|
+
return response
|
ads/aqua/extension/errors.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
@@ -8,3 +7,4 @@ class Errors(str):
|
|
8
7
|
INVALID_INPUT_DATA_FORMAT = "Invalid format of input data."
|
9
8
|
NO_INPUT_DATA = "No input data provided."
|
10
9
|
MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
|
10
|
+
MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required."
|
@@ -3,13 +3,14 @@
|
|
3
3
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
|
6
|
+
import json
|
6
7
|
from typing import List, Union
|
7
8
|
|
8
9
|
from ads.aqua.common.decorator import handle_exceptions
|
9
10
|
from ads.aqua.evaluation import AquaEvaluationApp
|
10
11
|
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
|
11
12
|
from ads.aqua.extension.models.ws_models import (
|
12
|
-
|
13
|
+
EvaluationDetailsResponse,
|
13
14
|
ListEvaluationsResponse,
|
14
15
|
RequestResponseType,
|
15
16
|
)
|
@@ -19,21 +20,42 @@ from ads.config import COMPARTMENT_OCID
|
|
19
20
|
class AquaEvaluationWSMsgHandler(AquaWSMsgHandler):
|
20
21
|
@staticmethod
|
21
22
|
def get_message_types() -> List[RequestResponseType]:
|
22
|
-
return [
|
23
|
+
return [
|
24
|
+
RequestResponseType.ListEvaluations,
|
25
|
+
RequestResponseType.EvaluationDetails,
|
26
|
+
]
|
23
27
|
|
24
28
|
def __init__(self, message: Union[str, bytes]):
|
25
29
|
super().__init__(message)
|
26
30
|
|
27
31
|
@handle_exceptions
|
28
|
-
def process(self) -> ListEvaluationsResponse:
|
29
|
-
|
32
|
+
def process(self) -> Union[ListEvaluationsResponse, EvaluationDetailsResponse]:
|
33
|
+
request = json.loads(self.message)
|
34
|
+
if request["kind"] == "ListEvaluations":
|
35
|
+
return self.list_evaluations(request)
|
36
|
+
if request["kind"] == "EvaluationDetails":
|
37
|
+
return self.evaluation_details(request)
|
30
38
|
|
39
|
+
@staticmethod
|
40
|
+
def list_evaluations(request) -> ListEvaluationsResponse:
|
31
41
|
eval_list = AquaEvaluationApp().list(
|
32
|
-
|
42
|
+
request.get("compartment_id") or COMPARTMENT_OCID
|
33
43
|
)
|
34
44
|
response = ListEvaluationsResponse(
|
35
|
-
message_id=
|
45
|
+
message_id=request["message_id"],
|
36
46
|
kind=RequestResponseType.ListEvaluations,
|
37
47
|
data=eval_list,
|
38
48
|
)
|
39
49
|
return response
|
50
|
+
|
51
|
+
@staticmethod
|
52
|
+
def evaluation_details(request) -> EvaluationDetailsResponse:
|
53
|
+
evaluation_details = AquaEvaluationApp().get(
|
54
|
+
eval_id=request.get("evaluation_id")
|
55
|
+
)
|
56
|
+
response = EvaluationDetailsResponse(
|
57
|
+
message_id=request.get("message_id"),
|
58
|
+
kind=RequestResponseType.EvaluationDetails,
|
59
|
+
data=evaluation_details,
|
60
|
+
)
|
61
|
+
return response
|
@@ -1,27 +1,66 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
6
|
-
import re
|
7
5
|
from typing import Optional
|
8
6
|
from urllib.parse import urlparse
|
9
7
|
|
10
8
|
from tornado.web import HTTPError
|
11
|
-
|
9
|
+
|
12
10
|
from ads.aqua.common.decorator import handle_exceptions
|
11
|
+
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
|
12
|
+
from ads.aqua.common.utils import get_hf_model_info
|
13
13
|
from ads.aqua.extension.base_handler import AquaAPIhandler
|
14
|
+
from ads.aqua.extension.errors import Errors
|
14
15
|
from ads.aqua.model import AquaModelApp
|
16
|
+
from ads.aqua.model.constants import ModelTask
|
17
|
+
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
|
18
|
+
from ads.aqua.ui import ModelFormat
|
15
19
|
|
16
20
|
|
17
21
|
class AquaModelHandler(AquaAPIhandler):
|
18
22
|
"""Handler for Aqua Model REST APIs."""
|
19
23
|
|
20
24
|
@handle_exceptions
|
21
|
-
def get(
|
25
|
+
def get(
|
26
|
+
self,
|
27
|
+
model_id="",
|
28
|
+
):
|
22
29
|
"""Handle GET request."""
|
23
|
-
|
30
|
+
url_parse = urlparse(self.request.path)
|
31
|
+
paths = url_parse.path.strip("/")
|
32
|
+
if paths.startswith("aqua/model/files"):
|
33
|
+
os_path = self.get_argument("os_path", None)
|
34
|
+
model_name = self.get_argument("model_name", None)
|
35
|
+
|
36
|
+
model_format = self.get_argument("model_format")
|
37
|
+
if not model_format:
|
38
|
+
raise HTTPError(
|
39
|
+
400, Errors.MISSING_REQUIRED_PARAMETER.format("model_format")
|
40
|
+
)
|
41
|
+
try:
|
42
|
+
model_format = ModelFormat(model_format.upper())
|
43
|
+
except ValueError as err:
|
44
|
+
raise AquaValueError(f"Invalid model format: {model_format}") from err
|
45
|
+
else:
|
46
|
+
if os_path:
|
47
|
+
return self.finish(
|
48
|
+
AquaModelApp.get_model_files(os_path, model_format)
|
49
|
+
)
|
50
|
+
elif model_name:
|
51
|
+
return self.finish(
|
52
|
+
AquaModelApp.get_hf_model_files(model_name, model_format)
|
53
|
+
)
|
54
|
+
else:
|
55
|
+
raise HTTPError(
|
56
|
+
400,
|
57
|
+
Errors.MISSING_ONEOF_REQUIRED_PARAMETER.format(
|
58
|
+
"os_path", "model_name"
|
59
|
+
),
|
60
|
+
)
|
61
|
+
elif not model_id:
|
24
62
|
return self.list()
|
63
|
+
|
25
64
|
return self.read(model_id)
|
26
65
|
|
27
66
|
def read(self, model_id):
|
@@ -29,7 +68,7 @@ class AquaModelHandler(AquaAPIhandler):
|
|
29
68
|
return self.finish(AquaModelApp().get(model_id))
|
30
69
|
|
31
70
|
@handle_exceptions
|
32
|
-
def delete(self
|
71
|
+
def delete(self):
|
33
72
|
"""Handles DELETE request for clearing cache"""
|
34
73
|
url_parse = urlparse(self.request.path)
|
35
74
|
paths = url_parse.path.strip("/")
|
@@ -63,8 +102,8 @@ class AquaModelHandler(AquaAPIhandler):
|
|
63
102
|
"""
|
64
103
|
try:
|
65
104
|
input_data = self.get_json_body()
|
66
|
-
except Exception:
|
67
|
-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
105
|
+
except Exception as ex:
|
106
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
|
68
107
|
|
69
108
|
if not input_data:
|
70
109
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
@@ -81,15 +120,21 @@ class AquaModelHandler(AquaAPIhandler):
|
|
81
120
|
finetuning_container = input_data.get("finetuning_container")
|
82
121
|
compartment_id = input_data.get("compartment_id")
|
83
122
|
project_id = input_data.get("project_id")
|
123
|
+
model_file = input_data.get("model_file")
|
124
|
+
download_from_hf = (
|
125
|
+
str(input_data.get("download_from_hf", "false")).lower() == "true"
|
126
|
+
)
|
84
127
|
|
85
128
|
return self.finish(
|
86
129
|
AquaModelApp().register(
|
87
130
|
model=model,
|
88
131
|
os_path=os_path,
|
132
|
+
download_from_hf=download_from_hf,
|
89
133
|
inference_container=inference_container,
|
90
134
|
finetuning_container=finetuning_container,
|
91
135
|
compartment_id=compartment_id,
|
92
136
|
project_id=project_id,
|
137
|
+
model_file=model_file,
|
93
138
|
)
|
94
139
|
)
|
95
140
|
|
@@ -105,7 +150,88 @@ class AquaModelLicenseHandler(AquaAPIhandler):
|
|
105
150
|
return self.finish(AquaModelApp().load_license(model_id))
|
106
151
|
|
107
152
|
|
153
|
+
class AquaHuggingFaceHandler(AquaAPIhandler):
|
154
|
+
"""Handler for Aqua Hugging Face REST APIs."""
|
155
|
+
|
156
|
+
@staticmethod
|
157
|
+
def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:
|
158
|
+
"""
|
159
|
+
Finds a matching model in AQUA based on the model ID from Hugging Face.
|
160
|
+
|
161
|
+
Parameters
|
162
|
+
----------
|
163
|
+
model_id (str): The Hugging Face model ID to match.
|
164
|
+
|
165
|
+
Returns
|
166
|
+
-------
|
167
|
+
Optional[AquaModelSummary]
|
168
|
+
Returns the matching AquaModelSummary object if found, else None.
|
169
|
+
"""
|
170
|
+
# Convert the Hugging Face model ID to lowercase once
|
171
|
+
model_id_lower = model_id.lower()
|
172
|
+
|
173
|
+
aqua_model_app = AquaModelApp()
|
174
|
+
model_ocid = aqua_model_app._find_matching_aqua_model(model_id=model_id_lower)
|
175
|
+
if model_ocid:
|
176
|
+
return aqua_model_app.get(model_ocid, load_model_card=False)
|
177
|
+
|
178
|
+
return None
|
179
|
+
|
180
|
+
@handle_exceptions
|
181
|
+
def post(self, *args, **kwargs):
|
182
|
+
"""Handles post request for the HF Models APIs
|
183
|
+
|
184
|
+
Raises
|
185
|
+
------
|
186
|
+
HTTPError
|
187
|
+
Raises HTTPError if inputs are missing or are invalid.
|
188
|
+
"""
|
189
|
+
try:
|
190
|
+
input_data = self.get_json_body()
|
191
|
+
except Exception as ex:
|
192
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
|
193
|
+
|
194
|
+
if not input_data:
|
195
|
+
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
196
|
+
|
197
|
+
model_id = input_data.get("model_id")
|
198
|
+
|
199
|
+
if not model_id:
|
200
|
+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))
|
201
|
+
|
202
|
+
# Get model info from the HF
|
203
|
+
hf_model_info = get_hf_model_info(repo_id=model_id)
|
204
|
+
|
205
|
+
# Check if model is not disabled
|
206
|
+
if hf_model_info.disabled:
|
207
|
+
raise AquaRuntimeError(
|
208
|
+
f"The chosen model '{hf_model_info.id}' is currently disabled and cannot be imported into AQUA. "
|
209
|
+
"Please verify the model's status on the Hugging Face Model Hub or select a different model."
|
210
|
+
)
|
211
|
+
|
212
|
+
# Check pipeline_tag, it should be `text-generation`
|
213
|
+
if (
|
214
|
+
not hf_model_info.pipeline_tag
|
215
|
+
or hf_model_info.pipeline_tag.lower() != ModelTask.TEXT_GENERATION
|
216
|
+
):
|
217
|
+
raise AquaRuntimeError(
|
218
|
+
f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
|
219
|
+
f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "
|
220
|
+
"Please select a model with a compatible pipeline tag."
|
221
|
+
)
|
222
|
+
|
223
|
+
# Check if it is a service/verified model
|
224
|
+
aqua_model_info: AquaModelSummary = self._find_matching_aqua_model(
|
225
|
+
model_id=hf_model_info.id
|
226
|
+
)
|
227
|
+
|
228
|
+
return self.finish(
|
229
|
+
HFModelSummary(model_info=hf_model_info, aqua_model_info=aqua_model_info)
|
230
|
+
)
|
231
|
+
|
232
|
+
|
108
233
|
__handlers__ = [
|
109
234
|
("model/?([^/]*)", AquaModelHandler),
|
110
235
|
("model/?([^/]*)/license", AquaModelLicenseHandler),
|
236
|
+
("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
|
111
237
|
]
|
@@ -7,15 +7,22 @@
|
|
7
7
|
from dataclasses import dataclass
|
8
8
|
from typing import List, Optional
|
9
9
|
|
10
|
-
from ads.aqua.evaluation.entities import AquaEvaluationSummary
|
11
|
-
from ads.aqua.model.entities import AquaModelSummary
|
10
|
+
from ads.aqua.evaluation.entities import AquaEvaluationSummary, AquaEvaluationDetail
|
11
|
+
from ads.aqua.model.entities import AquaModelSummary, AquaModel
|
12
|
+
from ads.aqua.modeldeployment.entities import AquaDeployment, AquaDeploymentDetail
|
12
13
|
from ads.common.extended_enum import ExtendedEnumMeta
|
13
14
|
from ads.common.serializer import DataClassSerializable
|
14
15
|
|
15
16
|
|
16
17
|
class RequestResponseType(str, metaclass=ExtendedEnumMeta):
|
17
18
|
ListEvaluations = "ListEvaluations"
|
19
|
+
EvaluationDetails = "EvaluationDetails"
|
20
|
+
ListDeployments = "ListDeployments"
|
21
|
+
DeploymentDetails = "DeploymentDetails"
|
18
22
|
ListModels = "ListModels"
|
23
|
+
ModelDetails = "ModelDetails"
|
24
|
+
AdsVersion = "AdsVersion"
|
25
|
+
CompatibilityCheck = "CompatibilityCheck"
|
19
26
|
Error = "Error"
|
20
27
|
|
21
28
|
|
@@ -23,7 +30,7 @@ class RequestResponseType(str, metaclass=ExtendedEnumMeta):
|
|
23
30
|
class BaseResponse(DataClassSerializable):
|
24
31
|
message_id: str
|
25
32
|
kind: RequestResponseType
|
26
|
-
data: object
|
33
|
+
data: Optional[object]
|
27
34
|
|
28
35
|
|
29
36
|
@dataclass
|
@@ -40,9 +47,37 @@ class ListEvaluationsRequest(BaseRequest):
|
|
40
47
|
kind = RequestResponseType.ListEvaluations
|
41
48
|
|
42
49
|
|
50
|
+
@dataclass
|
51
|
+
class EvaluationDetailsRequest(BaseRequest):
|
52
|
+
kind = RequestResponseType.EvaluationDetails
|
53
|
+
evaluation_id: str
|
54
|
+
|
55
|
+
|
43
56
|
@dataclass
|
44
57
|
class ListModelsRequest(BaseRequest):
|
45
58
|
compartment_id: Optional[str] = None
|
59
|
+
project_id: Optional[str] = None
|
60
|
+
model_type: Optional[str] = None
|
61
|
+
kind = RequestResponseType.ListDeployments
|
62
|
+
|
63
|
+
|
64
|
+
@dataclass
|
65
|
+
class ModelDetailsRequest(BaseRequest):
|
66
|
+
kind = RequestResponseType.ModelDetails
|
67
|
+
model_id: str
|
68
|
+
|
69
|
+
|
70
|
+
@dataclass
|
71
|
+
class ListDeploymentRequest(BaseRequest):
|
72
|
+
compartment_id: str
|
73
|
+
project_id: Optional[str] = None
|
74
|
+
kind = RequestResponseType.ListDeployments
|
75
|
+
|
76
|
+
|
77
|
+
@dataclass
|
78
|
+
class DeploymentDetailsRequest(BaseRequest):
|
79
|
+
model_deployment_id: str
|
80
|
+
kind = RequestResponseType.DeploymentDetails
|
46
81
|
|
47
82
|
|
48
83
|
@dataclass
|
@@ -50,11 +85,51 @@ class ListEvaluationsResponse(BaseResponse):
|
|
50
85
|
data: List[AquaEvaluationSummary]
|
51
86
|
|
52
87
|
|
88
|
+
@dataclass
|
89
|
+
class EvaluationDetailsResponse(BaseResponse):
|
90
|
+
data: AquaEvaluationDetail
|
91
|
+
|
92
|
+
|
93
|
+
@dataclass
|
94
|
+
class ListDeploymentResponse(BaseResponse):
|
95
|
+
data: List[AquaDeployment]
|
96
|
+
|
97
|
+
|
98
|
+
@dataclass
|
99
|
+
class ModelDeploymentDetailsResponse(BaseResponse):
|
100
|
+
data: AquaDeploymentDetail
|
101
|
+
|
102
|
+
|
53
103
|
@dataclass
|
54
104
|
class ListModelsResponse(BaseResponse):
|
55
105
|
data: List[AquaModelSummary]
|
56
106
|
|
57
107
|
|
108
|
+
@dataclass
|
109
|
+
class ModelDetailsResponse(BaseResponse):
|
110
|
+
data: AquaModel
|
111
|
+
|
112
|
+
|
113
|
+
@dataclass
|
114
|
+
class AdsVersionRequest(BaseRequest):
|
115
|
+
kind: RequestResponseType.AdsVersion
|
116
|
+
|
117
|
+
|
118
|
+
@dataclass
|
119
|
+
class AdsVersionResponse(BaseResponse):
|
120
|
+
data: str
|
121
|
+
|
122
|
+
|
123
|
+
@dataclass
|
124
|
+
class CompatibilityCheckRequest(BaseRequest):
|
125
|
+
kind: RequestResponseType.CompatibilityCheck
|
126
|
+
|
127
|
+
|
128
|
+
@dataclass
|
129
|
+
class CompatibilityCheckResponse(BaseResponse):
|
130
|
+
data: object
|
131
|
+
|
132
|
+
|
58
133
|
@dataclass
|
59
134
|
class AquaWsError(DataClassSerializable):
|
60
135
|
status: str
|
@@ -0,0 +1,49 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
import json
|
7
|
+
from typing import List, Union
|
8
|
+
|
9
|
+
from ads.aqua.common.decorator import handle_exceptions
|
10
|
+
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
|
11
|
+
from ads.aqua.extension.models.ws_models import (
|
12
|
+
ListModelsResponse,
|
13
|
+
ModelDetailsResponse,
|
14
|
+
RequestResponseType,
|
15
|
+
)
|
16
|
+
from ads.aqua.model import AquaModelApp
|
17
|
+
|
18
|
+
|
19
|
+
class AquaModelWSMsgHandler(AquaWSMsgHandler):
|
20
|
+
def __init__(self, message: Union[str, bytes]):
|
21
|
+
super().__init__(message)
|
22
|
+
|
23
|
+
@staticmethod
|
24
|
+
def get_message_types() -> List[RequestResponseType]:
|
25
|
+
return [RequestResponseType.ListModels, RequestResponseType.ModelDetails]
|
26
|
+
|
27
|
+
@handle_exceptions
|
28
|
+
def process(self) -> Union[ListModelsResponse, ModelDetailsResponse]:
|
29
|
+
request = json.loads(self.message)
|
30
|
+
if request.get("kind") == "ListModels":
|
31
|
+
models_list = AquaModelApp().list(
|
32
|
+
compartment_id=request.get("compartment_id"),
|
33
|
+
project_id=request.get("project_id"),
|
34
|
+
model_type=request.get("model_type"),
|
35
|
+
)
|
36
|
+
response = ListModelsResponse(
|
37
|
+
message_id=request.get("message_id"),
|
38
|
+
kind=RequestResponseType.ListModels,
|
39
|
+
data=models_list,
|
40
|
+
)
|
41
|
+
return response
|
42
|
+
elif request.get("kind") == "ModelDetails":
|
43
|
+
model_id = request.get("model_id")
|
44
|
+
response = AquaModelApp().get(model_id)
|
45
|
+
return ModelDetailsResponse(
|
46
|
+
message_id=request.get("message_id"),
|
47
|
+
kind=RequestResponseType.ModelDetails,
|
48
|
+
data=response,
|
49
|
+
)
|
@@ -14,6 +14,8 @@ from tornado.websocket import WebSocketHandler
|
|
14
14
|
|
15
15
|
from ads.aqua import logger
|
16
16
|
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
|
17
|
+
from ads.aqua.extension.common_ws_msg_handler import AquaCommonWsMsgHandler
|
18
|
+
from ads.aqua.extension.deployment_ws_msg_handler import AquaDeploymentWSMsgHandler
|
17
19
|
from ads.aqua.extension.evaluation_ws_msg_handler import AquaEvaluationWSMsgHandler
|
18
20
|
from ads.aqua.extension.models.ws_models import (
|
19
21
|
AquaWsError,
|
@@ -22,6 +24,7 @@ from ads.aqua.extension.models.ws_models import (
|
|
22
24
|
ErrorResponse,
|
23
25
|
RequestResponseType,
|
24
26
|
)
|
27
|
+
from ads.aqua.extension.models_ws_msg_handler import AquaModelWSMsgHandler
|
25
28
|
|
26
29
|
MAX_WORKERS = 20
|
27
30
|
|
@@ -43,7 +46,10 @@ def get_aqua_internal_error_response(message_id: str) -> ErrorResponse:
|
|
43
46
|
class AquaUIWebSocketHandler(WebSocketHandler):
|
44
47
|
"""Handler for Aqua Websocket."""
|
45
48
|
|
46
|
-
_handlers_: List[Type[AquaWSMsgHandler]] = [AquaEvaluationWSMsgHandler
|
49
|
+
_handlers_: List[Type[AquaWSMsgHandler]] = [AquaEvaluationWSMsgHandler,
|
50
|
+
AquaDeploymentWSMsgHandler,
|
51
|
+
AquaModelWSMsgHandler,
|
52
|
+
AquaCommonWsMsgHandler]
|
47
53
|
|
48
54
|
thread_pool: ThreadPoolExecutor
|
49
55
|
|