oracle-ads 2.11.15__py3-none-any.whl → 2.11.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. ads/aqua/app.py +5 -6
  2. ads/aqua/common/entities.py +17 -0
  3. ads/aqua/common/enums.py +14 -1
  4. ads/aqua/common/utils.py +160 -3
  5. ads/aqua/config/config.py +1 -1
  6. ads/aqua/config/deployment_config_defaults.json +29 -1
  7. ads/aqua/config/resource_limit_names.json +1 -0
  8. ads/aqua/constants.py +6 -1
  9. ads/aqua/evaluation/entities.py +0 -1
  10. ads/aqua/evaluation/evaluation.py +47 -14
  11. ads/aqua/extension/common_handler.py +75 -5
  12. ads/aqua/extension/common_ws_msg_handler.py +57 -0
  13. ads/aqua/extension/deployment_handler.py +16 -13
  14. ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
  15. ads/aqua/extension/errors.py +1 -1
  16. ads/aqua/extension/evaluation_ws_msg_handler.py +28 -6
  17. ads/aqua/extension/model_handler.py +134 -8
  18. ads/aqua/extension/models/ws_models.py +78 -3
  19. ads/aqua/extension/models_ws_msg_handler.py +49 -0
  20. ads/aqua/extension/ui_websocket_handler.py +7 -1
  21. ads/aqua/model/entities.py +28 -0
  22. ads/aqua/model/model.py +544 -129
  23. ads/aqua/modeldeployment/deployment.py +102 -43
  24. ads/aqua/modeldeployment/entities.py +9 -20
  25. ads/aqua/ui.py +152 -28
  26. ads/common/object_storage_details.py +2 -5
  27. ads/common/serializer.py +2 -3
  28. ads/jobs/builders/infrastructure/dsc_job.py +41 -12
  29. ads/jobs/builders/infrastructure/dsc_job_runtime.py +74 -27
  30. ads/jobs/builders/runtimes/container_runtime.py +83 -4
  31. ads/opctl/operator/lowcode/anomaly/const.py +1 -0
  32. ads/opctl/operator/lowcode/anomaly/model/base_model.py +23 -7
  33. ads/opctl/operator/lowcode/anomaly/operator_config.py +1 -0
  34. ads/opctl/operator/lowcode/anomaly/schema.yaml +4 -0
  35. ads/opctl/operator/lowcode/common/errors.py +6 -0
  36. ads/opctl/operator/lowcode/forecast/model/arima.py +3 -1
  37. ads/opctl/operator/lowcode/forecast/model/base_model.py +21 -13
  38. ads/opctl/operator/lowcode/forecast/model_evaluator.py +11 -2
  39. ads/pipeline/ads_pipeline_run.py +13 -2
  40. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/METADATA +2 -1
  41. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/RECORD +44 -40
  42. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/LICENSE.txt +0 -0
  43. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/WHEEL +0 -0
  44. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/entry_points.txt +0 -0
ads/aqua/app.py CHANGED
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
@@ -175,7 +174,7 @@ class AquaApp:
175
174
  f"Invalid model version set name. Please provide a model version set with `{tag}` in tags."
176
175
  )
177
176
 
178
- except:
177
+ except Exception:
179
178
  logger.debug(
180
179
  f"Model version set {model_version_set_name} doesn't exist. "
181
180
  "Creating new model version set."
@@ -254,7 +253,7 @@ class AquaApp:
254
253
 
255
254
  try:
256
255
  response = self.ds_client.head_model_artifact(model_id=model_id, **kwargs)
257
- return True if response.status == 200 else False
256
+ return response.status == 200
258
257
  except oci.exceptions.ServiceError as ex:
259
258
  if ex.status == 404:
260
259
  logger.info(f"Artifact not found in model {model_id}.")
@@ -302,7 +301,7 @@ class AquaApp:
302
301
  config_path,
303
302
  config_file_name=config_file_name,
304
303
  )
