oracle-ads 2.11.15__py3-none-any.whl → 2.11.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. ads/aqua/app.py +5 -6
  2. ads/aqua/common/entities.py +17 -0
  3. ads/aqua/common/enums.py +14 -1
  4. ads/aqua/common/utils.py +160 -3
  5. ads/aqua/config/config.py +1 -1
  6. ads/aqua/config/deployment_config_defaults.json +29 -1
  7. ads/aqua/config/resource_limit_names.json +1 -0
  8. ads/aqua/constants.py +6 -1
  9. ads/aqua/evaluation/entities.py +0 -1
  10. ads/aqua/evaluation/evaluation.py +47 -14
  11. ads/aqua/extension/common_handler.py +75 -5
  12. ads/aqua/extension/common_ws_msg_handler.py +57 -0
  13. ads/aqua/extension/deployment_handler.py +16 -13
  14. ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
  15. ads/aqua/extension/errors.py +1 -1
  16. ads/aqua/extension/evaluation_ws_msg_handler.py +28 -6
  17. ads/aqua/extension/model_handler.py +134 -8
  18. ads/aqua/extension/models/ws_models.py +78 -3
  19. ads/aqua/extension/models_ws_msg_handler.py +49 -0
  20. ads/aqua/extension/ui_websocket_handler.py +7 -1
  21. ads/aqua/model/entities.py +28 -0
  22. ads/aqua/model/model.py +544 -129
  23. ads/aqua/modeldeployment/deployment.py +102 -43
  24. ads/aqua/modeldeployment/entities.py +9 -20
  25. ads/aqua/ui.py +152 -28
  26. ads/common/object_storage_details.py +2 -5
  27. ads/common/serializer.py +2 -3
  28. ads/jobs/builders/infrastructure/dsc_job.py +41 -12
  29. ads/jobs/builders/infrastructure/dsc_job_runtime.py +74 -27
  30. ads/jobs/builders/runtimes/container_runtime.py +83 -4
  31. ads/opctl/operator/lowcode/anomaly/const.py +1 -0
  32. ads/opctl/operator/lowcode/anomaly/model/base_model.py +23 -7
  33. ads/opctl/operator/lowcode/anomaly/operator_config.py +1 -0
  34. ads/opctl/operator/lowcode/anomaly/schema.yaml +4 -0
  35. ads/opctl/operator/lowcode/common/errors.py +6 -0
  36. ads/opctl/operator/lowcode/forecast/model/arima.py +3 -1
  37. ads/opctl/operator/lowcode/forecast/model/base_model.py +21 -13
  38. ads/opctl/operator/lowcode/forecast/model_evaluator.py +11 -2
  39. ads/pipeline/ads_pipeline_run.py +13 -2
  40. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/METADATA +2 -1
  41. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/RECORD +44 -40
  42. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/LICENSE.txt +0 -0
  43. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/WHEEL +0 -0
  44. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,57 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ import json
7
+ from importlib import metadata
8
+ from typing import List, Union
9
+
10
+ from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, fetch_service_compartment
11
+ from ads.aqua.common.decorator import handle_exceptions
12
+ from ads.aqua.common.errors import AquaResourceAccessError
13
+ from ads.aqua.common.utils import known_realm
14
+ from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
15
+ from ads.aqua.extension.models.ws_models import (
16
+ AdsVersionResponse,
17
+ CompatibilityCheckResponse,
18
+ RequestResponseType,
19
+ )
20
+
21
+
22
+ class AquaCommonWsMsgHandler(AquaWSMsgHandler):
23
+ @staticmethod
24
+ def get_message_types() -> List[RequestResponseType]:
25
+ return [RequestResponseType.AdsVersion, RequestResponseType.CompatibilityCheck]
26
+
27
+ def __init__(self, message: Union[str, bytes]):
28
+ super().__init__(message)
29
+
30
+ @handle_exceptions
31
+ def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]:
32
+ request = json.loads(self.message)
33
+ if request.get("kind") == "AdsVersion":
34
+ version = metadata.version("oracle_ads")
35
+ response = AdsVersionResponse(
36
+ message_id=request.get("message_id"),
37
+ kind=RequestResponseType.AdsVersion,
38
+ data=version,
39
+ )
40
+ return response
41
+ if request.get("kind") == "CompatibilityCheck":
42
+ if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
43
+ return CompatibilityCheckResponse(
44
+ message_id=request.get("message_id"),
45
+ kind=RequestResponseType.CompatibilityCheck,
46
+ data={"status": "ok"},
47
+ )
48
+ elif known_realm():
49
+ return CompatibilityCheckResponse(
50
+ message_id=request.get("message_id"),
51
+ kind=RequestResponseType.CompatibilityCheck,
52
+ data={"status": "compatible"},
53
+ )
54
+ else:
55
+ raise AquaResourceAccessError(
56
+ "The AI Quick actions extension is not compatible in the given region."
57
+ )
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
@@ -8,8 +7,8 @@ from urllib.parse import urlparse
8
7
  from tornado.web import HTTPError
