oracle-ads 2.11.15__py3-none-any.whl → 2.11.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. ads/aqua/common/entities.py +17 -0
  2. ads/aqua/common/enums.py +5 -1
  3. ads/aqua/common/utils.py +32 -2
  4. ads/aqua/config/config.py +1 -1
  5. ads/aqua/config/deployment_config_defaults.json +29 -1
  6. ads/aqua/config/resource_limit_names.json +1 -0
  7. ads/aqua/constants.py +5 -1
  8. ads/aqua/evaluation/entities.py +0 -1
  9. ads/aqua/evaluation/evaluation.py +47 -14
  10. ads/aqua/extension/common_ws_msg_handler.py +57 -0
  11. ads/aqua/extension/deployment_handler.py +14 -13
  12. ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
  13. ads/aqua/extension/errors.py +1 -1
  14. ads/aqua/extension/evaluation_ws_msg_handler.py +28 -6
  15. ads/aqua/extension/model_handler.py +31 -6
  16. ads/aqua/extension/models/ws_models.py +78 -3
  17. ads/aqua/extension/models_ws_msg_handler.py +49 -0
  18. ads/aqua/extension/ui_websocket_handler.py +7 -1
  19. ads/aqua/model/entities.py +11 -1
  20. ads/aqua/model/model.py +260 -90
  21. ads/aqua/modeldeployment/deployment.py +52 -7
  22. ads/aqua/modeldeployment/entities.py +9 -20
  23. ads/aqua/ui.py +152 -28
  24. ads/common/object_storage_details.py +2 -5
  25. ads/common/serializer.py +2 -3
  26. ads/jobs/builders/infrastructure/dsc_job.py +29 -3
  27. ads/jobs/builders/infrastructure/dsc_job_runtime.py +74 -27
  28. ads/jobs/builders/runtimes/container_runtime.py +83 -4
  29. ads/opctl/operator/lowcode/anomaly/const.py +1 -0
  30. ads/opctl/operator/lowcode/anomaly/model/base_model.py +23 -7
  31. ads/opctl/operator/lowcode/anomaly/operator_config.py +1 -0
  32. ads/opctl/operator/lowcode/anomaly/schema.yaml +4 -0
  33. ads/opctl/operator/lowcode/common/errors.py +6 -0
  34. ads/opctl/operator/lowcode/forecast/model/base_model.py +21 -13
  35. ads/opctl/operator/lowcode/forecast/model_evaluator.py +11 -2
  36. ads/pipeline/ads_pipeline_run.py +13 -2
  37. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/METADATA +1 -1
  38. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/RECORD +41 -37
  39. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/LICENSE.txt +0 -0
  40. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/WHEEL +0 -0
  41. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,17 @@
1
+ #!/usr/bin/env python
2
+ # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
+
5
+
6
+ class ContainerSpec:
7
+ """
8
+ Class to hold to hold keys within the container spec.
9
+ """
10
+
11
+ CONTAINER_SPEC = "containerSpec"
12
+ CLI_PARM = "cliParam"
13
+ SERVER_PORT = "serverPort"
14
+ HEALTH_CHECK_PORT = "healthCheckPort"
15
+ ENV_VARS = "envVars"
16
+ RESTRICTED_PARAMS = "restrictedParams"
17
+ EVALUATION_CONFIGURATION = "evaluationConfiguration"
ads/aqua/common/enums.py CHANGED
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
@@ -8,6 +7,7 @@ aqua.common.enums
8
7
  ~~~~~~~~~~~~~~
9
8
  This module contains the set of enums used in AQUA.