305
- except:
304
+ except Exception:
306
305
  # todo: temp fix for issue related to config load for byom models, update logic to choose the right path
307
306
  try:
308
307
  config_path = f"{artifact_path.rstrip('/')}/config/"
@@ -310,7 +309,7 @@ class AquaApp:
310
309
  config_path,
311
310
  config_file_name=config_file_name,
312
311
  )
313
- except:
312
+ except Exception:
314
313
  pass
315
314
 
316
315
  if not config:
@@ -343,7 +342,7 @@ class CLIBuilderMixin:
343
342
  params = [
344
343
  f"--{field.name} {getattr(self,field.name)}"
345
344
  for field in fields(self.__class__)
346
- if getattr(self, field.name)
345
+ if getattr(self, field.name) is not None
347
346
  ]
348
347
  cmd = f"{cmd} {' '.join(params)}"
349
348
  return cmd
@@ -0,0 +1,17 @@
1
+ #!/usr/bin/env python
2
+ # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
+
5
+
6
+ class ContainerSpec:
7
+ """
8
+ Class to hold to hold keys within the container spec.
9
+ """
10
+
11
+ CONTAINER_SPEC = "containerSpec"
12
+ CLI_PARM = "cliParam"
13
+ SERVER_PORT = "serverPort"
14
+ HEALTH_CHECK_PORT = "healthCheckPort"
15
+ ENV_VARS = "envVars"
16
+ RESTRICTED_PARAMS = "restrictedParams"
17
+ EVALUATION_CONFIGURATION = "evaluationConfiguration"
ads/aqua/common/enums.py CHANGED
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
@@ -8,6 +7,7 @@ aqua.common.enums
8
7
  ~~~~~~~~~~~~~~
9
8
  This module contains the set of enums used in AQUA.