9
8
 
10
9
  from ads.aqua.common.decorator import handle_exceptions
11
- from ads.aqua.extension.errors import Errors
12
10
  from ads.aqua.extension.base_handler import AquaAPIhandler
11
+ from ads.aqua.extension.errors import Errors
13
12
  from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
14
13
  from ads.aqua.modeldeployment.entities import ModelParams
15
14
  from ads.config import COMPARTMENT_OCID, PROJECT_OCID
@@ -66,8 +65,8 @@ class AquaDeploymentHandler(AquaAPIhandler):
66
65
  """
67
66
  try:
68
67
  input_data = self.get_json_body()
69
- except Exception:
70
- raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
68
+ except Exception as ex:
69
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
71
70
 
72
71
  if not input_data:
73
72
  raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -100,6 +99,9 @@ class AquaDeploymentHandler(AquaAPIhandler):
100
99
  health_check_port = input_data.get("health_check_port")
101
100
  env_var = input_data.get("env_var")
102
101
  container_family = input_data.get("container_family")
102
+ ocpus = input_data.get("ocpus")
103
+ memory_in_gbs = input_data.get("memory_in_gbs")
104
+ model_file = input_data.get("model_file")
103
105
 
104
106
  self.finish(
105
107
  AquaDeploymentApp().create(
@@ -119,6 +121,9 @@ class AquaDeploymentHandler(AquaAPIhandler):
119
121
  health_check_port=health_check_port,
120
122
  env_var=env_var,
121
123
  container_family=container_family,
124
+ ocpus=ocpus,
125
+ memory_in_gbs=memory_in_gbs,
126
+ model_file=model_file,
122
127
  )
123
128
  )
124
129
 
@@ -153,9 +158,7 @@ class AquaDeploymentInferenceHandler(AquaAPIhandler):
153
158
  return False
154
159
  if not url.netloc:
155
160
  return False
156
- if not url.path.endswith("/predict"):
157
- return False
158
- return True
161
+ return url.path.endswith("/predict")
159
162
  except Exception:
160
163
  return False
161
164
 
@@ -170,8 +173,8 @@ class AquaDeploymentInferenceHandler(AquaAPIhandler):
170
173
  """
171
174
  try:
172
175
  input_data = self.get_json_body()
173
- except Exception:
174
- raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
176
+ except Exception as ex:
177
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
175
178
 
176
179
  if not input_data:
177
180
  raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -192,10 +195,10 @@ class AquaDeploymentInferenceHandler(AquaAPIhandler):
192
195
  )
193
196
  try:
194
197
  model_params_obj = ModelParams(**model_params)
195
- except:
198
+ except Exception as ex:
196
199
  raise HTTPError(
197
200
  400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params")
198
- )
201
+ ) from ex
199
202
 
