oracle-ads 2.11.15__py3-none-any.whl → 2.11.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/app.py +5 -6
- ads/aqua/common/entities.py +17 -0
- ads/aqua/common/enums.py +14 -1
- ads/aqua/common/utils.py +160 -3
- ads/aqua/config/config.py +1 -1
- ads/aqua/config/deployment_config_defaults.json +29 -1
- ads/aqua/config/resource_limit_names.json +1 -0
- ads/aqua/constants.py +6 -1
- ads/aqua/evaluation/entities.py +0 -1
- ads/aqua/evaluation/evaluation.py +47 -14
- ads/aqua/extension/common_handler.py +75 -5
- ads/aqua/extension/common_ws_msg_handler.py +57 -0
- ads/aqua/extension/deployment_handler.py +16 -13
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +1 -1
- ads/aqua/extension/evaluation_ws_msg_handler.py +28 -6
- ads/aqua/extension/model_handler.py +134 -8
- ads/aqua/extension/models/ws_models.py +78 -3
- ads/aqua/extension/models_ws_msg_handler.py +49 -0
- ads/aqua/extension/ui_websocket_handler.py +7 -1
- ads/aqua/model/entities.py +28 -0
- ads/aqua/model/model.py +544 -129
- ads/aqua/modeldeployment/deployment.py +102 -43
- ads/aqua/modeldeployment/entities.py +9 -20
- ads/aqua/ui.py +152 -28
- ads/common/object_storage_details.py +2 -5
- ads/common/serializer.py +2 -3
- ads/jobs/builders/infrastructure/dsc_job.py +41 -12
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +74 -27
- ads/jobs/builders/runtimes/container_runtime.py +83 -4
- ads/opctl/operator/lowcode/anomaly/const.py +1 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +23 -7
- ads/opctl/operator/lowcode/anomaly/operator_config.py +1 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +4 -0
- ads/opctl/operator/lowcode/common/errors.py +6 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +3 -1
- ads/opctl/operator/lowcode/forecast/model/base_model.py +21 -13
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +11 -2
- ads/pipeline/ads_pipeline_run.py +13 -2
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/METADATA +2 -1
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/RECORD +44 -40
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/WHEEL +0 -0
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/entry_points.txt +0 -0
ads/aqua/app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
@@ -175,7 +174,7 @@ class AquaApp:
|
|
175
174
|
f"Invalid model version set name. Please provide a model version set with `{tag}` in tags."
|
176
175
|
)
|
177
176
|
|
178
|
-
except:
|
177
|
+
except Exception:
|
179
178
|
logger.debug(
|
180
179
|
f"Model version set {model_version_set_name} doesn't exist. "
|
181
180
|
"Creating new model version set."
|
@@ -254,7 +253,7 @@ class AquaApp:
|
|
254
253
|
|
255
254
|
try:
|
256
255
|
response = self.ds_client.head_model_artifact(model_id=model_id, **kwargs)
|
257
|
-
return
|
256
|
+
return response.status == 200
|
258
257
|
except oci.exceptions.ServiceError as ex:
|
259
258
|
if ex.status == 404:
|
260
259
|
logger.info(f"Artifact not found in model {model_id}.")
|
@@ -302,7 +301,7 @@ class AquaApp:
|
|
302
301
|
config_path,
|
303
302
|
config_file_name=config_file_name,
|
304
303
|
)
|
305
|
-
except:
|
304
|
+
except Exception:
|
306
305
|
# todo: temp fix for issue related to config load for byom models, update logic to choose the right path
|
307
306
|
try:
|
308
307
|
config_path = f"{artifact_path.rstrip('/')}/config/"
|
@@ -310,7 +309,7 @@ class AquaApp:
|
|
310
309
|
config_path,
|
311
310
|
config_file_name=config_file_name,
|
312
311
|
)
|
313
|
-
except:
|
312
|
+
except Exception:
|
314
313
|
pass
|
315
314
|
|
316
315
|
if not config:
|
@@ -343,7 +342,7 @@ class CLIBuilderMixin:
|
|
343
342
|
params = [
|
344
343
|
f"--{field.name} {getattr(self,field.name)}"
|
345
344
|
for field in fields(self.__class__)
|
346
|
-
if getattr(self, field.name)
|
345
|
+
if getattr(self, field.name) is not None
|
347
346
|
]
|
348
347
|
cmd = f"{cmd} {' '.join(params)}"
|
349
348
|
return cmd
|
@@ -0,0 +1,17 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
3
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
+
|
5
|
+
|
6
|
+
class ContainerSpec:
|
7
|
+
"""
|
8
|
+
Class to hold to hold keys within the container spec.
|
9
|
+
"""
|
10
|
+
|
11
|
+
CONTAINER_SPEC = "containerSpec"
|
12
|
+
CLI_PARM = "cliParam"
|
13
|
+
SERVER_PORT = "serverPort"
|
14
|
+
HEALTH_CHECK_PORT = "healthCheckPort"
|
15
|
+
ENV_VARS = "envVars"
|
16
|
+
RESTRICTED_PARAMS = "restrictedParams"
|
17
|
+
EVALUATION_CONFIGURATION = "evaluationConfiguration"
|
ads/aqua/common/enums.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
@@ -8,6 +7,7 @@ aqua.common.enums
|
|
8
7
|
~~~~~~~~~~~~~~
|
9
8
|
This module contains the set of enums used in AQUA.
|
10
9
|
"""
|
10
|
+
|
11
11
|
from ads.common.extended_enum import ExtendedEnumMeta
|
12
12
|
|
13
13
|
|
@@ -38,21 +38,34 @@ class Tags(str, metaclass=ExtendedEnumMeta):
|
|
38
38
|
READY_TO_IMPORT = "ready_to_import"
|
39
39
|
BASE_MODEL_CUSTOM = "aqua_custom_base_model"
|
40
40
|
AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
|
41
|
+
MODEL_FORMAT = "model_format"
|
42
|
+
MODEL_ARTIFACT_FILE = "model_file"
|
41
43
|
|
42
44
|
|
43
45
|
class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
|
44
46
|
CONTAINER_TYPE_VLLM = "vllm"
|
45
47
|
CONTAINER_TYPE_TGI = "tgi"
|
48
|
+
CONTAINER_TYPE_LLAMA_CPP = "llama-cpp"
|
46
49
|
|
47
50
|
|
48
51
|
class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
|
49
52
|
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
|
50
53
|
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
|
54
|
+
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
|
51
55
|
|
52
56
|
|
53
57
|
class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
|
54
58
|
PARAM_TYPE_VLLM = "VLLM_PARAMS"
|
55
59
|
PARAM_TYPE_TGI = "TGI_PARAMS"
|
60
|
+
PARAM_TYPE_LLAMA_CPP = "LLAMA_CPP_PARAMS"
|
61
|
+
|
62
|
+
|
63
|
+
class EvaluationContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
|
64
|
+
AQUA_EVALUATION_CONTAINER_FAMILY = "odsc-llm-evaluate"
|
65
|
+
|
66
|
+
|
67
|
+
class FineTuningContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
|
68
|
+
AQUA_FINETUNING_CONTAINER_FAMILY = "odsc-llm-fine-tuning"
|
56
69
|
|
57
70
|
|
58
71
|
class HuggingFaceTags(str, metaclass=ExtendedEnumMeta):
|
ads/aqua/common/utils.py
CHANGED
@@ -10,6 +10,9 @@ import logging
|
|
10
10
|
import os
|
11
11
|
import random
|
12
12
|
import re
|
13
|
+
import shlex
|
14
|
+
import subprocess
|
15
|
+
from datetime import datetime, timedelta
|
13
16
|
from functools import wraps
|
14
17
|
from pathlib import Path
|
15
18
|
from string import Template
|
@@ -17,7 +20,16 @@ from typing import List, Union
|
|
17
20
|
|
18
21
|
import fsspec
|
19
22
|
import oci
|
23
|
+
from cachetools import TTLCache, cached
|
24
|
+
from huggingface_hub.hf_api import HfApi, ModelInfo
|
25
|
+
from huggingface_hub.utils import (
|
26
|
+
GatedRepoError,
|
27
|
+
HfHubHTTPError,
|
28
|
+
RepositoryNotFoundError,
|
29
|
+
RevisionNotFoundError,
|
30
|
+
)
|
20
31
|
from oci.data_science.models import JobRun, Model
|
32
|
+
from oci.object_storage.models import ObjectSummary
|
21
33
|
|
22
34
|
from ads.aqua.common.enums import (
|
23
35
|
InferenceContainerParamType,
|
@@ -34,6 +46,7 @@ from ads.aqua.constants import (
|
|
34
46
|
COMPARTMENT_MAPPING_KEY,
|
35
47
|
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
|
36
48
|
CONTAINER_INDEX,
|
49
|
+
HF_LOGIN_DEFAULT_TIMEOUT,
|
37
50
|
MAXIMUM_ALLOWED_DATASET_IN_BYTE,
|
38
51
|
MODEL_BY_REFERENCE_OSS_PATH_KEY,
|
39
52
|
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
|
@@ -44,8 +57,7 @@ from ads.aqua.constants import (
|
|
44
57
|
VLLM_INFERENCE_RESTRICTED_PARAMS,
|
45
58
|
)
|
46
59
|
from ads.aqua.data import AquaResourceIdentifier
|
47
|
-
from ads.common.auth import default_signer
|
48
|
-
from ads.common.decorator.threaded import threaded
|
60
|
+
from ads.common.auth import AuthState, default_signer
|
49
61
|
from ads.common.extended_enum import ExtendedEnumMeta
|
50
62
|
from ads.common.object_storage_details import ObjectStorageDetails
|
51
63
|
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
|
@@ -213,7 +225,6 @@ def read_file(file_path: str, **kwargs) -> str:
|
|
213
225
|
return UNKNOWN
|
214
226
|
|
215
227
|
|
216
|
-
@threaded()
|
217
228
|
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
|
218
229
|
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
|
219
230
|
signer = default_signer() if artifact_path.startswith("oci://") else {}
|
@@ -228,6 +239,32 @@ def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
|
|
228
239
|
return config
|
229
240
|
|
230
241
|
|
242
|
+
def list_os_files_with_extension(oss_path: str, extension: str) -> [str]:
|
243
|
+
"""
|
244
|
+
List files in the specified directory with the given extension.
|
245
|
+
|
246
|
+
Parameters:
|
247
|
+
- oss_path: The path to the directory where files are located.
|
248
|
+
- extension: The file extension to filter by (e.g., 'txt' for text files).
|
249
|
+
|
250
|
+
Returns:
|
251
|
+
- A list of file paths matching the specified extension.
|
252
|
+
"""
|
253
|
+
|
254
|
+
oss_client = ObjectStorageDetails.from_path(oss_path)
|
255
|
+
|
256
|
+
# Ensure the extension is prefixed with a dot if not already
|
257
|
+
if not extension.startswith("."):
|
258
|
+
extension = "." + extension
|
259
|
+
files: List[ObjectSummary] = oss_client.list_objects().objects
|
260
|
+
|
261
|
+
return [
|
262
|
+
file.name[len(oss_client.filepath) :].lstrip("/")
|
263
|
+
for file in files
|
264
|
+
if file.name.endswith(extension)
|
265
|
+
]
|
266
|
+
|
267
|
+
|
231
268
|
def is_valid_ocid(ocid: str) -> bool:
|
232
269
|
"""Checks if the given ocid is valid.
|
233
270
|
|
@@ -503,6 +540,7 @@ def container_config_path():
|
|
503
540
|
return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
|
504
541
|
|
505
542
|
|
543
|
+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
|
506
544
|
def get_container_config():
|
507
545
|
config = load_config(
|
508
546
|
file_path=container_config_path(),
|
@@ -743,6 +781,33 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
|
|
743
781
|
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
|
744
782
|
|
745
783
|
|
784
|
+
def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
|
785
|
+
"""Upload the local folder to the object storage
|
786
|
+
|
787
|
+
Args:
|
788
|
+
os_path (str): object storage URI with prefix. This is the path to upload
|
789
|
+
local_dir (str): Local directory where the object is downloaded
|
790
|
+
model_name (str): Name of the huggingface model
|
791
|
+
Retuns:
|
792
|
+
str: Object name inside the bucket
|
793
|
+
"""
|
794
|
+
os_details: ObjectStorageDetails = ObjectStorageDetails.from_path(os_path)
|
795
|
+
if not os_details.is_bucket_versioned():
|
796
|
+
raise ValueError(f"Version is not enabled at object storage location {os_path}")
|
797
|
+
auth_state = AuthState()
|
798
|
+
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
|
799
|
+
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
|
800
|
+
try:
|
801
|
+
logger.info(f"Running: {command}")
|
802
|
+
subprocess.check_call(shlex.split(command))
|
803
|
+
except subprocess.CalledProcessError as e:
|
804
|
+
logger.error(
|
805
|
+
f"Error uploading the object. Exit code: {e.returncode} with error {e.stdout}"
|
806
|
+
)
|
807
|
+
|
808
|
+
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
|
809
|
+
|
810
|
+
|
746
811
|
def is_service_managed_container(container):
|
747
812
|
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
|
748
813
|
|
@@ -881,6 +946,8 @@ def get_container_params_type(container_type_name: str) -> str:
|
|
881
946
|
return InferenceContainerParamType.PARAM_TYPE_VLLM
|
882
947
|
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
|
883
948
|
return InferenceContainerParamType.PARAM_TYPE_TGI
|
949
|
+
elif InferenceContainerType.CONTAINER_TYPE_LLAMA_CPP in container_type_name.lower():
|
950
|
+
return InferenceContainerParamType.PARAM_TYPE_LLAMA_CPP
|
884
951
|
else:
|
885
952
|
return UNKNOWN
|
886
953
|
|
@@ -905,3 +972,93 @@ def get_restricted_params_by_container(container_type_name: str) -> set:
|
|
905
972
|
return TGI_INFERENCE_RESTRICTED_PARAMS
|
906
973
|
else:
|
907
974
|
return set()
|
975
|
+
|
976
|
+
|
977
|
+
def get_huggingface_login_timeout() -> int:
|
978
|
+
"""This helper function returns the huggingface login timeout, returns default if not set via
|
979
|
+
env var.
|
980
|
+
Returns
|
981
|
+
-------
|
982
|
+
timeout: int
|
983
|
+
huggingface login timeout.
|
984
|
+
|
985
|
+
"""
|
986
|
+
timeout = HF_LOGIN_DEFAULT_TIMEOUT
|
987
|
+
try:
|
988
|
+
timeout = int(
|
989
|
+
os.environ.get("HF_LOGIN_DEFAULT_TIMEOUT", HF_LOGIN_DEFAULT_TIMEOUT)
|
990
|
+
)
|
991
|
+
except ValueError:
|
992
|
+
pass
|
993
|
+
return timeout
|
994
|
+
|
995
|
+
|
996
|
+
def format_hf_custom_error_message(error: HfHubHTTPError):
|
997
|
+
"""
|
998
|
+
Formats a custom error message based on the Hugging Face error response.
|
999
|
+
|
1000
|
+
Parameters
|
1001
|
+
----------
|
1002
|
+
error (HfHubHTTPError): The caught exception.
|
1003
|
+
|
1004
|
+
Raises
|
1005
|
+
------
|
1006
|
+
AquaRuntimeError: A user-friendly error message.
|
1007
|
+
"""
|
1008
|
+
# Extract the repository URL from the error message if present
|
1009
|
+
match = re.search(r"(https://huggingface.co/[^\s]+)", str(error))
|
1010
|
+
url = match.group(1) if match else "the requested Hugging Face URL."
|
1011
|
+
|
1012
|
+
if isinstance(error, RepositoryNotFoundError):
|
1013
|
+
raise AquaRuntimeError(
|
1014
|
+
reason=f"Failed to access `{url}`. Please check if the provided repository name is correct. "
|
1015
|
+
"If the repo is private, make sure you are authenticated and have a valid HF token registered. "
|
1016
|
+
"To register your token, run this command in your terminal: `huggingface-cli login`",
|
1017
|
+
service_payload={"error": "RepositoryNotFoundError"},
|
1018
|
+
)
|
1019
|
+
|
1020
|
+
if isinstance(error, GatedRepoError):
|
1021
|
+
raise AquaRuntimeError(
|
1022
|
+
reason=f"Access denied to `{url}` "
|
1023
|
+
"This repository is gated. Access is restricted to authorized users. "
|
1024
|
+
"Please request access or check with the repository administrator. "
|
1025
|
+
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
|
1026
|
+
"To register your token, run this command in your terminal: `huggingface-cli login`",
|
1027
|
+
service_payload={"error": "GatedRepoError"},
|
1028
|
+
)
|
1029
|
+
|
1030
|
+
if isinstance(error, RevisionNotFoundError):
|
1031
|
+
raise AquaRuntimeError(
|
1032
|
+
reason=f"The specified revision could not be found at `{url}` "
|
1033
|
+
"Please check the revision identifier and try again.",
|
1034
|
+
service_payload={"error": "RevisionNotFoundError"},
|
1035
|
+
)
|
1036
|
+
|
1037
|
+
raise AquaRuntimeError(
|
1038
|
+
reason=f"An error occurred while accessing `{url}` "
|
1039
|
+
"Please check your network connection and try again. "
|
1040
|
+
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
|
1041
|
+
"To register your token, run this command in your terminal: `huggingface-cli login`",
|
1042
|
+
service_payload={"error": "Error"},
|
1043
|
+
)
|
1044
|
+
|
1045
|
+
|
1046
|
+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
|
1047
|
+
def get_hf_model_info(repo_id: str) -> ModelInfo:
|
1048
|
+
"""Gets the model information object for the given model repository name. For models that requires a token,
|
1049
|
+
this method assumes that the token validation is already done.
|
1050
|
+
|
1051
|
+
Parameters
|
1052
|
+
----------
|
1053
|
+
repo_id: str
|
1054
|
+
hugging face model repository name
|
1055
|
+
|
1056
|
+
Returns
|
1057
|
+
-------
|
1058
|
+
instance of ModelInfo object
|
1059
|
+
|
1060
|
+
"""
|
1061
|
+
try:
|
1062
|
+
return HfApi().model_info(repo_id=repo_id)
|
1063
|
+
except HfHubHTTPError as err:
|
1064
|
+
raise format_hf_custom_error_message(err) from err
|
ads/aqua/config/config.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
@@ -14,5 +13,6 @@ def get_finetuning_config_defaults():
|
|
14
13
|
"BM.GPU.A10.4": {"batch_size": 1, "replica": 1},
|
15
14
|
"BM.GPU4.8": {"batch_size": 4, "replica": 1},
|
16
15
|
"BM.GPU.A100-v2.8": {"batch_size": 6, "replica": 1},
|
16
|
+
"BM.GPU.H100.8": {"batch_size": 6, "replica": 1},
|
17
17
|
}
|
18
18
|
}
|
@@ -1,9 +1,37 @@
|
|
1
1
|
{
|
2
|
+
"configuration": {
|
3
|
+
"VM.Standard.A1.Flex": {
|
4
|
+
"parameters": {},
|
5
|
+
"shape_info": {
|
6
|
+
"configs": [
|
7
|
+
{
|
8
|
+
"memory_in_gbs": 128,
|
9
|
+
"ocpu": 20
|
10
|
+
},
|
11
|
+
{
|
12
|
+
"memory_in_gbs": 256,
|
13
|
+
"ocpu": 40
|
14
|
+
},
|
15
|
+
{
|
16
|
+
"memory_in_gbs": 384,
|
17
|
+
"ocpu": 60
|
18
|
+
},
|
19
|
+
{
|
20
|
+
"memory_in_gbs": 512,
|
21
|
+
"ocpu": 80
|
22
|
+
}
|
23
|
+
],
|
24
|
+
"type": "CPU"
|
25
|
+
}
|
26
|
+
}
|
27
|
+
},
|
2
28
|
"shape": [
|
3
29
|
"VM.GPU.A10.1",
|
4
30
|
"VM.GPU.A10.2",
|
5
31
|
"BM.GPU.A10.4",
|
6
32
|
"BM.GPU4.8",
|
7
|
-
"BM.GPU.A100-v2.8"
|
33
|
+
"BM.GPU.A100-v2.8",
|
34
|
+
"BM.GPU.H100.8",
|
35
|
+
"VM.Standard.A1.Flex"
|
8
36
|
]
|
9
37
|
}
|
ads/aqua/constants.py
CHANGED
@@ -21,7 +21,6 @@ DEFAULT_FT_BLOCK_STORAGE_SIZE = 750
|
|
21
21
|
DEFAULT_FT_REPLICA = 1
|
22
22
|
DEFAULT_FT_BATCH_SIZE = 1
|
23
23
|
DEFAULT_FT_VALIDATION_SET_SIZE = 0.1
|
24
|
-
|
25
24
|
MAXIMUM_ALLOWED_DATASET_IN_BYTE = 52428800 # 1024 x 1024 x 50 = 50MB
|
26
25
|
JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING = "ME_STANDALONE"
|
27
26
|
NB_SESSION_IDENTIFIER = "NB_SESSION_OCID"
|
@@ -34,6 +33,8 @@ AQUA_MODEL_TYPE_CUSTOM = "custom"
|
|
34
33
|
AQUA_MODEL_ARTIFACT_CONFIG = "config.json"
|
35
34
|
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
|
36
35
|
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
|
36
|
+
AQUA_MODEL_ARTIFACT_FILE = "model_file"
|
37
|
+
HF_LOGIN_DEFAULT_TIMEOUT = 2
|
37
38
|
|
38
39
|
TRAINING_METRICS_FINAL = "training_metrics_final"
|
39
40
|
VALIDATION_METRICS_FINAL = "validation_metrics_final"
|
@@ -74,3 +75,7 @@ TGI_INFERENCE_RESTRICTED_PARAMS = {
|
|
74
75
|
"--sharded",
|
75
76
|
"--trust-remote-code",
|
76
77
|
}
|
78
|
+
LLAMA_CPP_INFERENCE_RESTRICTED_PARAMS = {
|
79
|
+
"--port",
|
80
|
+
"--host",
|
81
|
+
}
|
ads/aqua/evaluation/entities.py
CHANGED
@@ -7,7 +7,7 @@ import os
|
|
7
7
|
import re
|
8
8
|
import tempfile
|
9
9
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
10
|
-
from dataclasses import asdict
|
10
|
+
from dataclasses import asdict, fields
|
11
11
|
from datetime import datetime, timedelta
|
12
12
|
from pathlib import Path
|
13
13
|
from threading import Lock
|
@@ -76,6 +76,7 @@ from ads.aqua.evaluation.entities import (
|
|
76
76
|
ModelParams,
|
77
77
|
)
|
78
78
|
from ads.aqua.evaluation.errors import EVALUATION_JOB_EXIT_CODE_MESSAGE
|
79
|
+
from ads.aqua.ui import AquaContainerConfig
|
79
80
|
from ads.common.auth import default_signer
|
80
81
|
from ads.common.object_storage_details import ObjectStorageDetails
|
81
82
|
from ads.common.utils import get_console_link, get_files, get_log_links
|
@@ -90,7 +91,9 @@ from ads.jobs.builders.infrastructure.dsc_job import DataScienceJob
|
|
90
91
|
from ads.jobs.builders.runtimes.base import Runtime
|
91
92
|
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
|
92
93
|
from ads.model.datascience_model import DataScienceModel
|
94
|
+
from ads.model.deployment import ModelDeploymentContainerRuntime
|
93
95
|
from ads.model.deployment.model_deployment import ModelDeployment
|
96
|
+
from ads.model.generic_model import ModelDeploymentRuntimeType
|
94
97
|
from ads.model.model_metadata import (
|
95
98
|
MetadataTaxonomyKeys,
|
96
99
|
ModelCustomMetadata,
|
@@ -157,8 +160,9 @@ class AquaEvaluationApp(AquaApp):
|
|
157
160
|
create_aqua_evaluation_details = CreateAquaEvaluationDetails(**kwargs)
|
158
161
|
except Exception as ex:
|
159
162
|
raise AquaValueError(
|
160
|
-
"Invalid create evaluation parameters.
|
161
|
-
|
163
|
+
"Invalid create evaluation parameters. "
|
164
|
+
"Allowable parameters are: "
|
165
|
+
f"{', '.join([field.name for field in fields(CreateAquaEvaluationDetails)])}."
|
162
166
|
) from ex
|
163
167
|
|
164
168
|
if not is_valid_ocid(create_aqua_evaluation_details.evaluation_source_id):
|
@@ -166,8 +170,8 @@ class AquaEvaluationApp(AquaApp):
|
|
166
170
|
f"Invalid evaluation source {create_aqua_evaluation_details.evaluation_source_id}. "
|
167
171
|
"Specify either a model or model deployment id."
|
168
172
|
)
|
169
|
-
|
170
173
|
evaluation_source = None
|
174
|
+
eval_inference_configuration = None
|
171
175
|
if (
|
172
176
|
DataScienceResource.MODEL_DEPLOYMENT
|
173
177
|
in create_aqua_evaluation_details.evaluation_source_id
|
@@ -175,6 +179,28 @@ class AquaEvaluationApp(AquaApp):
|
|
175
179
|
evaluation_source = ModelDeployment.from_id(
|
176
180
|
create_aqua_evaluation_details.evaluation_source_id
|
177
181
|
)
|
182
|
+
try:
|
183
|
+
if (
|
184
|
+
evaluation_source.runtime.type
|
185
|
+
== ModelDeploymentRuntimeType.CONTAINER
|
186
|
+
):
|
187
|
+
runtime = ModelDeploymentContainerRuntime.from_dict(
|
188
|
+
evaluation_source.runtime.to_dict()
|
189
|
+
)
|
190
|
+
inference_config = AquaContainerConfig.from_container_index_json(
|
191
|
+
enable_spec=True
|
192
|
+
).inference
|
193
|
+
for container in inference_config.values():
|
194
|
+
if container.name == runtime.image[:runtime.image.rfind(":")]:
|
195
|
+
eval_inference_configuration = (
|
196
|
+
container.spec.evaluation_configuration
|
197
|
+
)
|
198
|
+
except Exception:
|
199
|
+
logger.debug(
|
200
|
+
f"Could not load inference config details for the evaluation id: "
|
201
|
+
f"{create_aqua_evaluation_details.evaluation_source_id}. Please check if the container"
|
202
|
+
f" runtime has the correct SMC image information."
|
203
|
+
)
|
178
204
|
elif (
|
179
205
|
DataScienceResource.MODEL
|
180
206
|
in create_aqua_evaluation_details.evaluation_source_id
|
@@ -390,6 +416,9 @@ class AquaEvaluationApp(AquaApp):
|
|
390
416
|
report_path=create_aqua_evaluation_details.report_path,
|
391
417
|
model_parameters=create_aqua_evaluation_details.model_parameters,
|
392
418
|
metrics=create_aqua_evaluation_details.metrics,
|
419
|
+
inference_configuration=eval_inference_configuration.to_filtered_dict()
|
420
|
+
if eval_inference_configuration
|
421
|
+
else {},
|
393
422
|
)
|
394
423
|
).create(**kwargs) ## TODO: decide what parameters will be needed
|
395
424
|
logger.debug(
|
@@ -511,6 +540,7 @@ class AquaEvaluationApp(AquaApp):
|
|
511
540
|
report_path: str,
|
512
541
|
model_parameters: dict,
|
513
542
|
metrics: List = None,
|
543
|
+
inference_configuration: dict = None,
|
514
544
|
) -> Runtime:
|
515
545
|
"""Builds evaluation runtime for Job."""
|
516
546
|
# TODO the image name needs to be extracted from the mapping index.json file.
|
@@ -520,16 +550,19 @@ class AquaEvaluationApp(AquaApp):
|
|
520
550
|
.with_environment_variable(
|
521
551
|
**{
|
522
552
|
"AIP_SMC_EVALUATION_ARGUMENTS": json.dumps(
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
553
|
+
{
|
554
|
+
**asdict(
|
555
|
+
self._build_launch_cmd(
|
556
|
+
evaluation_id=evaluation_id,
|
557
|
+
evaluation_source_id=evaluation_source_id,
|
558
|
+
dataset_path=dataset_path,
|
559
|
+
report_path=report_path,
|
560
|
+
model_parameters=model_parameters,
|
561
|
+
metrics=metrics,
|
562
|
+
),
|
563
|
+
),
|
564
|
+
**(inference_configuration or {}),
|
565
|
+
},
|
533
566
|
),
|
534
567
|
"CONDA_BUCKET_NS": CONDA_BUCKET_NS,
|
535
568
|
},
|
@@ -1,18 +1,24 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
6
5
|
|
7
6
|
from importlib import metadata
|
8
7
|
|
8
|
+
import huggingface_hub
|
9
9
|
import requests
|
10
|
+
from huggingface_hub import HfApi
|
11
|
+
from huggingface_hub.utils import LocalTokenNotFoundError
|
10
12
|
from tornado.web import HTTPError
|
11
13
|
|
12
14
|
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
|
13
15
|
from ads.aqua.common.decorator import handle_exceptions
|
14
16
|
from ads.aqua.common.errors import AquaResourceAccessError, AquaRuntimeError
|
15
|
-
from ads.aqua.common.utils import
|
17
|
+
from ads.aqua.common.utils import (
|
18
|
+
fetch_service_compartment,
|
19
|
+
get_huggingface_login_timeout,
|
20
|
+
known_realm,
|
21
|
+
)
|
16
22
|
from ads.aqua.extension.base_handler import AquaAPIhandler
|
17
23
|
from ads.aqua.extension.errors import Errors
|
18
24
|
|
@@ -46,16 +52,80 @@ class CompatibilityCheckHandler(AquaAPIhandler):
|
|
46
52
|
|
47
53
|
"""
|
48
54
|
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
|
49
|
-
return self.finish(
|
55
|
+
return self.finish({"status": "ok"})
|
50
56
|
elif known_realm():
|
51
|
-
return self.finish(
|
57
|
+
return self.finish({"status": "compatible"})
|
52
58
|
else:
|
53
59
|
raise AquaResourceAccessError(
|
54
|
-
|
60
|
+
"The AI Quick actions extension is not compatible in the given region."
|
55
61
|
)
|
56
62
|
|
57
63
|
|
64
|
+
class NetworkStatusHandler(AquaAPIhandler):
|
65
|
+
"""Handler to check internet connection."""
|
66
|
+
|
67
|
+
@handle_exceptions
|
68
|
+
def get(self):
|
69
|
+
requests.get("https://huggingface.com", timeout=get_huggingface_login_timeout())
|
70
|
+
return self.finish({"status": 200, "message": "success"})
|
71
|
+
|
72
|
+
|
73
|
+
class HFLoginHandler(AquaAPIhandler):
|
74
|
+
"""Handler to login to HF."""
|
75
|
+
|
76
|
+
@handle_exceptions
|
77
|
+
def post(self, *args, **kwargs):
|
78
|
+
"""Handles post request for the HF login.
|
79
|
+
|
80
|
+
Raises
|
81
|
+
------
|
82
|
+
HTTPError
|
83
|
+
Raises HTTPError if inputs are missing or are invalid.
|
84
|
+
"""
|
85
|
+
try:
|
86
|
+
input_data = self.get_json_body()
|
87
|
+
except Exception as ex:
|
88
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
|
89
|
+
|
90
|
+
if not input_data:
|
91
|
+
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
92
|
+
|
93
|
+
token = input_data.get("token")
|
94
|
+
|
95
|
+
if not token:
|
96
|
+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("token"))
|
97
|
+
|
98
|
+
# Login to HF
|
99
|
+
try:
|
100
|
+
huggingface_hub.login(token=token, new_session=False)
|
101
|
+
except Exception as ex:
|
102
|
+
raise AquaRuntimeError(
|
103
|
+
reason=str(ex), service_payload={"error": type(ex).__name__}
|
104
|
+
) from ex
|
105
|
+
|
106
|
+
return self.finish({"status": 200, "message": "login successful"})
|
107
|
+
|
108
|
+
|
109
|
+
class HFUserStatusHandler(AquaAPIhandler):
|
110
|
+
"""Handler to check if user logged in to the HF."""
|
111
|
+
|
112
|
+
@handle_exceptions
|
113
|
+
def get(self):
|
114
|
+
try:
|
115
|
+
HfApi().whoami()
|
116
|
+
except LocalTokenNotFoundError as err:
|
117
|
+
raise AquaRuntimeError(
|
118
|
+
"You are not logged in. Please log in to Hugging Face using the `huggingface-cli login` command."
|
119
|
+
"See https://huggingface.co/settings/tokens.",
|
120
|
+
) from err
|
121
|
+
|
122
|
+
return self.finish({"status": 200, "message": "logged in"})
|
123
|
+
|
124
|
+
|
58
125
|
__handlers__ = [
|
59
126
|
("ads_version", ADSVersionHandler),
|
60
127
|
("hello", CompatibilityCheckHandler),
|
128
|
+
("network_status", NetworkStatusHandler),
|
129
|
+
("hf_login", HFLoginHandler),
|
130
|
+
("hf_logged_in", HFUserStatusHandler),
|
61
131
|
]
|