10
9
  """
10
+
11
11
  from ads.common.extended_enum import ExtendedEnumMeta
12
12
 
13
13
 
@@ -38,21 +38,34 @@ class Tags(str, metaclass=ExtendedEnumMeta):
38
38
  READY_TO_IMPORT = "ready_to_import"
39
39
  BASE_MODEL_CUSTOM = "aqua_custom_base_model"
40
40
  AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
41
+ MODEL_FORMAT = "model_format"
42
+ MODEL_ARTIFACT_FILE = "model_file"
41
43
 
42
44
 
43
45
  class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
44
46
  CONTAINER_TYPE_VLLM = "vllm"
45
47
  CONTAINER_TYPE_TGI = "tgi"
48
+ CONTAINER_TYPE_LLAMA_CPP = "llama-cpp"
46
49
 
47
50
 
48
51
  class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
49
52
  AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
50
53
  AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
54
+ AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
51
55
 
52
56
 
53
57
  class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
54
58
  PARAM_TYPE_VLLM = "VLLM_PARAMS"
55
59
  PARAM_TYPE_TGI = "TGI_PARAMS"
60
+ PARAM_TYPE_LLAMA_CPP = "LLAMA_CPP_PARAMS"
61
+
62
+
63
+ class EvaluationContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
64
+ AQUA_EVALUATION_CONTAINER_FAMILY = "odsc-llm-evaluate"
65
+
66
+
67
+ class FineTuningContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
68
+ AQUA_FINETUNING_CONTAINER_FAMILY = "odsc-llm-fine-tuning"
56
69
 
57
70
 
58
71
  class HuggingFaceTags(str, metaclass=ExtendedEnumMeta):
ads/aqua/common/utils.py CHANGED
@@ -10,6 +10,9 @@ import logging
10
10
  import os
11
11
  import random
12
12
  import re
13
+ import shlex
14
+ import subprocess
15
+ from datetime import datetime, timedelta
13
16
  from functools import wraps
14
17
  from pathlib import Path
15
18
  from string import Template
@@ -17,7 +20,16 @@ from typing import List, Union
17
20
 
18
21
  import fsspec
19
22
  import oci
23
+ from cachetools import TTLCache, cached
24
+ from huggingface_hub.hf_api import HfApi, ModelInfo
25
+ from huggingface_hub.utils import (
26
+ GatedRepoError,
27
+ HfHubHTTPError,
28
+ RepositoryNotFoundError,
29
+ RevisionNotFoundError,
30
+ )
20
31
  from oci.data_science.models import JobRun, Model
32
+ from oci.object_storage.models import ObjectSummary
21
33
 
22
34
  from ads.aqua.common.enums import (
23
35
  InferenceContainerParamType,
@@ -34,6 +46,7 @@ from ads.aqua.constants import (
34
46
  COMPARTMENT_MAPPING_KEY,
35
47
  CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
36
48
  CONTAINER_INDEX,
49
+ HF_LOGIN_DEFAULT_TIMEOUT,
37
50
  MAXIMUM_ALLOWED_DATASET_IN_BYTE,
38
51
  MODEL_BY_REFERENCE_OSS_PATH_KEY,
39
52
  SERVICE_MANAGED_CONTAINER_URI_SCHEME,
@@ -44,8 +57,7 @@ from ads.aqua.constants import (
44
57
  VLLM_INFERENCE_RESTRICTED_PARAMS,
45
58
  )
46
59
  from ads.aqua.data import AquaResourceIdentifier
47
- from ads.common.auth import default_signer
48
- from ads.common.decorator.threaded import threaded
60
+ from ads.common.auth import AuthState, default_signer
49
61
  from ads.common.extended_enum import ExtendedEnumMeta
50
62
  from ads.common.object_storage_details import ObjectStorageDetails
51
63
  from ads.common.oci_resource import SEARCH_TYPE, OCIResource
@@ -213,7 +225,6 @@ def read_file(file_path: str, **kwargs) -> str:
213
225
  return UNKNOWN
214
226
 
215
227
 
216
- @threaded()
217
228
  def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
218
229
  artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
219
230
  signer = default_signer() if artifact_path.startswith("oci://") else {}
@@ -228,6 +239,32 @@ def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
228
239
  return config
229
240
 
230
241
 
242
+ def list_os_files_with_extension(oss_path: str, extension: str) -> [str]:
243
+ """
244
+ List files in the specified directory with the given extension.
245
+
246
+ Parameters:
247
+ - oss_path: The path to the directory where files are located.
248
+ - extension: The file extension to filter by (e.g., 'txt' for text files).
249
+
250
+ Returns:
251
+ - A list of file paths matching the specified extension.
252
+ """
253
+
254
+ oss_client = ObjectStorageDetails.from_path(oss_path)
255
+
256
+ # Ensure the extension is prefixed with a dot if not already
257
+ if not extension.startswith("."):
258
+ extension = "." + extension
259
+ files: List[ObjectSummary] = oss_client.list_objects().objects
260
+
261
+ return [
262
+ file.name[len(oss_client.filepath) :].lstrip("/")
263
+ for file in files
264
+ if file.name.endswith(extension)
265
+ ]
266
+
267
+
231
268
  def is_valid_ocid(ocid: str) -> bool:
232
269
  """Checks if the given ocid is valid.
233
270
 
@@ -503,6 +540,7 @@ def container_config_path():
503
540
  return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
504
541
 
505
542
 
543
+ @cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
506
544
  def get_container_config():