10
9
  """
10
+
11
11
  from ads.common.extended_enum import ExtendedEnumMeta
12
12
 
13
13
 
@@ -38,21 +38,25 @@ class Tags(str, metaclass=ExtendedEnumMeta):
38
38
  READY_TO_IMPORT = "ready_to_import"
39
39
  BASE_MODEL_CUSTOM = "aqua_custom_base_model"
40
40
  AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
41
+ MODEL_FORMAT = "model_format"
41
42
 
42
43
 
43
44
  class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
44
45
  CONTAINER_TYPE_VLLM = "vllm"
45
46
  CONTAINER_TYPE_TGI = "tgi"
47
+ CONTAINER_TYPE_LLAMA_CPP = "llama-cpp"
46
48
 
47
49
 
48
50
  class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
49
51
  AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
50
52
  AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
53
+ AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
51
54
 
52
55
 
53
56
  class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
54
57
  PARAM_TYPE_VLLM = "VLLM_PARAMS"
55
58
  PARAM_TYPE_TGI = "TGI_PARAMS"
59
+ PARAM_TYPE_LLAMA_CPP = "LLAMA_CPP_PARAMS"
56
60
 
57
61
 
58
62
  class HuggingFaceTags(str, metaclass=ExtendedEnumMeta):
ads/aqua/common/utils.py CHANGED
@@ -10,6 +10,7 @@ import logging
10
10
  import os
11
11
  import random
12
12
  import re
13
+ from datetime import datetime, timedelta
13
14
  from functools import wraps
14
15
  from pathlib import Path
15
16
  from string import Template
@@ -17,7 +18,9 @@ from typing import List, Union
17
18
 
18
19
  import fsspec
19
20
  import oci
21
+ from cachetools import TTLCache, cached
20
22
  from oci.data_science.models import JobRun, Model
23
+ from oci.object_storage.models import ObjectSummary
21
24
 
22
25
  from ads.aqua.common.enums import (
23
26
  InferenceContainerParamType,
@@ -45,7 +48,6 @@ from ads.aqua.constants import (
45
48
  )
46
49
  from ads.aqua.data import AquaResourceIdentifier
47
50
  from ads.common.auth import default_signer
48
- from ads.common.decorator.threaded import threaded
49
51
  from ads.common.extended_enum import ExtendedEnumMeta
50
52
  from ads.common.object_storage_details import ObjectStorageDetails
51
53
  from ads.common.oci_resource import SEARCH_TYPE, OCIResource
@@ -213,7 +215,6 @@ def read_file(file_path: str, **kwargs) -> str:
213
215
  return UNKNOWN
214
216
 
215
217
 
216
- @threaded()
217
218
  def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
218
219
  artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
219
220
  signer = default_signer() if artifact_path.startswith("oci://") else {}
@@ -228,6 +229,32 @@ def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
228
229
  return config
229
230
 
230
231
 
232
+ def list_os_files_with_extension(oss_path: str, extension: str) -> [str]:
233
+ """
234
+ List files in the specified directory with the given extension.
235
+
236
+ Parameters:
237
+ - oss_path: The path to the directory where files are located.
238
+ - extension: The file extension to filter by (e.g., 'txt' for text files).
239
+
240
+ Returns:
241
+ - A list of file paths matching the specified extension.
242
+ """
243
+
244
+ oss_client = ObjectStorageDetails.from_path(oss_path)
245
+
246
+ # Ensure the extension is prefixed with a dot if not already
247
+ if not extension.startswith("."):
248
+ extension = "." + extension
249
+ files: List[ObjectSummary] = oss_client.list_objects().objects
250
+
251
+ return [
252
+ file.name[len(oss_client.filepath) :].lstrip("/")
253
+ for file in files
254
+ if file.name.endswith(extension)
255
+ ]
256
+
257
+
231
258
  def is_valid_ocid(ocid: str) -> bool:
232
259
  """Checks if the given ocid is valid.
233
260
 
@@ -503,6 +530,7 @@ def container_config_path():
503
530
  return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
504
531
 
505
532
 
533
+ @cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
506
534
  def get_container_config():
