oracle-ads 2.11.15__py3-none-any.whl → 2.11.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/common/entities.py +17 -0
- ads/aqua/common/enums.py +5 -1
- ads/aqua/common/utils.py +32 -2
- ads/aqua/config/config.py +1 -1
- ads/aqua/config/deployment_config_defaults.json +29 -1
- ads/aqua/config/resource_limit_names.json +1 -0
- ads/aqua/constants.py +5 -1
- ads/aqua/evaluation/entities.py +0 -1
- ads/aqua/evaluation/evaluation.py +47 -14
- ads/aqua/extension/common_ws_msg_handler.py +57 -0
- ads/aqua/extension/deployment_handler.py +14 -13
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +1 -1
- ads/aqua/extension/evaluation_ws_msg_handler.py +28 -6
- ads/aqua/extension/model_handler.py +31 -6
- ads/aqua/extension/models/ws_models.py +78 -3
- ads/aqua/extension/models_ws_msg_handler.py +49 -0
- ads/aqua/extension/ui_websocket_handler.py +7 -1
- ads/aqua/model/entities.py +11 -1
- ads/aqua/model/model.py +260 -90
- ads/aqua/modeldeployment/deployment.py +52 -7
- ads/aqua/modeldeployment/entities.py +9 -20
- ads/aqua/ui.py +152 -28
- ads/common/object_storage_details.py +2 -5
- ads/common/serializer.py +2 -3
- ads/jobs/builders/infrastructure/dsc_job.py +29 -3
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +74 -27
- ads/jobs/builders/runtimes/container_runtime.py +83 -4
- ads/opctl/operator/lowcode/anomaly/const.py +1 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +23 -7
- ads/opctl/operator/lowcode/anomaly/operator_config.py +1 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +4 -0
- ads/opctl/operator/lowcode/common/errors.py +6 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +21 -13
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +11 -2
- ads/pipeline/ads_pipeline_run.py +13 -2
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/METADATA +1 -1
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/RECORD +41 -37
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/WHEEL +0 -0
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,17 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
3
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
+
|
5
|
+
|
6
|
+
class ContainerSpec:
|
7
|
+
"""
|
8
|
+
Class to hold to hold keys within the container spec.
|
9
|
+
"""
|
10
|
+
|
11
|
+
CONTAINER_SPEC = "containerSpec"
|
12
|
+
CLI_PARM = "cliParam"
|
13
|
+
SERVER_PORT = "serverPort"
|
14
|
+
HEALTH_CHECK_PORT = "healthCheckPort"
|
15
|
+
ENV_VARS = "envVars"
|
16
|
+
RESTRICTED_PARAMS = "restrictedParams"
|
17
|
+
EVALUATION_CONFIGURATION = "evaluationConfiguration"
|
ads/aqua/common/enums.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
@@ -8,6 +7,7 @@ aqua.common.enums
|
|
8
7
|
~~~~~~~~~~~~~~
|
9
8
|
This module contains the set of enums used in AQUA.
|
10
9
|
"""
|
10
|
+
|
11
11
|
from ads.common.extended_enum import ExtendedEnumMeta
|
12
12
|
|
13
13
|
|
@@ -38,21 +38,25 @@ class Tags(str, metaclass=ExtendedEnumMeta):
|
|
38
38
|
READY_TO_IMPORT = "ready_to_import"
|
39
39
|
BASE_MODEL_CUSTOM = "aqua_custom_base_model"
|
40
40
|
AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
|
41
|
+
MODEL_FORMAT = "model_format"
|
41
42
|
|
42
43
|
|
43
44
|
class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
|
44
45
|
CONTAINER_TYPE_VLLM = "vllm"
|
45
46
|
CONTAINER_TYPE_TGI = "tgi"
|
47
|
+
CONTAINER_TYPE_LLAMA_CPP = "llama-cpp"
|
46
48
|
|
47
49
|
|
48
50
|
class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
|
49
51
|
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
|
50
52
|
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
|
53
|
+
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
|
51
54
|
|
52
55
|
|
53
56
|
class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
|
54
57
|
PARAM_TYPE_VLLM = "VLLM_PARAMS"
|
55
58
|
PARAM_TYPE_TGI = "TGI_PARAMS"
|
59
|
+
PARAM_TYPE_LLAMA_CPP = "LLAMA_CPP_PARAMS"
|
56
60
|
|
57
61
|
|
58
62
|
class HuggingFaceTags(str, metaclass=ExtendedEnumMeta):
|
ads/aqua/common/utils.py
CHANGED
@@ -10,6 +10,7 @@ import logging
|
|
10
10
|
import os
|
11
11
|
import random
|
12
12
|
import re
|
13
|
+
from datetime import datetime, timedelta
|
13
14
|
from functools import wraps
|
14
15
|
from pathlib import Path
|
15
16
|
from string import Template
|
@@ -17,7 +18,9 @@ from typing import List, Union
|
|
17
18
|
|
18
19
|
import fsspec
|
19
20
|
import oci
|
21
|
+
from cachetools import TTLCache, cached
|
20
22
|
from oci.data_science.models import JobRun, Model
|
23
|
+
from oci.object_storage.models import ObjectSummary
|
21
24
|
|
22
25
|
from ads.aqua.common.enums import (
|
23
26
|
InferenceContainerParamType,
|
@@ -45,7 +48,6 @@ from ads.aqua.constants import (
|
|
45
48
|
)
|
46
49
|
from ads.aqua.data import AquaResourceIdentifier
|
47
50
|
from ads.common.auth import default_signer
|
48
|
-
from ads.common.decorator.threaded import threaded
|
49
51
|
from ads.common.extended_enum import ExtendedEnumMeta
|
50
52
|
from ads.common.object_storage_details import ObjectStorageDetails
|
51
53
|
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
|
@@ -213,7 +215,6 @@ def read_file(file_path: str, **kwargs) -> str:
|
|
213
215
|
return UNKNOWN
|
214
216
|
|
215
217
|
|
216
|
-
@threaded()
|
217
218
|
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
|
218
219
|
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
|
219
220
|
signer = default_signer() if artifact_path.startswith("oci://") else {}
|
@@ -228,6 +229,32 @@ def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
|
|
228
229
|
return config
|
229
230
|
|
230
231
|
|
232
|
+
def list_os_files_with_extension(oss_path: str, extension: str) -> [str]:
|
233
|
+
"""
|
234
|
+
List files in the specified directory with the given extension.
|
235
|
+
|
236
|
+
Parameters:
|
237
|
+
- oss_path: The path to the directory where files are located.
|
238
|
+
- extension: The file extension to filter by (e.g., 'txt' for text files).
|
239
|
+
|
240
|
+
Returns:
|
241
|
+
- A list of file paths matching the specified extension.
|
242
|
+
"""
|
243
|
+
|
244
|
+
oss_client = ObjectStorageDetails.from_path(oss_path)
|
245
|
+
|
246
|
+
# Ensure the extension is prefixed with a dot if not already
|
247
|
+
if not extension.startswith("."):
|
248
|
+
extension = "." + extension
|
249
|
+
files: List[ObjectSummary] = oss_client.list_objects().objects
|
250
|
+
|
251
|
+
return [
|
252
|
+
file.name[len(oss_client.filepath) :].lstrip("/")
|
253
|
+
for file in files
|
254
|
+
if file.name.endswith(extension)
|
255
|
+
]
|
256
|
+
|
257
|
+
|
231
258
|
def is_valid_ocid(ocid: str) -> bool:
|
232
259
|
"""Checks if the given ocid is valid.
|
233
260
|
|
@@ -503,6 +530,7 @@ def container_config_path():
|
|
503
530
|
return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
|
504
531
|
|
505
532
|
|
533
|
+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
|
506
534
|
def get_container_config():
|
507
535
|
config = load_config(
|
508
536
|
file_path=container_config_path(),
|
@@ -881,6 +909,8 @@ def get_container_params_type(container_type_name: str) -> str:
|
|
881
909
|
return InferenceContainerParamType.PARAM_TYPE_VLLM
|
882
910
|
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
|
883
911
|
return InferenceContainerParamType.PARAM_TYPE_TGI
|
912
|
+
elif InferenceContainerType.CONTAINER_TYPE_LLAMA_CPP in container_type_name.lower():
|
913
|
+
return InferenceContainerParamType.PARAM_TYPE_LLAMA_CPP
|
884
914
|
else:
|
885
915
|
return UNKNOWN
|
886
916
|
|
ads/aqua/config/config.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
@@ -14,5 +13,6 @@ def get_finetuning_config_defaults():
|
|
14
13
|
"BM.GPU.A10.4": {"batch_size": 1, "replica": 1},
|
15
14
|
"BM.GPU4.8": {"batch_size": 4, "replica": 1},
|
16
15
|
"BM.GPU.A100-v2.8": {"batch_size": 6, "replica": 1},
|
16
|
+
"BM.GPU.H100.8": {"batch_size": 6, "replica": 1},
|
17
17
|
}
|
18
18
|
}
|
@@ -1,9 +1,37 @@
|
|
1
1
|
{
|
2
|
+
"configuration": {
|
3
|
+
"VM.Standard.A1.Flex": {
|
4
|
+
"parameters": {},
|
5
|
+
"shape_info": {
|
6
|
+
"configs": [
|
7
|
+
{
|
8
|
+
"memory_in_gbs": 128,
|
9
|
+
"ocpu": 20
|
10
|
+
},
|
11
|
+
{
|
12
|
+
"memory_in_gbs": 256,
|
13
|
+
"ocpu": 40
|
14
|
+
},
|
15
|
+
{
|
16
|
+
"memory_in_gbs": 384,
|
17
|
+
"ocpu": 60
|
18
|
+
},
|
19
|
+
{
|
20
|
+
"memory_in_gbs": 512,
|
21
|
+
"ocpu": 80
|
22
|
+
}
|
23
|
+
],
|
24
|
+
"type": "CPU"
|
25
|
+
}
|
26
|
+
}
|
27
|
+
},
|
2
28
|
"shape": [
|
3
29
|
"VM.GPU.A10.1",
|
4
30
|
"VM.GPU.A10.2",
|
5
31
|
"BM.GPU.A10.4",
|
6
32
|
"BM.GPU4.8",
|
7
|
-
"BM.GPU.A100-v2.8"
|
33
|
+
"BM.GPU.A100-v2.8",
|
34
|
+
"BM.GPU.H100.8",
|
35
|
+
"VM.Standard.A1.Flex"
|
8
36
|
]
|
9
37
|
}
|
ads/aqua/constants.py
CHANGED
@@ -21,7 +21,6 @@ DEFAULT_FT_BLOCK_STORAGE_SIZE = 750
|
|
21
21
|
DEFAULT_FT_REPLICA = 1
|
22
22
|
DEFAULT_FT_BATCH_SIZE = 1
|
23
23
|
DEFAULT_FT_VALIDATION_SET_SIZE = 0.1
|
24
|
-
|
25
24
|
MAXIMUM_ALLOWED_DATASET_IN_BYTE = 52428800 # 1024 x 1024 x 50 = 50MB
|
26
25
|
JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING = "ME_STANDALONE"
|
27
26
|
NB_SESSION_IDENTIFIER = "NB_SESSION_OCID"
|
@@ -34,6 +33,7 @@ AQUA_MODEL_TYPE_CUSTOM = "custom"
|
|
34
33
|
AQUA_MODEL_ARTIFACT_CONFIG = "config.json"
|
35
34
|
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
|
36
35
|
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
|
36
|
+
AQUA_MODEL_ARTIFACT_FILE = "model_file"
|
37
37
|
|
38
38
|
TRAINING_METRICS_FINAL = "training_metrics_final"
|
39
39
|
VALIDATION_METRICS_FINAL = "validation_metrics_final"
|
@@ -74,3 +74,7 @@ TGI_INFERENCE_RESTRICTED_PARAMS = {
|
|
74
74
|
"--sharded",
|
75
75
|
"--trust-remote-code",
|
76
76
|
}
|
77
|
+
LLAMA_CPP_INFERENCE_RESTRICTED_PARAMS = {
|
78
|
+
"--port",
|
79
|
+
"--host",
|
80
|
+
}
|
ads/aqua/evaluation/entities.py
CHANGED
@@ -7,7 +7,7 @@ import os
|
|
7
7
|
import re
|
8
8
|
import tempfile
|
9
9
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
10
|
-
from dataclasses import asdict
|
10
|
+
from dataclasses import asdict, fields
|
11
11
|
from datetime import datetime, timedelta
|
12
12
|
from pathlib import Path
|
13
13
|
from threading import Lock
|
@@ -76,6 +76,7 @@ from ads.aqua.evaluation.entities import (
|
|
76
76
|
ModelParams,
|
77
77
|
)
|
78
78
|
from ads.aqua.evaluation.errors import EVALUATION_JOB_EXIT_CODE_MESSAGE
|
79
|
+
from ads.aqua.ui import AquaContainerConfig
|
79
80
|
from ads.common.auth import default_signer
|
80
81
|
from ads.common.object_storage_details import ObjectStorageDetails
|
81
82
|
from ads.common.utils import get_console_link, get_files, get_log_links
|
@@ -90,7 +91,9 @@ from ads.jobs.builders.infrastructure.dsc_job import DataScienceJob
|
|
90
91
|
from ads.jobs.builders.runtimes.base import Runtime
|
91
92
|
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
|
92
93
|
from ads.model.datascience_model import DataScienceModel
|
94
|
+
from ads.model.deployment import ModelDeploymentContainerRuntime
|
93
95
|
from ads.model.deployment.model_deployment import ModelDeployment
|
96
|
+
from ads.model.generic_model import ModelDeploymentRuntimeType
|
94
97
|
from ads.model.model_metadata import (
|
95
98
|
MetadataTaxonomyKeys,
|
96
99
|
ModelCustomMetadata,
|
@@ -157,8 +160,9 @@ class AquaEvaluationApp(AquaApp):
|
|
157
160
|
create_aqua_evaluation_details = CreateAquaEvaluationDetails(**kwargs)
|
158
161
|
except Exception as ex:
|
159
162
|
raise AquaValueError(
|
160
|
-
"Invalid create evaluation parameters.
|
161
|
-
|
163
|
+
"Invalid create evaluation parameters. "
|
164
|
+
"Allowable parameters are: "
|
165
|
+
f"{', '.join([field.name for field in fields(CreateAquaEvaluationDetails)])}."
|
162
166
|
) from ex
|
163
167
|
|
164
168
|
if not is_valid_ocid(create_aqua_evaluation_details.evaluation_source_id):
|
@@ -166,8 +170,8 @@ class AquaEvaluationApp(AquaApp):
|
|
166
170
|
f"Invalid evaluation source {create_aqua_evaluation_details.evaluation_source_id}. "
|
167
171
|
"Specify either a model or model deployment id."
|
168
172
|
)
|
169
|
-
|
170
173
|
evaluation_source = None
|
174
|
+
eval_inference_configuration = None
|
171
175
|
if (
|
172
176
|
DataScienceResource.MODEL_DEPLOYMENT
|
173
177
|
in create_aqua_evaluation_details.evaluation_source_id
|
@@ -175,6 +179,28 @@ class AquaEvaluationApp(AquaApp):
|
|
175
179
|
evaluation_source = ModelDeployment.from_id(
|
176
180
|
create_aqua_evaluation_details.evaluation_source_id
|
177
181
|
)
|
182
|
+
try:
|
183
|
+
if (
|
184
|
+
evaluation_source.runtime.type
|
185
|
+
== ModelDeploymentRuntimeType.CONTAINER
|
186
|
+
):
|
187
|
+
runtime = ModelDeploymentContainerRuntime.from_dict(
|
188
|
+
evaluation_source.runtime.to_dict()
|
189
|
+
)
|
190
|
+
inference_config = AquaContainerConfig.from_container_index_json(
|
191
|
+
enable_spec=True
|
192
|
+
).inference
|
193
|
+
for container in inference_config.values():
|
194
|
+
if container.name == runtime.image.split(":")[0]:
|
195
|
+
eval_inference_configuration = (
|
196
|
+
container.spec.evaluation_configuration
|
197
|
+
)
|
198
|
+
except Exception:
|
199
|
+
logger.debug(
|
200
|
+
f"Could not load inference config details for the evaluation id: "
|
201
|
+
f"{create_aqua_evaluation_details.evaluation_source_id}. Please check if the container"
|
202
|
+
f" runtime has the correct SMC image information."
|
203
|
+
)
|
178
204
|
elif (
|
179
205
|
DataScienceResource.MODEL
|
180
206
|
in create_aqua_evaluation_details.evaluation_source_id
|
@@ -390,6 +416,9 @@ class AquaEvaluationApp(AquaApp):
|
|
390
416
|
report_path=create_aqua_evaluation_details.report_path,
|
391
417
|
model_parameters=create_aqua_evaluation_details.model_parameters,
|
392
418
|
metrics=create_aqua_evaluation_details.metrics,
|
419
|
+
inference_configuration=eval_inference_configuration.to_filtered_dict()
|
420
|
+
if eval_inference_configuration
|
421
|
+
else {},
|
393
422
|
)
|
394
423
|
).create(**kwargs) ## TODO: decide what parameters will be needed
|
395
424
|
logger.debug(
|
@@ -511,6 +540,7 @@ class AquaEvaluationApp(AquaApp):
|
|
511
540
|
report_path: str,
|
512
541
|
model_parameters: dict,
|
513
542
|
metrics: List = None,
|
543
|
+
inference_configuration: dict = None,
|
514
544
|
) -> Runtime:
|
515
545
|
"""Builds evaluation runtime for Job."""
|
516
546
|
# TODO the image name needs to be extracted from the mapping index.json file.
|
@@ -520,16 +550,19 @@ class AquaEvaluationApp(AquaApp):
|
|
520
550
|
.with_environment_variable(
|
521
551
|
**{
|
522
552
|
"AIP_SMC_EVALUATION_ARGUMENTS": json.dumps(
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
553
|
+
{
|
554
|
+
**asdict(
|
555
|
+
self._build_launch_cmd(
|
556
|
+
evaluation_id=evaluation_id,
|
557
|
+
evaluation_source_id=evaluation_source_id,
|
558
|
+
dataset_path=dataset_path,
|
559
|
+
report_path=report_path,
|
560
|
+
model_parameters=model_parameters,
|
561
|
+
metrics=metrics,
|
562
|
+
),
|
563
|
+
),
|
564
|
+
**(inference_configuration or {}),
|
565
|
+
},
|
533
566
|
),
|
534
567
|
"CONDA_BUCKET_NS": CONDA_BUCKET_NS,
|
535
568
|
},
|
@@ -0,0 +1,57 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
import json
|
7
|
+
from importlib import metadata
|
8
|
+
from typing import List, Union
|
9
|
+
|
10
|
+
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, fetch_service_compartment
|
11
|
+
from ads.aqua.common.decorator import handle_exceptions
|
12
|
+
from ads.aqua.common.errors import AquaResourceAccessError
|
13
|
+
from ads.aqua.common.utils import known_realm
|
14
|
+
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
|
15
|
+
from ads.aqua.extension.models.ws_models import (
|
16
|
+
AdsVersionResponse,
|
17
|
+
CompatibilityCheckResponse,
|
18
|
+
RequestResponseType,
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
class AquaCommonWsMsgHandler(AquaWSMsgHandler):
|
23
|
+
@staticmethod
|
24
|
+
def get_message_types() -> List[RequestResponseType]:
|
25
|
+
return [RequestResponseType.AdsVersion, RequestResponseType.CompatibilityCheck]
|
26
|
+
|
27
|
+
def __init__(self, message: Union[str, bytes]):
|
28
|
+
super().__init__(message)
|
29
|
+
|
30
|
+
@handle_exceptions
|
31
|
+
def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]:
|
32
|
+
request = json.loads(self.message)
|
33
|
+
if request.get("kind") == "AdsVersion":
|
34
|
+
version = metadata.version("oracle_ads")
|
35
|
+
response = AdsVersionResponse(
|
36
|
+
message_id=request.get("message_id"),
|
37
|
+
kind=RequestResponseType.AdsVersion,
|
38
|
+
data=version,
|
39
|
+
)
|
40
|
+
return response
|
41
|
+
if request.get("kind") == "CompatibilityCheck":
|
42
|
+
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
|
43
|
+
return CompatibilityCheckResponse(
|
44
|
+
message_id=request.get("message_id"),
|
45
|
+
kind=RequestResponseType.CompatibilityCheck,
|
46
|
+
data={"status": "ok"},
|
47
|
+
)
|
48
|
+
elif known_realm():
|
49
|
+
return CompatibilityCheckResponse(
|
50
|
+
message_id=request.get("message_id"),
|
51
|
+
kind=RequestResponseType.CompatibilityCheck,
|
52
|
+
data={"status": "compatible"},
|
53
|
+
)
|
54
|
+
else:
|
55
|
+
raise AquaResourceAccessError(
|
56
|
+
"The AI Quick actions extension is not compatible in the given region."
|
57
|
+
)
|
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
@@ -8,8 +7,8 @@ from urllib.parse import urlparse
|
|
8
7
|
from tornado.web import HTTPError
|
9
8
|
|
10
9
|
from ads.aqua.common.decorator import handle_exceptions
|
11
|
-
from ads.aqua.extension.errors import Errors
|
12
10
|
from ads.aqua.extension.base_handler import AquaAPIhandler
|
11
|
+
from ads.aqua.extension.errors import Errors
|
13
12
|
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
|
14
13
|
from ads.aqua.modeldeployment.entities import ModelParams
|
15
14
|
from ads.config import COMPARTMENT_OCID, PROJECT_OCID
|
@@ -66,8 +65,8 @@ class AquaDeploymentHandler(AquaAPIhandler):
|
|
66
65
|
"""
|
67
66
|
try:
|
68
67
|
input_data = self.get_json_body()
|
69
|
-
except Exception:
|
70
|
-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
68
|
+
except Exception as ex:
|
69
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
|
71
70
|
|
72
71
|
if not input_data:
|
73
72
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
@@ -100,6 +99,8 @@ class AquaDeploymentHandler(AquaAPIhandler):
|
|
100
99
|
health_check_port = input_data.get("health_check_port")
|
101
100
|
env_var = input_data.get("env_var")
|
102
101
|
container_family = input_data.get("container_family")
|
102
|
+
ocpus = input_data.get("ocpus")
|
103
|
+
memory_in_gbs = input_data.get("memory_in_gbs")
|
103
104
|
|
104
105
|
self.finish(
|
105
106
|
AquaDeploymentApp().create(
|
@@ -119,6 +120,8 @@ class AquaDeploymentHandler(AquaAPIhandler):
|
|
119
120
|
health_check_port=health_check_port,
|
120
121
|
env_var=env_var,
|
121
122
|
container_family=container_family,
|
123
|
+
ocpus=ocpus,
|
124
|
+
memory_in_gbs=memory_in_gbs,
|
122
125
|
)
|
123
126
|
)
|
124
127
|
|
@@ -153,9 +156,7 @@ class AquaDeploymentInferenceHandler(AquaAPIhandler):
|
|
153
156
|
return False
|
154
157
|
if not url.netloc:
|
155
158
|
return False
|
156
|
-
|
157
|
-
return False
|
158
|
-
return True
|
159
|
+
return url.path.endswith("/predict")
|
159
160
|
except Exception:
|
160
161
|
return False
|
161
162
|
|
@@ -170,8 +171,8 @@ class AquaDeploymentInferenceHandler(AquaAPIhandler):
|
|
170
171
|
"""
|
171
172
|
try:
|
172
173
|
input_data = self.get_json_body()
|
173
|
-
except Exception:
|
174
|
-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
174
|
+
except Exception as ex:
|
175
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
|
175
176
|
|
176
177
|
if not input_data:
|
177
178
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
@@ -192,10 +193,10 @@ class AquaDeploymentInferenceHandler(AquaAPIhandler):
|
|
192
193
|
)
|
193
194
|
try:
|
194
195
|
model_params_obj = ModelParams(**model_params)
|
195
|
-
except:
|
196
|
+
except Exception as ex:
|
196
197
|
raise HTTPError(
|
197
198
|
400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params")
|
198
|
-
)
|
199
|
+
) from ex
|
199
200
|
|
200
201
|
return self.finish(
|
201
202
|
MDInferenceResponse(prompt, model_params_obj).get_model_deployment_response(
|
@@ -236,8 +237,8 @@ class AquaDeploymentParamsHandler(AquaAPIhandler):
|
|
236
237
|
"""
|
237
238
|
try:
|
238
239
|
input_data = self.get_json_body()
|
239
|
-
except Exception:
|
240
|
-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
240
|
+
except Exception as ex:
|
241
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
|
241
242
|
|
242
243
|
if not input_data:
|
243
244
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
import json
|
7
|
+
from typing import List, Union
|
8
|
+
|
9
|
+
from ads.aqua.common.decorator import handle_exceptions
|
10
|
+
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
|
11
|
+
from ads.aqua.extension.models.ws_models import (
|
12
|
+
ListDeploymentResponse,
|
13
|
+
ModelDeploymentDetailsResponse,
|
14
|
+
RequestResponseType,
|
15
|
+
)
|
16
|
+
from ads.aqua.modeldeployment import AquaDeploymentApp
|
17
|
+
from ads.config import COMPARTMENT_OCID
|
18
|
+
|
19
|
+
|
20
|
+
class AquaDeploymentWSMsgHandler(AquaWSMsgHandler):
|
21
|
+
def __init__(self, message: Union[str, bytes]):
|
22
|
+
super().__init__(message)
|
23
|
+
|
24
|
+
@staticmethod
|
25
|
+
def get_message_types() -> List[RequestResponseType]:
|
26
|
+
return [
|
27
|
+
RequestResponseType.ListDeployments,
|
28
|
+
RequestResponseType.DeploymentDetails,
|
29
|
+
]
|
30
|
+
|
31
|
+
@handle_exceptions
|
32
|
+
def process(self) -> Union[ListDeploymentResponse, ModelDeploymentDetailsResponse]:
|
33
|
+
request = json.loads(self.message)
|
34
|
+
if request.get("kind") == "ListDeployments":
|
35
|
+
deployment_list = AquaDeploymentApp().list(
|
36
|
+
compartment_id=request.get("compartment_id") or COMPARTMENT_OCID,
|
37
|
+
project_id=request.get("project_id"),
|
38
|
+
)
|
39
|
+
response = ListDeploymentResponse(
|
40
|
+
message_id=request.get("message_id"),
|
41
|
+
kind=RequestResponseType.ListDeployments,
|
42
|
+
data=deployment_list,
|
43
|
+
)
|
44
|
+
return response
|
45
|
+
elif request.get("kind") == "DeploymentDetails":
|
46
|
+
deployment_details = AquaDeploymentApp().get(
|
47
|
+
request.get("model_deployment_id")
|
48
|
+
)
|
49
|
+
response = ModelDeploymentDetailsResponse(
|
50
|
+
message_id=request.get("message_id"),
|
51
|
+
kind=RequestResponseType.DeploymentDetails,
|
52
|
+
data=deployment_details,
|
53
|
+
)
|
54
|
+
return response
|
ads/aqua/extension/errors.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
@@ -8,3 +7,4 @@ class Errors(str):
|
|
8
7
|
INVALID_INPUT_DATA_FORMAT = "Invalid format of input data."
|
9
8
|
NO_INPUT_DATA = "No input data provided."
|
10
9
|
MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
|
10
|
+
MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required."
|
@@ -3,13 +3,14 @@
|
|
3
3
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
|
6
|
+
import json
|
6
7
|
from typing import List, Union
|
7
8
|
|
8
9
|
from ads.aqua.common.decorator import handle_exceptions
|
9
10
|
from ads.aqua.evaluation import AquaEvaluationApp
|
10
11
|
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
|
11
12
|
from ads.aqua.extension.models.ws_models import (
|
12
|
-
|
13
|
+
EvaluationDetailsResponse,
|
13
14
|
ListEvaluationsResponse,
|
14
15
|
RequestResponseType,
|
15
16
|
)
|
@@ -19,21 +20,42 @@ from ads.config import COMPARTMENT_OCID
|
|
19
20
|
class AquaEvaluationWSMsgHandler(AquaWSMsgHandler):
|
20
21
|
@staticmethod
|
21
22
|
def get_message_types() -> List[RequestResponseType]:
|
22
|
-
return [
|
23
|
+
return [
|
24
|
+
RequestResponseType.ListEvaluations,
|
25
|
+
RequestResponseType.EvaluationDetails,
|
26
|
+
]
|
23
27
|
|
24
28
|
def __init__(self, message: Union[str, bytes]):
|
25
29
|
super().__init__(message)
|
26
30
|
|
27
31
|
@handle_exceptions
|
28
|
-
def process(self) -> ListEvaluationsResponse:
|
29
|
-
|
32
|
+
def process(self) -> Union[ListEvaluationsResponse, EvaluationDetailsResponse]:
|
33
|
+
request = json.loads(self.message)
|
34
|
+
if request["kind"] == "ListEvaluations":
|
35
|
+
return self.list_evaluations(request)
|
36
|
+
if request["kind"] == "EvaluationDetails":
|
37
|
+
return self.evaluation_details(request)
|
30
38
|
|
39
|
+
@staticmethod
|
40
|
+
def list_evaluations(request) -> ListEvaluationsResponse:
|
31
41
|
eval_list = AquaEvaluationApp().list(
|
32
|
-
|
42
|
+
request.get("compartment_id") or COMPARTMENT_OCID
|
33
43
|
)
|
34
44
|
response = ListEvaluationsResponse(
|
35
|
-
message_id=
|
45
|
+
message_id=request["message_id"],
|
36
46
|
kind=RequestResponseType.ListEvaluations,
|
37
47
|
data=eval_list,
|
38
48
|
)
|
39
49
|
return response
|
50
|
+
|
51
|
+
@staticmethod
|
52
|
+
def evaluation_details(request) -> EvaluationDetailsResponse:
|
53
|
+
evaluation_details = AquaEvaluationApp().get(
|
54
|
+
eval_id=request.get("evaluation_id")
|
55
|
+
)
|
56
|
+
response = EvaluationDetailsResponse(
|
57
|
+
message_id=request.get("message_id"),
|
58
|
+
kind=RequestResponseType.EvaluationDetails,
|
59
|
+
data=evaluation_details,
|
60
|
+
)
|
61
|
+
return response
|