507
545
  config = load_config(
508
546
  file_path=container_config_path(),
@@ -743,6 +781,33 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
743
781
  return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
744
782
 
745
783
 
784
+ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
785
+ """Upload the local folder to the object storage
786
+
787
+ Args:
788
+ os_path (str): object storage URI with prefix. This is the path to upload
789
+ local_dir (str): Local directory where the object is downloaded
790
+ model_name (str): Name of the huggingface model
791
+ Retuns:
792
+ str: Object name inside the bucket
793
+ """
794
+ os_details: ObjectStorageDetails = ObjectStorageDetails.from_path(os_path)
795
+ if not os_details.is_bucket_versioned():
796
+ raise ValueError(f"Version is not enabled at object storage location {os_path}")
797
+ auth_state = AuthState()
798
+ object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
799
+ command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
800
+ try:
801
+ logger.info(f"Running: {command}")
802
+ subprocess.check_call(shlex.split(command))
803
+ except subprocess.CalledProcessError as e:
804
+ logger.error(
805
+ f"Error uploading the object. Exit code: {e.returncode} with error {e.stdout}"
806
+ )
807
+
808
+ return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
809
+
810
+
746
811
  def is_service_managed_container(container):
747
812
  return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
748
813
 
@@ -881,6 +946,8 @@ def get_container_params_type(container_type_name: str) -> str:
881
946
  return InferenceContainerParamType.PARAM_TYPE_VLLM
882
947
  elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
883
948
  return InferenceContainerParamType.PARAM_TYPE_TGI
949
+ elif InferenceContainerType.CONTAINER_TYPE_LLAMA_CPP in container_type_name.lower():
950
+ return InferenceContainerParamType.PARAM_TYPE_LLAMA_CPP
884
951
  else:
885
952
  return UNKNOWN
886
953
 
@@ -905,3 +972,93 @@ def get_restricted_params_by_container(container_type_name: str) -> set:
905
972
  return TGI_INFERENCE_RESTRICTED_PARAMS
906
973
  else:
907
974
  return set()
975
+
976
+
977
+ def get_huggingface_login_timeout() -> int:
978
+ """This helper function returns the huggingface login timeout, returns default if not set via
979
+ env var.
980
+ Returns
981
+ -------
982
+ timeout: int
983
+ huggingface login timeout.
984
+
985
+ """
986
+ timeout = HF_LOGIN_DEFAULT_TIMEOUT
987
+ try:
988
+ timeout = int(
989
+ os.environ.get("HF_LOGIN_DEFAULT_TIMEOUT", HF_LOGIN_DEFAULT_TIMEOUT)
990
+ )
991
+ except ValueError:
992
+ pass
993
+ return timeout
994
+
995
+
996
+ def format_hf_custom_error_message(error: HfHubHTTPError):
997
+ """
998
+ Formats a custom error message based on the Hugging Face error response.
999
+
1000
+ Parameters
1001
+ ----------
1002
+ error (HfHubHTTPError): The caught exception.
1003
+
1004
+ Raises
1005
+ ------
1006
+ AquaRuntimeError: A user-friendly error message.
1007
+ """
1008
+ # Extract the repository URL from the error message if present
1009
+ match = re.search(r"(https://huggingface.co/[^\s]+)", str(error))
1010
+ url = match.group(1) if match else "the requested Hugging Face URL."
1011
+
1012
+ if isinstance(error, RepositoryNotFoundError):
1013
+ raise AquaRuntimeError(
1014
+ reason=f"Failed to access `{url}`. Please check if the provided repository name is correct. "
1015
+ "If the repo is private, make sure you are authenticated and have a valid HF token registered. "
1016
+ "To register your token, run this command in your terminal: `huggingface-cli login`",
1017
+ service_payload={"error": "RepositoryNotFoundError"},
1018
+ )
1019
+
1020
+ if isinstance(error, GatedRepoError):
1021
+ raise AquaRuntimeError(
1022
+ reason=f"Access denied to `{url}` "
1023
+ "This repository is gated. Access is restricted to authorized users. "
1024
+ "Please request access or check with the repository administrator. "
1025
+ "If you are trying to access a gated repository, ensure you have a valid HF token registered. "
1026
+ "To register your token, run this command in your terminal: `huggingface-cli login`",
1027
+ service_payload={"error": "GatedRepoError"},
1028
+ )
1029
+
1030
+ if isinstance(error, RevisionNotFoundError):
1031
+ raise AquaRuntimeError(
1032
+ reason=f"The specified revision could not be found at `{url}` "
1033
+ "Please check the revision identifier and try again.",
1034
+ service_payload={"error": "RevisionNotFoundError"},
1035
+ )
1036
+
1037
+ raise AquaRuntimeError(
1038
+ reason=f"An error occurred while accessing `{url}` "
1039
+ "Please check your network connection and try again. "
1040
+ "If you are trying to access a gated repository, ensure you have a valid HF token registered. "
1041
+ "To register your token, run this command in your terminal: `huggingface-cli login`",
1042
+ service_payload={"error": "Error"},
1043
+ )
1044
+
1045
+
1046
+ @cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
1047
+ def get_hf_model_info(repo_id: str) -> ModelInfo:
1048
+ """Gets the model information object for the given model repository name. For models that requires a token,
1049
+ this method assumes that the token validation is already done.
1050
+
1051
+ Parameters
1052
+ ----------
1053
+ repo_id: str
1054
+ hugging face model repository name
1055
+
1056
+ Returns
1057
+ -------
1058
+ instance of ModelInfo object
1059
+
1060
+ """
1061
+ try:
1062
+ return HfApi().model_info(repo_id=repo_id)
1063
+ except HfHubHTTPError as err:
1064
+ raise format_hf_custom_error_message(err) from err
ads/aqua/config/config.py CHANGED
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
@@ -14,5 +13,6 @@ def get_finetuning_config_defaults():
14
13
  "BM.GPU.A10.4": {"batch_size": 1, "replica": 1},
15
14
  "BM.GPU4.8": {"batch_size": 4, "replica": 1},
16
15
  "BM.GPU.A100-v2.8": {"batch_size": 6, "replica": 1},
16
+ "BM.GPU.H100.8": {"batch_size": 6, "replica": 1},
17
17
  }