507
535
  config = load_config(
508
536
  file_path=container_config_path(),
@@ -881,6 +909,8 @@ def get_container_params_type(container_type_name: str) -> str:
881
909
  return InferenceContainerParamType.PARAM_TYPE_VLLM
882
910
  elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
883
911
  return InferenceContainerParamType.PARAM_TYPE_TGI
912
+ elif InferenceContainerType.CONTAINER_TYPE_LLAMA_CPP in container_type_name.lower():
913
+ return InferenceContainerParamType.PARAM_TYPE_LLAMA_CPP
884
914
  else:
885
915
  return UNKNOWN
886
916
 
ads/aqua/config/config.py CHANGED
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
@@ -14,5 +13,6 @@ def get_finetuning_config_defaults():
14
13
  "BM.GPU.A10.4": {"batch_size": 1, "replica": 1},
15
14
  "BM.GPU4.8": {"batch_size": 4, "replica": 1},
16
15
  "BM.GPU.A100-v2.8": {"batch_size": 6, "replica": 1},
16
+ "BM.GPU.H100.8": {"batch_size": 6, "replica": 1},
17
17
  }
18
18
  }
@@ -1,9 +1,37 @@
1
1
  {
2
+ "configuration": {
3
+ "VM.Standard.A1.Flex": {
4
+ "parameters": {},
5
+ "shape_info": {
6
+ "configs": [
7
+ {
8
+ "memory_in_gbs": 128,
9
+ "ocpu": 20
10
+ },
11
+ {
12
+ "memory_in_gbs": 256,
13
+ "ocpu": 40
14
+ },
15
+ {
16
+ "memory_in_gbs": 384,
17
+ "ocpu": 60
18
+ },
19
+ {
20
+ "memory_in_gbs": 512,
21
+ "ocpu": 80
22
+ }
23
+ ],
24
+ "type": "CPU"
25
+ }
26
+ }
27
+ },
2
28
  "shape": [
3
29
  "VM.GPU.A10.1",
4
30
  "VM.GPU.A10.2",
5
31
  "BM.GPU.A10.4",
6
32
  "BM.GPU4.8",
7
- "BM.GPU.A100-v2.8"
33
+ "BM.GPU.A100-v2.8",
34
+ "BM.GPU.H100.8",
35
+ "VM.Standard.A1.Flex"
8
36
  ]
9
37
  }