200
203
  return self.finish(
201
204
  MDInferenceResponse(prompt, model_params_obj).get_model_deployment_response(
@@ -236,8 +239,8 @@ class AquaDeploymentParamsHandler(AquaAPIhandler):
236
239
  """
237
240
  try:
238
241
  input_data = self.get_json_body()
239
- except Exception:
240
- raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
242
+ except Exception as ex:
243
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
241
244
 
242
245
  if not input_data:
243
246
  raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -0,0 +1,54 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ import json
7
+ from typing import List, Union
8
+
9
+ from ads.aqua.common.decorator import handle_exceptions
10
+ from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
11
+ from ads.aqua.extension.models.ws_models import (
12
+ ListDeploymentResponse,
13
+ ModelDeploymentDetailsResponse,
14
+ RequestResponseType,
15
+ )
16
+ from ads.aqua.modeldeployment import AquaDeploymentApp
17
+ from ads.config import COMPARTMENT_OCID
18
+
19
+
20
+ class AquaDeploymentWSMsgHandler(AquaWSMsgHandler):
21
+ def __init__(self, message: Union[str, bytes]):
22
+ super().__init__(message)
23
+
24
+ @staticmethod
25
+ def get_message_types() -> List[RequestResponseType]:
26
+ return [
27
+ RequestResponseType.ListDeployments,
28
+ RequestResponseType.DeploymentDetails,
29
+ ]
30
+
31
+ @handle_exceptions
32
+ def process(self) -> Union[ListDeploymentResponse, ModelDeploymentDetailsResponse]:
33
+ request = json.loads(self.message)
34
+ if request.get("kind") == "ListDeployments":
35
+ deployment_list = AquaDeploymentApp().list(
36
+ compartment_id=request.get("compartment_id") or COMPARTMENT_OCID,
37
+ project_id=request.get("project_id"),
38
+ )
39
+ response = ListDeploymentResponse(
40
+ message_id=request.get("message_id"),
41
+ kind=RequestResponseType.ListDeployments,
42
+ data=deployment_list,
43
+ )
44
+ return response
45
+ elif request.get("kind") == "DeploymentDetails":
46
+ deployment_details = AquaDeploymentApp().get(
47
+ request.get("model_deployment_id")
48
+ )
49
+ response = ModelDeploymentDetailsResponse(
50
+ message_id=request.get("message_id"),
51
+ kind=RequestResponseType.DeploymentDetails,
52
+ data=deployment_details,
53
+ )
54
+ return response
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
@@ -8,3 +7,4 @@ class Errors(str):
8
7
  INVALID_INPUT_DATA_FORMAT = "Invalid format of input data."
9
8
  NO_INPUT_DATA = "No input data provided."
10
9
  MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
10
+ MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required."
@@ -3,13 +3,14 @@
3
3
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
+ import json
6
7
  from typing import List, Union
7
8
 
8
9
  from ads.aqua.common.decorator import handle_exceptions
9
10
  from ads.aqua.evaluation import AquaEvaluationApp
10
11
  from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
11
12
  from ads.aqua.extension.models.ws_models import (
12
- ListEvaluationsRequest,
13
+ EvaluationDetailsResponse,
13
14
  ListEvaluationsResponse,
14
15
  RequestResponseType,
15
16
  )
@@ -19,21 +20,42 @@ from ads.config import COMPARTMENT_OCID
19
20
  class AquaEvaluationWSMsgHandler(AquaWSMsgHandler):
20
21
  @staticmethod
21
22
  def get_message_types() -> List[RequestResponseType]:
22
- return [RequestResponseType.ListEvaluations]
23
+ return [
24
+ RequestResponseType.ListEvaluations,
25
+ RequestResponseType.EvaluationDetails,
26
+ ]
23
27
 
24
28
  def __init__(self, message: Union[str, bytes]):
25
29
  super().__init__(message)
26
30
 
27
31
  @handle_exceptions
28
- def process(self) -> ListEvaluationsResponse:
29
- list_eval_request = ListEvaluationsRequest.from_json(self.message)
32
+ def process(self) -> Union[ListEvaluationsResponse, EvaluationDetailsResponse]:
33
+ request = json.loads(self.message)
34
+ if request["kind"] == "ListEvaluations":
35
+ return self.list_evaluations(request)
36
+ if request["kind"] == "EvaluationDetails":
37
+ return self.evaluation_details(request)
30
38
 
39
+ @staticmethod
40
+ def list_evaluations(request) -> ListEvaluationsResponse:
31
41
  eval_list = AquaEvaluationApp().list(
32
- list_eval_request.compartment_id or COMPARTMENT_OCID,
42
+ request.get("compartment_id") or COMPARTMENT_OCID
33
43
  )
34
44
  response = ListEvaluationsResponse(
35
- message_id=list_eval_request.message_id,
45
+ message_id=request["message_id"],
36
46
  kind=RequestResponseType.ListEvaluations,
37
47
  data=eval_list,
38
48
  )
39
49
  return response
50
+
51
+ @staticmethod
52
+ def evaluation_details(request) -> EvaluationDetailsResponse:
53
+ evaluation_details = AquaEvaluationApp().get(
54
+ eval_id=request.get("evaluation_id")
55
+ )
56
+ response = EvaluationDetailsResponse(
57
+ message_id=request.get("message_id"),
58
+ kind=RequestResponseType.EvaluationDetails,
59
+ data=evaluation_details,
60
+ )
61
+ return response
@@ -1,27 +1,66 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
6
- import re
7
5
  from typing import Optional
8
6
  from urllib.parse import urlparse
9
7
 
10
8
  from tornado.web import HTTPError
11
- from ads.aqua.extension.errors import Errors
9
+
12
10
  from ads.aqua.common.decorator import handle_exceptions
11
+ from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
12
+ from ads.aqua.common.utils import get_hf_model_info
13
13
  from ads.aqua.extension.base_handler import AquaAPIhandler
14
+ from ads.aqua.extension.errors import Errors
14
15
  from ads.aqua.model import AquaModelApp
16
+ from ads.aqua.model.constants import ModelTask
17
+ from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
18
+ from ads.aqua.ui import ModelFormat
15
19
 
16
20
 
17
21
  class AquaModelHandler(AquaAPIhandler):
18
22
  """Handler for Aqua Model REST APIs."""
19
23
 
20
24
  @handle_exceptions
21
- def get(self, model_id=""):
25
+ def get(
26
+ self,
27
+ model_id="",
28
+ ):
22
29
  """Handle GET request."""
23
- if not model_id:
30
+ url_parse = urlparse(self.request.path)
31
+ paths = url_parse.path.strip("/")
32
+ if paths.startswith("aqua/model/files"):
33
+ os_path = self.get_argument("os_path", None)
34
+ model_name = self.get_argument("model_name", None)
35
+
36
+ model_format = self.get_argument("model_format")
37
+ if not model_format:
38
+ raise HTTPError(
39
+ 400, Errors.MISSING_REQUIRED_PARAMETER.format("model_format")
40
+ )
41
+ try:
42
+ model_format = ModelFormat(model_format.upper())
43
+ except ValueError as err:
44
+ raise AquaValueError(f"Invalid model format: {model_format}") from err
45
+ else:
46
+ if os_path:
47
+ return self.finish(
48
+ AquaModelApp.get_model_files(os_path, model_format)
49
+ )
50
+ elif model_name:
51
+ return self.finish(
52
+ AquaModelApp.get_hf_model_files(model_name, model_format)
53
+ )
54
+ else:
55
+ raise HTTPError(
56
+ 400,
57
+ Errors.MISSING_ONEOF_REQUIRED_PARAMETER.format(
58
+ "os_path", "model_name"
59
+ ),
60
+ )
61
+ elif not model_id:
24
62
  return self.list()
63
+
25
64
  return self.read(model_id)
26
65
 
27
66
  def read(self, model_id):
@@ -29,7 +68,7 @@ class AquaModelHandler(AquaAPIhandler):
29
68
  return self.finish(AquaModelApp().get(model_id))
30
69
 
31
70
  @handle_exceptions
32
- def delete(self, id=""):
71
+ def delete(self):
33
72
  """Handles DELETE request for clearing cache"""
34
73
  url_parse = urlparse(self.request.path)
35
74
  paths = url_parse.path.strip("/")
@@ -63,8 +102,8 @@ class AquaModelHandler(AquaAPIhandler):
63
102
  """
64
103
  try:
65
104
  input_data = self.get_json_body()
66
- except Exception:
67
- raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
105
+ except Exception as ex:
106
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
68
107
 
69
108
  if not input_data:
70
109
  raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -81,15 +120,21 @@ class AquaModelHandler(AquaAPIhandler):
81
120
  finetuning_container = input_data.get("finetuning_container")
82
121
  compartment_id = input_data.get("compartment_id")
83
122
  project_id = input_data.get("project_id")
123
+ model_file = input_data.get("model_file")
124
+ download_from_hf = (
125
+ str(input_data.get("download_from_hf", "false")).lower() == "true"
126
+ )
84
127
 
85
128
  return self.finish(
86
129
  AquaModelApp().register(
87
130
  model=model,
88
131
  os_path=os_path,
132
+ download_from_hf=download_from_hf,
89
133
  inference_container=inference_container,
90
134
  finetuning_container=finetuning_container,
91
135
  compartment_id=compartment_id,
92
136
  project_id=project_id,
137
+ model_file=model_file,
93
138
  )
94
139
  )
95
140
 
@@ -105,7 +150,88 @@ class AquaModelLicenseHandler(AquaAPIhandler):
105
150
  return self.finish(AquaModelApp().load_license(model_id))
106
151
 
107
152
 
153
+ class AquaHuggingFaceHandler(AquaAPIhandler):
154
+ """Handler for Aqua Hugging Face REST APIs."""
155
+
156
+ @staticmethod
157
+ def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:
158
+ """
159
+ Finds a matching model in AQUA based on the model ID from Hugging Face.
160
+
161
+ Parameters
162
+ ----------
163
+ model_id (str): The Hugging Face model ID to match.
164
+
165
+ Returns
166
+ -------
167
+ Optional[AquaModelSummary]
168
+ Returns the matching AquaModelSummary object if found, else None.
169
+ """
170
+ # Convert the Hugging Face model ID to lowercase once
171
+ model_id_lower = model_id.lower()
172
+
173
+ aqua_model_app = AquaModelApp()
174
+ model_ocid = aqua_model_app._find_matching_aqua_model(model_id=model_id_lower)
175
+ if model_ocid:
176
+ return aqua_model_app.get(model_ocid, load_model_card=False)
177
+
178
+ return None
179
+
180
+ @handle_exceptions
181
+ def post(self, *args, **kwargs):
182
+ """Handles post request for the HF Models APIs
183
+
184
+ Raises
185
+ ------
186
+ HTTPError
187
+ Raises HTTPError if inputs are missing or are invalid.
188
+ """
189
+ try:
190
+ input_data = self.get_json_body()
191
+ except Exception as ex:
192
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
193
+
194
+ if not input_data:
195
+ raise HTTPError(400, Errors.NO_INPUT_DATA)
196
+
197
+ model_id = input_data.get("model_id")
198
+
199
+ if not model_id:
200
+ raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))
201
+
202
+ # Get model info from the HF
203
+ hf_model_info = get_hf_model_info(repo_id=model_id)
204
+
205
+ # Check if model is not disabled
206
+ if hf_model_info.disabled:
207
+ raise AquaRuntimeError(
208
+ f"The chosen model '{hf_model_info.id}' is currently disabled and cannot be imported into AQUA. "
209
+ "Please verify the model's status on the Hugging Face Model Hub or select a different model."
210
+ )
211
+
212
+ # Check pipeline_tag, it should be `text-generation`
213
+ if (
214
+ not hf_model_info.pipeline_tag
215
+ or hf_model_info.pipeline_tag.lower() != ModelTask.TEXT_GENERATION
216
+ ):
217
+ raise AquaRuntimeError(
218
+ f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
219
+ f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "
220
+ "Please select a model with a compatible pipeline tag."
221
+ )
222
+
223
+ # Check if it is a service/verified model
224
+ aqua_model_info: AquaModelSummary = self._find_matching_aqua_model(
225
+ model_id=hf_model_info.id
226
+ )
227
+
228
+ return self.finish(
229
+ HFModelSummary(model_info=hf_model_info, aqua_model_info=aqua_model_info)
230
+ )
231
+
232
+
108
233
  __handlers__ = [
109
234
  ("model/?([^/]*)", AquaModelHandler),
110
235
  ("model/?([^/]*)/license", AquaModelLicenseHandler),
236
+ ("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
111
237
  ]
@@ -7,15 +7,22 @@
7
7
  from dataclasses import dataclass
8
8
  from typing import List, Optional
9
9
 
10
- from ads.aqua.evaluation.entities import AquaEvaluationSummary
11
- from ads.aqua.model.entities import AquaModelSummary
10
+ from ads.aqua.evaluation.entities import AquaEvaluationSummary, AquaEvaluationDetail
11
+ from ads.aqua.model.entities import AquaModelSummary, AquaModel
12
+ from ads.aqua.modeldeployment.entities import AquaDeployment, AquaDeploymentDetail
12
13
  from ads.common.extended_enum import ExtendedEnumMeta
13
14
  from ads.common.serializer import DataClassSerializable
14
15
 
15
16
 
16
17
  class RequestResponseType(str, metaclass=ExtendedEnumMeta):
17
18
  ListEvaluations = "ListEvaluations"
19
+ EvaluationDetails = "EvaluationDetails"
20
+ ListDeployments = "ListDeployments"
21
+ DeploymentDetails = "DeploymentDetails"
18
22
  ListModels = "ListModels"
23
+ ModelDetails = "ModelDetails"
24
+ AdsVersion = "AdsVersion"
25
+ CompatibilityCheck = "CompatibilityCheck"
19
26
  Error = "Error"
20
27
 
21
28
 
@@ -23,7 +30,7 @@ class RequestResponseType(str, metaclass=ExtendedEnumMeta):
23
30
  class BaseResponse(DataClassSerializable):
24
31
  message_id: str
25
32
  kind: RequestResponseType
26
- data: object
33
+ data: Optional[object]
27
34
 
28
35
 
29
36
  @dataclass
@@ -40,9 +47,37 @@ class ListEvaluationsRequest(BaseRequest):
40
47
  kind = RequestResponseType.ListEvaluations
41
48
 
42
49
 
50
+ @dataclass
51
+ class EvaluationDetailsRequest(BaseRequest):
52
+ kind = RequestResponseType.EvaluationDetails
53
+ evaluation_id: str
54
+
55
+
43
56
  @dataclass
44
57
  class ListModelsRequest(BaseRequest):
45
58
  compartment_id: Optional[str] = None
59
+ project_id: Optional[str] = None
60
+ model_type: Optional[str] = None
61
+ kind = RequestResponseType.ListDeployments
62
+
63
+
64
+ @dataclass
65
+ class ModelDetailsRequest(BaseRequest):
66
+ kind = RequestResponseType.ModelDetails
67
+ model_id: str
68
+
69
+
70
+ @dataclass
71
+ class ListDeploymentRequest(BaseRequest):
72
+ compartment_id: str
73
+ project_id: Optional[str] = None
74
+ kind = RequestResponseType.ListDeployments
75
+
76
+
77
+ @dataclass
78
+ class DeploymentDetailsRequest(BaseRequest):
79
+ model_deployment_id: str
80
+ kind = RequestResponseType.DeploymentDetails
46
81
 
47
82
 
48
83
  @dataclass
@@ -50,11 +85,51 @@ class ListEvaluationsResponse(BaseResponse):
50
85
  data: List[AquaEvaluationSummary]
51
86
 
52
87
 
88
+ @dataclass
89
+ class EvaluationDetailsResponse(BaseResponse):
90
+ data: AquaEvaluationDetail
91
+
92
+
93
+ @dataclass
94
+ class ListDeploymentResponse(BaseResponse):
95
+ data: List[AquaDeployment]
96
+
97
+
98
+ @dataclass
99
+ class ModelDeploymentDetailsResponse(BaseResponse):
100
+ data: AquaDeploymentDetail
101
+
102
+
53
103
  @dataclass
54
104
  class ListModelsResponse(BaseResponse):
55
105
  data: List[AquaModelSummary]
56
106
 
57
107
 
108
+ @dataclass
109
+ class ModelDetailsResponse(BaseResponse):
110
+ data: AquaModel
111
+
112
+
113
+ @dataclass
114
+ class AdsVersionRequest(BaseRequest):
115
+ kind: RequestResponseType.AdsVersion
116
+
117
+
118
+ @dataclass
119
+ class AdsVersionResponse(BaseResponse):
120
+ data: str
121
+
122
+
123
+ @dataclass
124
+ class CompatibilityCheckRequest(BaseRequest):
125
+ kind: RequestResponseType.CompatibilityCheck
126
+
127
+
128
+ @dataclass
129
+ class CompatibilityCheckResponse(BaseResponse):
130
+ data: object
131
+
132
+
58
133
  @dataclass
59
134
  class AquaWsError(DataClassSerializable):
60
135
  status: str
@@ -0,0 +1,49 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ import json
7
+ from typing import List, Union
8
+
9
+ from ads.aqua.common.decorator import handle_exceptions
10
+ from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
11
+ from ads.aqua.extension.models.ws_models import (
12
+ ListModelsResponse,
13
+ ModelDetailsResponse,
14
+ RequestResponseType,
15
+ )
16
+ from ads.aqua.model import AquaModelApp
17
+
18
+
19
+ class AquaModelWSMsgHandler(AquaWSMsgHandler):
20
+ def __init__(self, message: Union[str, bytes]):
21
+ super().__init__(message)
22
+
23
+ @staticmethod
24
+ def get_message_types() -> List[RequestResponseType]:
25
+ return [RequestResponseType.ListModels, RequestResponseType.ModelDetails]
26
+
27
+ @handle_exceptions
28
+ def process(self) -> Union[ListModelsResponse, ModelDetailsResponse]:
29
+ request = json.loads(self.message)
30
+ if request.get("kind") == "ListModels":
31
+ models_list = AquaModelApp().list(
32
+ compartment_id=request.get("compartment_id"),
33
+ project_id=request.get("project_id"),
34
+ model_type=request.get("model_type"),
35
+ )
36
+ response = ListModelsResponse(
37
+ message_id=request.get("message_id"),
38
+ kind=RequestResponseType.ListModels,
39
+ data=models_list,
40
+ )
41
+ return response
42
+ elif request.get("kind") == "ModelDetails":
43
+ model_id = request.get("model_id")
44
+ response = AquaModelApp().get(model_id)
45
+ return ModelDetailsResponse(
46
+ message_id=request.get("message_id"),
47
+ kind=RequestResponseType.ModelDetails,
48
+ data=response,
49
+ )
@@ -14,6 +14,8 @@ from tornado.websocket import WebSocketHandler
14
14
 
15
15
  from ads.aqua import logger
16
16
  from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
17
+ from ads.aqua.extension.common_ws_msg_handler import AquaCommonWsMsgHandler
18
+ from ads.aqua.extension.deployment_ws_msg_handler import AquaDeploymentWSMsgHandler
17
19
  from ads.aqua.extension.evaluation_ws_msg_handler import AquaEvaluationWSMsgHandler
18
20
  from ads.aqua.extension.models.ws_models import (
19
21
  AquaWsError,
@@ -22,6 +24,7 @@ from ads.aqua.extension.models.ws_models import (
22
24
  ErrorResponse,
23
25
  RequestResponseType,
24
26
  )
27
+ from ads.aqua.extension.models_ws_msg_handler import AquaModelWSMsgHandler
25
28
 
26
29
  MAX_WORKERS = 20
27
30
 
@@ -43,7 +46,10 @@ def get_aqua_internal_error_response(message_id: str) -> ErrorResponse:
43
46
  class AquaUIWebSocketHandler(WebSocketHandler):
44
47
  """Handler for Aqua Websocket."""
45
48
 
46
- _handlers_: List[Type[AquaWSMsgHandler]] = [AquaEvaluationWSMsgHandler]
49
+ _handlers_: List[Type[AquaWSMsgHandler]] = [AquaEvaluationWSMsgHandler,
50
+ AquaDeploymentWSMsgHandler,
51
+ AquaModelWSMsgHandler,
52
+ AquaCommonWsMsgHandler]
47
53
 
48
54
  thread_pool: ThreadPoolExecutor
49
55