18
18
  }
@@ -1,9 +1,37 @@
1
1
  {
2
+ "configuration": {
3
+ "VM.Standard.A1.Flex": {
4
+ "parameters": {},
5
+ "shape_info": {
6
+ "configs": [
7
+ {
8
+ "memory_in_gbs": 128,
9
+ "ocpu": 20
10
+ },
11
+ {
12
+ "memory_in_gbs": 256,
13
+ "ocpu": 40
14
+ },
15
+ {
16
+ "memory_in_gbs": 384,
17
+ "ocpu": 60
18
+ },
19
+ {
20
+ "memory_in_gbs": 512,
21
+ "ocpu": 80
22
+ }
23
+ ],
24
+ "type": "CPU"
25
+ }
26
+ }
27
+ },
2
28
  "shape": [
3
29
  "VM.GPU.A10.1",
4
30
  "VM.GPU.A10.2",
5
31
  "BM.GPU.A10.4",
6
32
  "BM.GPU4.8",
7
- "BM.GPU.A100-v2.8"
33
+ "BM.GPU.A100-v2.8",
34
+ "BM.GPU.H100.8",
35
+ "VM.Standard.A1.Flex"
8
36
  ]
9
37
  }
@@ -1,6 +1,7 @@
1
1
  {
2
2
  "BM.GPU.A10.4": "ds-gpu-a10-count",
3
3
  "BM.GPU.A100-v2.8": "ds-gpu-a100-v2-count",
4
+ "BM.GPU.H100.8": "ds-gpu-h100-count",
4
5
  "BM.GPU4.8": "ds-gpu4-count",
5
6
  "VM.GPU.A10.1": "ds-gpu-a10-count",
6
7
  "VM.GPU.A10.2": "ds-gpu-a10-count"
ads/aqua/constants.py CHANGED
@@ -21,7 +21,6 @@ DEFAULT_FT_BLOCK_STORAGE_SIZE = 750
21
21
  DEFAULT_FT_REPLICA = 1
22
22
  DEFAULT_FT_BATCH_SIZE = 1
23
23
  DEFAULT_FT_VALIDATION_SET_SIZE = 0.1
24
-
25
24
  MAXIMUM_ALLOWED_DATASET_IN_BYTE = 52428800 # 1024 x 1024 x 50 = 50MB
26
25
  JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING = "ME_STANDALONE"
27
26
  NB_SESSION_IDENTIFIER = "NB_SESSION_OCID"
@@ -34,6 +33,8 @@ AQUA_MODEL_TYPE_CUSTOM = "custom"
34
33
  AQUA_MODEL_ARTIFACT_CONFIG = "config.json"
35
34
  AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
36
35
  AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
36
+ AQUA_MODEL_ARTIFACT_FILE = "model_file"
37
+ HF_LOGIN_DEFAULT_TIMEOUT = 2
37
38
 
38
39
  TRAINING_METRICS_FINAL = "training_metrics_final"
39
40
  VALIDATION_METRICS_FINAL = "validation_metrics_final"
@@ -74,3 +75,7 @@ TGI_INFERENCE_RESTRICTED_PARAMS = {
74
75
  "--sharded",
75
76
  "--trust-remote-code",
76
77
  }
78
+ LLAMA_CPP_INFERENCE_RESTRICTED_PARAMS = {
79
+ "--port",
80
+ "--host",
81
+ }
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
@@ -7,7 +7,7 @@ import os
7
7
  import re
8
8
  import tempfile
9
9
  from concurrent.futures import ThreadPoolExecutor, as_completed
10
- from dataclasses import asdict
10
+ from dataclasses import asdict, fields
11
11
  from datetime import datetime, timedelta
12
12
  from pathlib import Path
13
13
  from threading import Lock
@@ -76,6 +76,7 @@ from ads.aqua.evaluation.entities import (
76
76
  ModelParams,
77
77
  )
78
78
  from ads.aqua.evaluation.errors import EVALUATION_JOB_EXIT_CODE_MESSAGE
79
+ from ads.aqua.ui import AquaContainerConfig
79
80
  from ads.common.auth import default_signer
80
81
  from ads.common.object_storage_details import ObjectStorageDetails
81
82
  from ads.common.utils import get_console_link, get_files, get_log_links
@@ -90,7 +91,9 @@ from ads.jobs.builders.infrastructure.dsc_job import DataScienceJob
90
91
  from ads.jobs.builders.runtimes.base import Runtime
91
92
  from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
92
93
  from ads.model.datascience_model import DataScienceModel
94
+ from ads.model.deployment import ModelDeploymentContainerRuntime
93
95
  from ads.model.deployment.model_deployment import ModelDeployment
96
+ from ads.model.generic_model import ModelDeploymentRuntimeType
94
97
  from ads.model.model_metadata import (
95
98
  MetadataTaxonomyKeys,
96
99
  ModelCustomMetadata,
@@ -157,8 +160,9 @@ class AquaEvaluationApp(AquaApp):
157
160
  create_aqua_evaluation_details = CreateAquaEvaluationDetails(**kwargs)
158
161
  except Exception as ex:
159
162
  raise AquaValueError(
160
- "Invalid create evaluation parameters. Allowable parameters are: "
161
- f"{', '.join(list(asdict(CreateAquaEvaluationDetails).keys()))}."
163
+ "Invalid create evaluation parameters. "
164
+ "Allowable parameters are: "
165
+ f"{', '.join([field.name for field in fields(CreateAquaEvaluationDetails)])}."
162
166
  ) from ex
163
167
 
164
168
  if not is_valid_ocid(create_aqua_evaluation_details.evaluation_source_id):
@@ -166,8 +170,8 @@ class AquaEvaluationApp(AquaApp):
166
170
  f"Invalid evaluation source {create_aqua_evaluation_details.evaluation_source_id}. "
167
171
  "Specify either a model or model deployment id."
168
172
  )
169
-
170
173
  evaluation_source = None
174
+ eval_inference_configuration = None
171
175
  if (
172
176
  DataScienceResource.MODEL_DEPLOYMENT
173
177
  in create_aqua_evaluation_details.evaluation_source_id
@@ -175,6 +179,28 @@ class AquaEvaluationApp(AquaApp):
175
179
  evaluation_source = ModelDeployment.from_id(
176
180
  create_aqua_evaluation_details.evaluation_source_id
177
181
  )
182
+ try:
183
+ if (
184
+ evaluation_source.runtime.type
185
+ == ModelDeploymentRuntimeType.CONTAINER
186
+ ):
187
+ runtime = ModelDeploymentContainerRuntime.from_dict(
188
+ evaluation_source.runtime.to_dict()
189
+ )
190
+ inference_config = AquaContainerConfig.from_container_index_json(
191
+ enable_spec=True
192
+ ).inference
193
+ for container in inference_config.values():
194
+ if container.name == runtime.image[:runtime.image.rfind(":")]:
195
+ eval_inference_configuration = (
196
+ container.spec.evaluation_configuration
197
+ )
198
+ except Exception:
199
+ logger.debug(
200
+ f"Could not load inference config details for the evaluation id: "
201
+ f"{create_aqua_evaluation_details.evaluation_source_id}. Please check if the container"
202
+ f" runtime has the correct SMC image information."
203
+ )
178
204
  elif (
179
205
  DataScienceResource.MODEL
180
206
  in create_aqua_evaluation_details.evaluation_source_id
@@ -390,6 +416,9 @@ class AquaEvaluationApp(AquaApp):
390
416
  report_path=create_aqua_evaluation_details.report_path,
391
417
  model_parameters=create_aqua_evaluation_details.model_parameters,
392
418
  metrics=create_aqua_evaluation_details.metrics,
419
+ inference_configuration=eval_inference_configuration.to_filtered_dict()
420
+ if eval_inference_configuration
421
+ else {},
393
422
  )
394
423
  ).create(**kwargs) ## TODO: decide what parameters will be needed
395
424
  logger.debug(
@@ -511,6 +540,7 @@ class AquaEvaluationApp(AquaApp):
511
540
  report_path: str,
512
541
  model_parameters: dict,
513
542
  metrics: List = None,
543
+ inference_configuration: dict = None,
514
544
  ) -> Runtime:
515
545
  """Builds evaluation runtime for Job."""
516
546
  # TODO the image name needs to be extracted from the mapping index.json file.
@@ -520,16 +550,19 @@ class AquaEvaluationApp(AquaApp):
520
550
  .with_environment_variable(
521
551
  **{
522
552
  "AIP_SMC_EVALUATION_ARGUMENTS": json.dumps(
523
- asdict(
524
- self._build_launch_cmd(
525
- evaluation_id=evaluation_id,
526
- evaluation_source_id=evaluation_source_id,
527
- dataset_path=dataset_path,
528
- report_path=report_path,
529
- model_parameters=model_parameters,
530
- metrics=metrics,
531
- )
532
- )
553
+ {
554
+ **asdict(
555
+ self._build_launch_cmd(
556
+ evaluation_id=evaluation_id,
557
+ evaluation_source_id=evaluation_source_id,
558
+ dataset_path=dataset_path,
559
+ report_path=report_path,
560
+ model_parameters=model_parameters,
561
+ metrics=metrics,
562
+ ),
563
+ ),
564
+ **(inference_configuration or {}),
565
+ },
533
566
  ),
534
567
  "CONDA_BUCKET_NS": CONDA_BUCKET_NS,
535
568
  },
@@ -1,18 +1,24 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
6
5
 
7
6
  from importlib import metadata
8
7
 
8
+ import huggingface_hub
9
9
  import requests
10
+ from huggingface_hub import HfApi
11
+ from huggingface_hub.utils import LocalTokenNotFoundError
10
12
  from tornado.web import HTTPError
11
13
 
12
14
  from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
13
15
  from ads.aqua.common.decorator import handle_exceptions
14
16
  from ads.aqua.common.errors import AquaResourceAccessError, AquaRuntimeError
15
- from ads.aqua.common.utils import fetch_service_compartment, known_realm
17
+ from ads.aqua.common.utils import (
18
+ fetch_service_compartment,
19
+ get_huggingface_login_timeout,
20
+ known_realm,
21
+ )
16
22
  from ads.aqua.extension.base_handler import AquaAPIhandler
17
23
  from ads.aqua.extension.errors import Errors
18
24
 
@@ -46,16 +52,80 @@ class CompatibilityCheckHandler(AquaAPIhandler):
46
52
 
47
53
  """
48
54
  if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
49
- return self.finish(dict(status="ok"))
55
+ return self.finish({"status": "ok"})
50
56
  elif known_realm():
51
- return self.finish(dict(status="compatible"))
57
+ return self.finish({"status": "compatible"})
52
58
  else:
53
59
  raise AquaResourceAccessError(
54
- f"The AI Quick actions extension is not compatible in the given region."
60
+ "The AI Quick actions extension is not compatible in the given region."
55
61
  )
56
62
 
57
63
 
64
+ class NetworkStatusHandler(AquaAPIhandler):
65
+ """Handler to check internet connection."""
66
+
67
+ @handle_exceptions
68
+ def get(self):
69
+ requests.get("https://huggingface.com", timeout=get_huggingface_login_timeout())
70
+ return self.finish({"status": 200, "message": "success"})
71
+
72
+
73
+ class HFLoginHandler(AquaAPIhandler):
74
+ """Handler to login to HF."""
75
+
76
+ @handle_exceptions
77
+ def post(self, *args, **kwargs):
78
+ """Handles post request for the HF login.
79
+
80
+ Raises
81
+ ------
82
+ HTTPError
83
+ Raises HTTPError if inputs are missing or are invalid.
84
+ """
85
+ try:
86
+ input_data = self.get_json_body()
87
+ except Exception as ex:
88
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
89
+
90
+ if not input_data:
91
+ raise HTTPError(400, Errors.NO_INPUT_DATA)
92
+
93
+ token = input_data.get("token")
94
+
95
+ if not token:
96
+ raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("token"))
97
+
98
+ # Login to HF
99
+ try:
100
+ huggingface_hub.login(token=token, new_session=False)
101
+ except Exception as ex:
102
+ raise AquaRuntimeError(
103
+ reason=str(ex), service_payload={"error": type(ex).__name__}
104
+ ) from ex
105
+
106
+ return self.finish({"status": 200, "message": "login successful"})
107
+
108
+
109
+ class HFUserStatusHandler(AquaAPIhandler):
110
+ """Handler to check if user logged in to the HF."""
111
+
112
+ @handle_exceptions
113
+ def get(self):
114
+ try:
115
+ HfApi().whoami()
116
+ except LocalTokenNotFoundError as err:
117
+ raise AquaRuntimeError(
118
+ "You are not logged in. Please log in to Hugging Face using the `huggingface-cli login` command."
119
+ "See https://huggingface.co/settings/tokens.",
120
+ ) from err
121
+
122
+ return self.finish({"status": 200, "message": "logged in"})
123
+
124
+
58
125
  __handlers__ = [
59
126
  ("ads_version", ADSVersionHandler),
60
127
  ("hello", CompatibilityCheckHandler),
128
+ ("network_status", NetworkStatusHandler),
129
+ ("hf_login", HFLoginHandler),
130
+ ("hf_logged_in", HFUserStatusHandler),
61
131
  ]