@@ -1,6 +1,7 @@
1
1
  {
2
2
  "BM.GPU.A10.4": "ds-gpu-a10-count",
3
3
  "BM.GPU.A100-v2.8": "ds-gpu-a100-v2-count",
4
+ "BM.GPU.H100.8": "ds-gpu-h100-count",
4
5
  "BM.GPU4.8": "ds-gpu4-count",
5
6
  "VM.GPU.A10.1": "ds-gpu-a10-count",
6
7
  "VM.GPU.A10.2": "ds-gpu-a10-count"
ads/aqua/constants.py CHANGED
@@ -21,7 +21,6 @@ DEFAULT_FT_BLOCK_STORAGE_SIZE = 750
21
21
  DEFAULT_FT_REPLICA = 1
22
22
  DEFAULT_FT_BATCH_SIZE = 1
23
23
  DEFAULT_FT_VALIDATION_SET_SIZE = 0.1
24
-
25
24
  MAXIMUM_ALLOWED_DATASET_IN_BYTE = 52428800 # 1024 x 1024 x 50 = 50MB
26
25
  JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING = "ME_STANDALONE"
27
26
  NB_SESSION_IDENTIFIER = "NB_SESSION_OCID"
@@ -34,6 +33,7 @@ AQUA_MODEL_TYPE_CUSTOM = "custom"
34
33
  AQUA_MODEL_ARTIFACT_CONFIG = "config.json"
35
34
  AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
36
35
  AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
36
+ AQUA_MODEL_ARTIFACT_FILE = "model_file"
37
37
 
38
38
  TRAINING_METRICS_FINAL = "training_metrics_final"
39
39
  VALIDATION_METRICS_FINAL = "validation_metrics_final"
@@ -74,3 +74,7 @@ TGI_INFERENCE_RESTRICTED_PARAMS = {
74
74
  "--sharded",
75
75
  "--trust-remote-code",
76
76
  }
77
+ LLAMA_CPP_INFERENCE_RESTRICTED_PARAMS = {
78
+ "--port",
79
+ "--host",
80
+ }
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
@@ -7,7 +7,7 @@ import os
7
7
  import re
8
8
  import tempfile
9
9
  from concurrent.futures import ThreadPoolExecutor, as_completed
10
- from dataclasses import asdict
10
+ from dataclasses import asdict, fields
11
11
  from datetime import datetime, timedelta
12
12
  from pathlib import Path
13
13
  from threading import Lock
@@ -76,6 +76,7 @@ from ads.aqua.evaluation.entities import (
76
76
  ModelParams,
77
77
  )
78
78
  from ads.aqua.evaluation.errors import EVALUATION_JOB_EXIT_CODE_MESSAGE
79
+ from ads.aqua.ui import AquaContainerConfig
79
80
  from ads.common.auth import default_signer
80
81
  from ads.common.object_storage_details import ObjectStorageDetails
81
82
  from ads.common.utils import get_console_link, get_files, get_log_links
@@ -90,7 +91,9 @@ from ads.jobs.builders.infrastructure.dsc_job import DataScienceJob
90
91
  from ads.jobs.builders.runtimes.base import Runtime
91
92
  from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
92
93
  from ads.model.datascience_model import DataScienceModel
94
+ from ads.model.deployment import ModelDeploymentContainerRuntime
93
95
  from ads.model.deployment.model_deployment import ModelDeployment
96
+ from ads.model.generic_model import ModelDeploymentRuntimeType
94
97
  from ads.model.model_metadata import (
95
98
  MetadataTaxonomyKeys,
96
99
  ModelCustomMetadata,
@@ -157,8 +160,9 @@ class AquaEvaluationApp(AquaApp):
157
160
  create_aqua_evaluation_details = CreateAquaEvaluationDetails(**kwargs)
158
161
  except Exception as ex:
159
162
  raise AquaValueError(
160
- "Invalid create evaluation parameters. Allowable parameters are: "
161
- f"{', '.join(list(asdict(CreateAquaEvaluationDetails).keys()))}."
163
+ "Invalid create evaluation parameters. "
164
+ "Allowable parameters are: "
165
+ f"{', '.join([field.name for field in fields(CreateAquaEvaluationDetails)])}."
162
166
  ) from ex
163
167
 
164
168
  if not is_valid_ocid(create_aqua_evaluation_details.evaluation_source_id):
@@ -166,8 +170,8 @@ class AquaEvaluationApp(AquaApp):
166
170
  f"Invalid evaluation source {create_aqua_evaluation_details.evaluation_source_id}. "
167
171
  "Specify either a model or model deployment id."
168
172
  )
169
-
170
173
  evaluation_source = None
174
+ eval_inference_configuration = None
171
175
  if (
172
176
  DataScienceResource.MODEL_DEPLOYMENT
173
177
  in create_aqua_evaluation_details.evaluation_source_id
@@ -175,6 +179,28 @@ class AquaEvaluationApp(AquaApp):
175
179
  evaluation_source = ModelDeployment.from_id(
176
180
  create_aqua_evaluation_details.evaluation_source_id
177
181
  )
182
+ try:
183
+ if (
184
+ evaluation_source.runtime.type
185
+ == ModelDeploymentRuntimeType.CONTAINER
186
+ ):
187
+ runtime = ModelDeploymentContainerRuntime.from_dict(
188
+ evaluation_source.runtime.to_dict()
189
+ )
190
+ inference_config = AquaContainerConfig.from_container_index_json(
191
+ enable_spec=True
192
+ ).inference
193
+ for container in inference_config.values():
194
+ if container.name == runtime.image.split(":")[0]:
195
+ eval_inference_configuration = (
196
+ container.spec.evaluation_configuration
197
+ )
198
+ except Exception:
199
+ logger.debug(
200
+ f"Could not load inference config details for the evaluation id: "
201
+ f"{create_aqua_evaluation_details.evaluation_source_id}. Please check if the container"
202
+ f" runtime has the correct SMC image information."
203
+ )
178
204
  elif (
179
205
  DataScienceResource.MODEL
180
206
  in create_aqua_evaluation_details.evaluation_source_id
@@ -390,6 +416,9 @@ class AquaEvaluationApp(AquaApp):
390
416
  report_path=create_aqua_evaluation_details.report_path,
391
417
  model_parameters=create_aqua_evaluation_details.model_parameters,
392
418
  metrics=create_aqua_evaluation_details.metrics,
419
+ inference_configuration=eval_inference_configuration.to_filtered_dict()
420
+ if eval_inference_configuration
421
+ else {},
393
422
  )
394
423
  ).create(**kwargs) ## TODO: decide what parameters will be needed
395
424
  logger.debug(
@@ -511,6 +540,7 @@ class AquaEvaluationApp(AquaApp):
511
540
  report_path: str,
512
541
  model_parameters: dict,
513
542
  metrics: List = None,
543
+ inference_configuration: dict = None,
514
544
  ) -> Runtime:
515
545
  """Builds evaluation runtime for Job."""
516
546
  # TODO the image name needs to be extracted from the mapping index.json file.
@@ -520,16 +550,19 @@ class AquaEvaluationApp(AquaApp):
520
550
  .with_environment_variable(
521
551
  **{
522
552
  "AIP_SMC_EVALUATION_ARGUMENTS": json.dumps(
523
- asdict(
524
- self._build_launch_cmd(
525
- evaluation_id=evaluation_id,
526
- evaluation_source_id=evaluation_source_id,
527
- dataset_path=dataset_path,
528
- report_path=report_path,
529
- model_parameters=model_parameters,
530
- metrics=metrics,
531
- )
532
- )
553
+ {
554
+ **asdict(
555
+ self._build_launch_cmd(
556
+ evaluation_id=evaluation_id,
557
+ evaluation_source_id=evaluation_source_id,
558
+ dataset_path=dataset_path,
559
+ report_path=report_path,
560
+ model_parameters=model_parameters,
561
+ metrics=metrics,
562
+ ),
563
+ ),
564
+ **(inference_configuration or {}),
565
+ },
533
566
  ),
534
567
  "CONDA_BUCKET_NS": CONDA_BUCKET_NS,
535
568
  },
@@ -0,0 +1,57 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ import json
7
+ from importlib import metadata
8
+ from typing import List, Union
9
+
10
+ from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, fetch_service_compartment
11
+ from ads.aqua.common.decorator import handle_exceptions
12
+ from ads.aqua.common.errors import AquaResourceAccessError
13
+ from ads.aqua.common.utils import known_realm
14
+ from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
15
+ from ads.aqua.extension.models.ws_models import (
16
+ AdsVersionResponse,
17
+ CompatibilityCheckResponse,
18
+ RequestResponseType,
19
+ )
20
+
21
+
22
+ class AquaCommonWsMsgHandler(AquaWSMsgHandler):
23
+ @staticmethod
24
+ def get_message_types() -> List[RequestResponseType]:
25
+ return [RequestResponseType.AdsVersion, RequestResponseType.CompatibilityCheck]
26
+
27
+ def __init__(self, message: Union[str, bytes]):
28
+ super().__init__(message)
29
+
30
+ @handle_exceptions
31
+ def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]:
32
+ request = json.loads(self.message)
33
+ if request.get("kind") == "AdsVersion":
34
+ version = metadata.version("oracle_ads")
35
+ response = AdsVersionResponse(
36
+ message_id=request.get("message_id"),
37
+ kind=RequestResponseType.AdsVersion,
38
+ data=version,
39
+ )
40
+ return response
41
+ if request.get("kind") == "CompatibilityCheck":
42
+ if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
43
+ return CompatibilityCheckResponse(
44
+ message_id=request.get("message_id"),
45
+ kind=RequestResponseType.CompatibilityCheck,
46
+ data={"status": "ok"},
47
+ )
48
+ elif known_realm():
49
+ return CompatibilityCheckResponse(
50
+ message_id=request.get("message_id"),
51
+ kind=RequestResponseType.CompatibilityCheck,
52
+ data={"status": "compatible"},
53
+ )
54
+ else:
55
+ raise AquaResourceAccessError(
56
+ "The AI Quick actions extension is not compatible in the given region."
57
+ )
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
@@ -8,8 +7,8 @@ from urllib.parse import urlparse
8
7
  from tornado.web import HTTPError
9
8
 
10
9
  from ads.aqua.common.decorator import handle_exceptions
11
- from ads.aqua.extension.errors import Errors
12
10
  from ads.aqua.extension.base_handler import AquaAPIhandler
11
+ from ads.aqua.extension.errors import Errors
13
12
  from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
14
13
  from ads.aqua.modeldeployment.entities import ModelParams
15
14
  from ads.config import COMPARTMENT_OCID, PROJECT_OCID
@@ -66,8 +65,8 @@ class AquaDeploymentHandler(AquaAPIhandler):
66
65
  """
67
66
  try:
68
67
  input_data = self.get_json_body()
69
- except Exception:
70
- raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
68
+ except Exception as ex:
69
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
71
70
 
72
71
  if not input_data:
73
72
  raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -100,6 +99,8 @@ class AquaDeploymentHandler(AquaAPIhandler):
100
99
  health_check_port = input_data.get("health_check_port")
101
100
  env_var = input_data.get("env_var")
102
101
  container_family = input_data.get("container_family")
102
+ ocpus = input_data.get("ocpus")
103
+ memory_in_gbs = input_data.get("memory_in_gbs")
103
104
 
104
105
  self.finish(
105
106
  AquaDeploymentApp().create(
@@ -119,6 +120,8 @@ class AquaDeploymentHandler(AquaAPIhandler):
119
120
  health_check_port=health_check_port,
120
121
  env_var=env_var,
121
122
  container_family=container_family,
123
+ ocpus=ocpus,
124
+ memory_in_gbs=memory_in_gbs,
122
125
  )
123
126
  )
124
127
 
@@ -153,9 +156,7 @@ class AquaDeploymentInferenceHandler(AquaAPIhandler):
153
156
  return False
154
157
  if not url.netloc:
155
158
  return False
156
- if not url.path.endswith("/predict"):
157
- return False
158
- return True
159
+ return url.path.endswith("/predict")
159
160
  except Exception:
160
161
  return False
161
162
 
@@ -170,8 +171,8 @@ class AquaDeploymentInferenceHandler(AquaAPIhandler):
170
171
  """
171
172
  try:
172
173
  input_data = self.get_json_body()
173
- except Exception:
174
- raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
174
+ except Exception as ex:
175
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
175
176
 
176
177
  if not input_data:
177
178
  raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -192,10 +193,10 @@ class AquaDeploymentInferenceHandler(AquaAPIhandler):
192
193
  )
193
194
  try:
194
195
  model_params_obj = ModelParams(**model_params)
195
- except:
196
+ except Exception as ex:
196
197
  raise HTTPError(
197
198
  400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params")
198
- )
199
+ ) from ex
199
200
 
200
201
  return self.finish(
201
202
  MDInferenceResponse(prompt, model_params_obj).get_model_deployment_response(
@@ -236,8 +237,8 @@ class AquaDeploymentParamsHandler(AquaAPIhandler):
236
237
  """
237
238
  try:
238
239
  input_data = self.get_json_body()
239
- except Exception:
240
- raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
240
+ except Exception as ex:
241
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
241
242
 
242
243
  if not input_data:
243
244
  raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -0,0 +1,54 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ import json
7
+ from typing import List, Union
8
+
9
+ from ads.aqua.common.decorator import handle_exceptions
10
+ from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
11
+ from ads.aqua.extension.models.ws_models import (
12
+ ListDeploymentResponse,
13
+ ModelDeploymentDetailsResponse,
14
+ RequestResponseType,
15
+ )
16
+ from ads.aqua.modeldeployment import AquaDeploymentApp
17
+ from ads.config import COMPARTMENT_OCID
18
+
19
+
20
+ class AquaDeploymentWSMsgHandler(AquaWSMsgHandler):
21
+ def __init__(self, message: Union[str, bytes]):
22
+ super().__init__(message)
23
+
24
+ @staticmethod
25
+ def get_message_types() -> List[RequestResponseType]:
26
+ return [
27
+ RequestResponseType.ListDeployments,
28
+ RequestResponseType.DeploymentDetails,
29
+ ]
30
+
31
+ @handle_exceptions
32
+ def process(self) -> Union[ListDeploymentResponse, ModelDeploymentDetailsResponse]:
33
+ request = json.loads(self.message)
34
+ if request.get("kind") == "ListDeployments":
35
+ deployment_list = AquaDeploymentApp().list(
36
+ compartment_id=request.get("compartment_id") or COMPARTMENT_OCID,
37
+ project_id=request.get("project_id"),
38
+ )
39
+ response = ListDeploymentResponse(
40
+ message_id=request.get("message_id"),
41
+ kind=RequestResponseType.ListDeployments,
42
+ data=deployment_list,
43
+ )
44
+ return response
45
+ elif request.get("kind") == "DeploymentDetails":
46
+ deployment_details = AquaDeploymentApp().get(
47
+ request.get("model_deployment_id")
48
+ )
49
+ response = ModelDeploymentDetailsResponse(
50
+ message_id=request.get("message_id"),
51
+ kind=RequestResponseType.DeploymentDetails,
52
+ data=deployment_details,
53
+ )
54
+ return response
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
@@ -8,3 +7,4 @@ class Errors(str):
8
7
  INVALID_INPUT_DATA_FORMAT = "Invalid format of input data."
9
8
  NO_INPUT_DATA = "No input data provided."
10
9
  MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
10
+ MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required."
@@ -3,13 +3,14 @@
3
3
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
+ import json
6
7
  from typing import List, Union
7
8
 
8
9
  from ads.aqua.common.decorator import handle_exceptions
9
10
  from ads.aqua.evaluation import AquaEvaluationApp
10
11
  from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
11
12
  from ads.aqua.extension.models.ws_models import (
12
- ListEvaluationsRequest,
13
+ EvaluationDetailsResponse,
13
14
  ListEvaluationsResponse,
14
15
  RequestResponseType,
15
16
  )
@@ -19,21 +20,42 @@ from ads.config import COMPARTMENT_OCID
19
20
  class AquaEvaluationWSMsgHandler(AquaWSMsgHandler):
20
21
  @staticmethod
21
22
  def get_message_types() -> List[RequestResponseType]:
22
- return [RequestResponseType.ListEvaluations]
23
+ return [
24
+ RequestResponseType.ListEvaluations,
25
+ RequestResponseType.EvaluationDetails,
26
+ ]
23
27
 
24
28
  def __init__(self, message: Union[str, bytes]):
25
29
  super().__init__(message)
26
30
 
27
31
  @handle_exceptions
28
- def process(self) -> ListEvaluationsResponse:
29
- list_eval_request = ListEvaluationsRequest.from_json(self.message)
32
+ def process(self) -> Union[ListEvaluationsResponse, EvaluationDetailsResponse]:
33
+ request = json.loads(self.message)
34
+ if request["kind"] == "ListEvaluations":
35
+ return self.list_evaluations(request)
36
+ if request["kind"] == "EvaluationDetails":
37
+ return self.evaluation_details(request)
30
38
 
39
+ @staticmethod
40
+ def list_evaluations(request) -> ListEvaluationsResponse:
31
41
  eval_list = AquaEvaluationApp().list(
32
- list_eval_request.compartment_id or COMPARTMENT_OCID,
42
+ request.get("compartment_id") or COMPARTMENT_OCID
33
43
  )
34
44
  response = ListEvaluationsResponse(
35
- message_id=list_eval_request.message_id,
45
+ message_id=request["message_id"],
36
46
  kind=RequestResponseType.ListEvaluations,
37
47
  data=eval_list,
38
48
  )
39
49
  return response
50
+
51
+ @staticmethod
52
+ def evaluation_details(request) -> EvaluationDetailsResponse:
53
+ evaluation_details = AquaEvaluationApp().get(
54
+ eval_id=request.get("evaluation_id")
55
+ )
56
+ response = EvaluationDetailsResponse(
57
+ message_id=request.get("message_id"),
58
+ kind=RequestResponseType.EvaluationDetails,
59
+ data=evaluation_details,
60
+ )
61
+ return response