oracle-ads 2.13.1rc0__py3-none-any.whl → 2.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. ads/aqua/__init__.py +7 -1
  2. ads/aqua/app.py +24 -23
  3. ads/aqua/client/client.py +48 -11
  4. ads/aqua/common/entities.py +28 -1
  5. ads/aqua/common/enums.py +13 -7
  6. ads/aqua/common/utils.py +8 -13
  7. ads/aqua/config/container_config.py +203 -0
  8. ads/aqua/config/evaluation/evaluation_service_config.py +5 -181
  9. ads/aqua/constants.py +0 -1
  10. ads/aqua/evaluation/evaluation.py +4 -4
  11. ads/aqua/extension/base_handler.py +4 -0
  12. ads/aqua/extension/model_handler.py +19 -28
  13. ads/aqua/finetuning/finetuning.py +2 -3
  14. ads/aqua/model/entities.py +2 -3
  15. ads/aqua/model/model.py +25 -30
  16. ads/aqua/modeldeployment/deployment.py +6 -14
  17. ads/aqua/modeldeployment/entities.py +2 -2
  18. ads/aqua/server/__init__.py +4 -0
  19. ads/aqua/server/__main__.py +24 -0
  20. ads/aqua/server/app.py +47 -0
  21. ads/aqua/server/aqua_spec.yml +1291 -0
  22. ads/aqua/ui.py +5 -199
  23. ads/common/auth.py +20 -11
  24. ads/common/utils.py +91 -11
  25. ads/config.py +3 -0
  26. ads/llm/__init__.py +1 -0
  27. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +32 -23
  28. ads/model/artifact_downloader.py +4 -1
  29. ads/model/common/utils.py +15 -3
  30. ads/model/datascience_model.py +339 -8
  31. ads/model/model_metadata.py +54 -14
  32. ads/model/model_version_set.py +5 -3
  33. ads/model/service/oci_datascience_model.py +477 -5
  34. ads/opctl/operator/common/utils.py +16 -0
  35. ads/opctl/operator/lowcode/common/data.py +5 -2
  36. ads/opctl/operator/lowcode/common/transformations.py +2 -12
  37. ads/opctl/operator/lowcode/forecast/model/automlx.py +10 -2
  38. ads/opctl/operator/lowcode/forecast/model/base_model.py +9 -10
  39. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +19 -11
  40. ads/opctl/operator/lowcode/forecast/model_evaluator.py +13 -15
  41. ads/opctl/operator/lowcode/forecast/schema.yaml +1 -1
  42. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +7 -0
  43. {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2.dist-info}/METADATA +15 -12
  44. {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2.dist-info}/RECORD +47 -43
  45. {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2.dist-info}/WHEEL +1 -1
  46. ads/aqua/config/evaluation/evaluation_service_model_config.py +0 -8
  47. {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2.dist-info}/entry_points.txt +0 -0
  48. {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2.dist-info/licenses}/LICENSE.txt +0 -0
ads/aqua/__init__.py CHANGED
@@ -7,7 +7,13 @@ import os
7
7
  from logging import getLogger
8
8
 
9
9
  from ads import logger, set_auth
10
- from ads.aqua.client.client import AsyncClient, Client
10
+ from ads.aqua.client.client import (
11
+ AsyncClient,
12
+ Client,
13
+ HttpxOCIAuth,
14
+ get_async_httpx_client,
15
+ get_httpx_client,
16
+ )
11
17
  from ads.aqua.common.utils import fetch_service_compartment
12
18
  from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
13
19
 
ads/aqua/app.py CHANGED
@@ -6,13 +6,14 @@ import json
6
6
  import os
7
7
  import traceback
8
8
  from dataclasses import fields
9
- from typing import Dict, Optional, Union
9
+ from typing import Any, Dict, Optional, Union
10
10
 
11
11
  import oci
12
12
  from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
13
13
 
14
14
  from ads import set_auth
15
15
  from ads.aqua import logger
16
+ from ads.aqua.common.entities import ModelConfigResult
16
17
  from ads.aqua.common.enums import ConfigFolder, Tags
17
18
  from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
18
19
  from ads.aqua.common.utils import (
@@ -21,10 +22,9 @@ from ads.aqua.common.utils import (
21
22
  is_valid_ocid,
22
23
  load_config,
23
24
  )
24
- from ads.aqua.constants import UNKNOWN
25
25
  from ads.common import oci_client as oc
26
26
  from ads.common.auth import default_signer
27
- from ads.common.utils import extract_region, is_path_exists
27
+ from ads.common.utils import UNKNOWN, extract_region, is_path_exists
28
28
  from ads.config import (
29
29
  AQUA_TELEMETRY_BUCKET,
30
30
  AQUA_TELEMETRY_BUCKET_NS,
@@ -273,24 +273,24 @@ class AquaApp:
273
273
  model_id: str,
274
274
  config_file_name: str,
275
275
  config_folder: Optional[str] = ConfigFolder.CONFIG,
276
- ) -> Dict:
277
- """Gets the config for the given Aqua model.
276
+ ) -> ModelConfigResult:
277
+ """
278
+ Gets the configuration for the given Aqua model along with the model details.
278
279
 
279
280
  Parameters
280
281
  ----------
281
- model_id: str
282
+ model_id : str
282
283
  The OCID of the Aqua model.
283
- config_file_name: str
284
- name of the config file
285
- config_folder: (str, optional):
286
- subfolder path where config_file_name needs to be searched
287
- Defaults to `ConfigFolder.CONFIG`.
288
- When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT`
284
+ config_file_name : str
285
+ The name of the configuration file.
286
+ config_folder : Optional[str]
287
+ The subfolder path where config_file_name is searched.
288
+ Defaults to ConfigFolder.CONFIG. For model artifact directories, use ConfigFolder.ARTIFACT.
289
289
 
290
290
  Returns
291
291
  -------
292
- Dict:
293
- A dict of allowed configs.
292
+ ModelConfigResult
293
+ A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
294
294
  """
295
295
  config_folder = config_folder or ConfigFolder.CONFIG
296
296
  oci_model = self.ds_client.get_model(model_id).data
@@ -302,11 +302,11 @@ class AquaApp:
302
302
  if oci_model.freeform_tags
303
303
  else False
304
304
  )
305
-
306
305
  if not oci_aqua:
307
- raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
306
+ raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
307
+
308
+ config: Dict[str, Any] = {}
308
309
 
309
- config = {}
310
310
  # if the current model has a service model tag, then
311
311
  if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags:
312
312
  base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG]
@@ -326,7 +326,7 @@ class AquaApp:
326
326
  logger.debug(
327
327
  f"Failed to get artifact path from custom metadata for the model: {model_id}"
328
328
  )
329
- return config
329
+ return ModelConfigResult(config=config, model_details=oci_model)
330
330
 
331
331
  config_path = os.path.join(os.path.dirname(artifact_path), config_folder)
332
332
  if not is_path_exists(config_path):
@@ -351,9 +351,8 @@ class AquaApp:
351
351
  f"{config_file_name} is not available for the model: {model_id}. "
352
352
  f"Check if the custom metadata has the artifact path set."
353
353
  )
354
- return config
355
354
 
356
- return config
355
+ return ModelConfigResult(config=config, model_details=oci_model)
357
356
 
358
357
  @property
359
358
  def telemetry(self):
@@ -375,9 +374,11 @@ class CLIBuilderMixin:
375
374
  """
376
375
  cmd = f"ads aqua {self._command}"
377
376
  params = [
378
- f"--{field.name} {json.dumps(getattr(self, field.name))}"
379
- if isinstance(getattr(self, field.name), dict)
380
- else f"--{field.name} {getattr(self, field.name)}"
377
+ (
378
+ f"--{field.name} {json.dumps(getattr(self, field.name))}"
379
+ if isinstance(getattr(self, field.name), dict)
380
+ else f"--{field.name} {getattr(self, field.name)}"
381
+ )
381
382
  for field in fields(self.__class__)
382
383
  if getattr(self, field.name) is not None
383
384
  ]
ads/aqua/client/client.py CHANGED
@@ -51,7 +51,7 @@ _T = TypeVar("_T", bound="BaseClient")
51
51
  logger = logging.getLogger(__name__)
52
52
 
53
53
 
54
- class OCIAuth(httpx.Auth):
54
+ class HttpxOCIAuth(httpx.Auth):
55
55
  """
56
56
  Custom HTTPX authentication class that uses the OCI Signer for request signing.
57
57
 
@@ -59,14 +59,15 @@ class OCIAuth(httpx.Auth):
59
59
  signer (oci.signer.Signer): The OCI signer used to sign requests.
60
60
  """
61
61
 
62
- def __init__(self, signer: oci.signer.Signer):
62
+ def __init__(self, signer: Optional[oci.signer.Signer] = None):
63
63
  """
64
- Initialize the OCIAuth instance.
64
+ Initialize the HttpxOCIAuth instance.
65
65
 
66
66
  Args:
67
67
  signer (oci.signer.Signer): The OCI signer to use for signing requests.
68
68
  """
69
- self.signer = signer
69
+
70
+ self.signer = signer or authutil.default_signer().get("signer")
70
71
 
71
72
  def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
72
73
  """
@@ -256,7 +257,7 @@ class BaseClient:
256
257
  auth = auth or authutil.default_signer()
257
258
  if not callable(auth.get("signer")):
258
259
  raise ValueError("Auth object must have a 'signer' callable attribute.")
259
- self.auth = OCIAuth(auth["signer"])
260
+ self.auth = HttpxOCIAuth(auth["signer"])
260
261
 
261
262
  logger.debug(
262
263
  f"Initialized {self.__class__.__name__} with endpoint={self.endpoint}, "
@@ -352,7 +353,7 @@ class Client(BaseClient):
352
353
  **kwargs: Keyword arguments forwarded to BaseClient.
353
354
  """
354
355
  super().__init__(*args, **kwargs)
355
- self._client = httpx.Client(timeout=self.timeout)
356
+ self._client = httpx.Client(timeout=self.timeout, auth=self.auth)
356
357
 
357
358
  def is_closed(self) -> bool:
358
359
  return self._client.is_closed
@@ -400,7 +401,6 @@ class Client(BaseClient):
400
401
  response = self._client.post(
401
402
  self.endpoint,
402
403
  headers=self._prepare_headers(stream=False, headers=headers),
403
- auth=self.auth,
404
404
  json=payload,
405
405
  )
406
406
  logger.debug(f"Received response with status code: {response.status_code}")
@@ -447,7 +447,6 @@ class Client(BaseClient):
447
447
  "POST",
448
448
  self.endpoint,
449
449
  headers=self._prepare_headers(stream=True, headers=headers),
450
- auth=self.auth,
451
450
  json={**payload, "stream": True},
452
451
  ) as response:
453
452
  try:
@@ -581,7 +580,7 @@ class AsyncClient(BaseClient):
581
580
  **kwargs: Keyword arguments forwarded to BaseClient.
582
581
  """
583
582
  super().__init__(*args, **kwargs)
584
- self._client = httpx.AsyncClient(timeout=self.timeout)
583
+ self._client = httpx.AsyncClient(timeout=self.timeout, auth=self.auth)
585
584
 
586
585
  def is_closed(self) -> bool:
587
586
  return self._client.is_closed
@@ -637,7 +636,6 @@ class AsyncClient(BaseClient):
637
636
  response = await self._client.post(
638
637
  self.endpoint,
639
638
  headers=self._prepare_headers(stream=False, headers=headers),
640
- auth=self.auth,
641
639
  json=payload,
642
640
  )
643
641
  logger.debug(f"Received response with status code: {response.status_code}")
@@ -683,7 +681,6 @@ class AsyncClient(BaseClient):
683
681
  "POST",
684
682
  self.endpoint,
685
683
  headers=self._prepare_headers(stream=True, headers=headers),
686
- auth=self.auth,
687
684
  json={**payload, "stream": True},
688
685
  ) as response:
689
686
  try:
@@ -797,3 +794,43 @@ class AsyncClient(BaseClient):
797
794
  logger.debug(f"Generating embeddings with input: {input}, payload: {payload}")
798
795
  payload = {**(payload or {}), "input": input}
799
796
  return await self._request(payload=payload, headers=headers)
797
+
798
+
799
+ def get_httpx_client(**kwargs: Any) -> httpx.Client:
800
+ """
801
+ Creates and returns a synchronous httpx Client configured with OCI authentication signer based
802
+ the authentication type setup using ads.set_auth method or env variable OCI_IAM_TYPE.
803
+ More information - https://accelerated-data-science.readthedocs.io/en/stable/user_guide/cli/authentication.html
804
+
805
+ Parameters
806
+ ----------
807
+ **kwargs : Any
808
+ Keyword arguments supported by httpx.Client
809
+
810
+ Returns
811
+ -------
812
+ Client
813
+ A configured synchronous httpx Client instance.
814
+ """
815
+ kwargs["auth"] = kwargs.get("auth") or HttpxOCIAuth()
816
+ return httpx.Client(**kwargs)
817
+
818
+
819
+ def get_async_httpx_client(**kwargs: Any) -> httpx.AsyncClient:
820
+ """
821
+ Creates and returns a synchronous httpx Client configured with OCI authentication signer based
822
+ the authentication type setup using ads.set_auth method or env variable OCI_IAM_TYPE.
823
+ More information - https://accelerated-data-science.readthedocs.io/en/stable/user_guide/cli/authentication.html
824
+
825
+ Parameters
826
+ ----------
827
+ **kwargs : Any
828
+ Keyword arguments supported by httpx.Client
829
+
830
+ Returns
831
+ -------
832
+ AsyncClient
833
+ A configured asynchronous httpx AsyncClient instance.
834
+ """
835
+ kwargs["auth"] = kwargs.get("auth") or HttpxOCIAuth()
836
+ return httpx.AsyncClient(**kwargs)
@@ -1,7 +1,12 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
 
5
+ from typing import Any, Dict, Optional
6
+
7
+ from oci.data_science.models import Model
8
+ from pydantic import BaseModel, Field
9
+
5
10
 
6
11
  class ContainerSpec:
7
12
  """
@@ -15,3 +20,25 @@ class ContainerSpec:
15
20
  ENV_VARS = "envVars"
16
21
  RESTRICTED_PARAMS = "restrictedParams"
17
22
  EVALUATION_CONFIGURATION = "evaluationConfiguration"
23
+
24
+
25
+ class ModelConfigResult(BaseModel):
26
+ """
27
+ Represents the result of getting the AQUA model configuration.
28
+
29
+ Attributes:
30
+ model_details (Dict[str, Any]): A dictionary containing model details extracted from OCI.
31
+ config (Dict[str, Any]): A dictionary of the loaded configuration.
32
+ """
33
+
34
+ config: Optional[Dict[str, Any]] = Field(
35
+ None, description="Loaded configuration dictionary."
36
+ )
37
+ model_details: Optional[Model] = Field(
38
+ None, description="Details of the model from OCI."
39
+ )
40
+
41
+ class Config:
42
+ extra = "ignore"
43
+ arbitrary_types_allowed = True
44
+ protected_namespaces = ()
ads/aqua/common/enums.py CHANGED
@@ -2,12 +2,6 @@
2
2
  # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
 
5
- """
6
- aqua.common.enums
7
- ~~~~~~~~~~~~~~
8
- This module contains the set of enums used in AQUA.
9
- """
10
-
11
5
  from ads.common.extended_enum import ExtendedEnum
12
6
 
13
7
 
@@ -88,7 +82,8 @@ class RqsAdditionalDetails(ExtendedEnum):
88
82
 
89
83
  class TextEmbeddingInferenceContainerParams(ExtendedEnum):
90
84
  """Contains a subset of params that are required for enabling model deployment in OCI Data Science. More options
91
- are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments"""
85
+ are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments
86
+ """
92
87
 
93
88
  MODEL_ID = "model-id"
94
89
  PORT = "port"
@@ -97,3 +92,14 @@ class TextEmbeddingInferenceContainerParams(ExtendedEnum):
97
92
  class ConfigFolder(ExtendedEnum):
98
93
  CONFIG = "config"
99
94
  ARTIFACT = "artifact"
95
+
96
+
97
+ class ModelFormat(ExtendedEnum):
98
+ GGUF = "GGUF"
99
+ SAFETENSORS = "SAFETENSORS"
100
+ UNKNOWN = "UNKNOWN"
101
+
102
+
103
+ class Platform(ExtendedEnum):
104
+ ARM_CPU = "ARM_CPU"
105
+ NVIDIA_GPU = "NVIDIA_GPU"
ads/aqua/common/utils.py CHANGED
@@ -19,7 +19,6 @@ from pathlib import Path
19
19
  from string import Template
20
20
  from typing import List, Union
21
21
 
22
- import fsspec
23
22
  import oci
24
23
  from cachetools import TTLCache, cached
25
24
  from huggingface_hub.constants import HF_HUB_CACHE
@@ -58,7 +57,6 @@ from ads.aqua.constants import (
58
57
  SUPPORTED_FILE_FORMATS,
59
58
  TEI_CONTAINER_DEFAULT_HOST,
60
59
  TGI_INFERENCE_RESTRICTED_PARAMS,
61
- UNKNOWN,
62
60
  UNKNOWN_JSON_STR,
63
61
  VLLM_INFERENCE_RESTRICTED_PARAMS,
64
62
  )
@@ -68,7 +66,13 @@ from ads.common.decorator.threaded import threaded
68
66
  from ads.common.extended_enum import ExtendedEnum
69
67
  from ads.common.object_storage_details import ObjectStorageDetails
70
68
  from ads.common.oci_resource import SEARCH_TYPE, OCIResource
71
- from ads.common.utils import copy_file, get_console_link, upload_to_os
69
+ from ads.common.utils import (
70
+ UNKNOWN,
71
+ copy_file,
72
+ get_console_link,
73
+ read_file,
74
+ upload_to_os,
75
+ )
72
76
  from ads.config import (
73
77
  AQUA_MODEL_DEPLOYMENT_FOLDER,
74
78
  AQUA_SERVICE_MODELS_BUCKET,
@@ -228,15 +232,6 @@ def get_artifact_path(custom_metadata_list: List) -> str:
228
232
  return UNKNOWN
229
233
 
230
234
 
231
- def read_file(file_path: str, **kwargs) -> str:
232
- try:
233
- with fsspec.open(file_path, "r", **kwargs.get("auth", {})) as f:
234
- return f.read()
235
- except Exception as e:
236
- logger.debug(f"Failed to read file {file_path}. {e}")
237
- return UNKNOWN
238
-
239
-
240
235
  @threaded()
241
236
  def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
242
237
  artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
@@ -553,7 +548,7 @@ def service_config_path():
553
548
  return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
554
549
 
555
550
 
556
- @cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
551
+ @cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=10), timer=datetime.now))
557
552
  def get_container_config():
558
553
  config = load_config(
559
554
  file_path=service_config_path(),
@@ -0,0 +1,203 @@
1
+ #!/usr/bin/env python
2
+ # Copyright (c) 2025 Oracle and/or its affiliates.
3
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
+
5
+ from typing import Dict, List, Optional
6
+
7
+ from pydantic import Field
8
+
9
+ from ads.aqua.common.entities import ContainerSpec
10
+ from ads.aqua.config.utils.serializer import Serializable
11
+
12
+
13
+ class AquaContainerConfigSpec(Serializable):
14
+ """
15
+ Represents container specification details.
16
+
17
+ Attributes
18
+ ----------
19
+ cli_param (Optional[str]): CLI parameter for container configuration.
20
+ server_port (Optional[str]): The server port for the container.
21
+ health_check_port (Optional[str]): The health check port for the container.
22
+ env_vars (Optional[List[Dict]]): Environment variables for the container.
23
+ restricted_params (Optional[List[str]]): Restricted parameters for container configuration.
24
+ """
25
+
26
+ cli_param: Optional[str] = Field(
27
+ default=None, description="CLI parameter for container configuration."
28
+ )
29
+ server_port: Optional[str] = Field(
30
+ default=None, description="Server port for the container."
31
+ )
32
+ health_check_port: Optional[str] = Field(
33
+ default=None, description="Health check port for the container."
34
+ )
35
+ env_vars: Optional[List[Dict]] = Field(
36
+ default_factory=list, description="List of environment variables."
37
+ )
38
+ restricted_params: Optional[List[str]] = Field(
39
+ default_factory=list, description="List of restricted parameters."
40
+ )
41
+
42
+ class Config:
43
+ extra = "allow"
44
+
45
+
46
+ class AquaContainerConfigItem(Serializable):
47
+ """
48
+ Represents an item of the AQUA container configuration.
49
+
50
+ Attributes
51
+ ----------
52
+ name (Optional[str]): Name of the container configuration item.
53
+ version (Optional[str]): Version of the container.
54
+ display_name (Optional[str]): Display name for UI.
55
+ family (Optional[str]): Container family or category.
56
+ platforms (Optional[List[str]]): Supported platforms.
57
+ model_formats (Optional[List[str]]): Supported model formats.
58
+ spec (Optional[AquaContainerConfigSpec]): Container specification details.
59
+ """
60
+
61
+ name: Optional[str] = Field(
62
+ default=None, description="Name of the container configuration item."
63
+ )
64
+ version: Optional[str] = Field(
65
+ default=None, description="Version of the container."
66
+ )
67
+ display_name: Optional[str] = Field(
68
+ default=None, description="Display name of the container."
69
+ )
70
+ family: Optional[str] = Field(
71
+ default=None, description="Container family or category."
72
+ )
73
+ platforms: Optional[List[str]] = Field(
74
+ default_factory=list, description="Supported platforms."
75
+ )
76
+ model_formats: Optional[List[str]] = Field(
77
+ default_factory=list, description="Supported model formats."
78
+ )
79
+ spec: Optional[AquaContainerConfigSpec] = Field(
80
+ default_factory=AquaContainerConfigSpec,
81
+ description="Detailed container specification.",
82
+ )
83
+ usages: Optional[List[str]] = Field(
84
+ default_factory=list, description="Supported usages."
85
+ )
86
+
87
+ class Config:
88
+ extra = "allow"
89
+
90
+
91
+ class AquaContainerConfig(Serializable):
92
+ """
93
+ Represents a configuration of AQUA containers to be returned to the client.
94
+
95
+ Attributes
96
+ ----------
97
+ inference (Dict[str, AquaContainerConfigItem]): Inference container configuration items.
98
+ finetune (Dict[str, AquaContainerConfigItem]): Fine-tuning container configuration items.
99
+ evaluate (Dict[str, AquaContainerConfigItem]): Evaluation container configuration items.
100
+ """
101
+
102
+ inference: Dict[str, AquaContainerConfigItem] = Field(
103
+ default_factory=dict, description="Inference container configuration items."
104
+ )
105
+ finetune: Dict[str, AquaContainerConfigItem] = Field(
106
+ default_factory=dict, description="Fine-tuning container configuration items."
107
+ )
108
+ evaluate: Dict[str, AquaContainerConfigItem] = Field(
109
+ default_factory=dict, description="Evaluation container configuration items."
110
+ )
111
+
112
+ def to_dict(self):
113
+ return {
114
+ "inference": list(self.inference.values()),
115
+ "finetune": list(self.finetune.values()),
116
+ "evaluate": list(self.evaluate.values()),
117
+ }
118
+
119
+ @classmethod
120
+ def from_container_index_json(
121
+ cls,
122
+ config: Dict,
123
+ enable_spec: Optional[bool] = False,
124
+ ) -> "AquaContainerConfig":
125
+ """
126
+ Creates an AquaContainerConfig instance from a container index JSON.
127
+
128
+ Parameters
129
+ ----------
130
+ config (Optional[Dict]): The container index JSON.
131
+ enable_spec (Optional[bool]): If True, fetch container specification details.
132
+
133
+ Returns
134
+ -------
135
+ AquaContainerConfig: The constructed container configuration.
136
+ """
137
+ # TODO: Return this logic back if necessary in the next iteraion.
138
+ # if not config:
139
+ # config = get_container_config()
140
+
141
+ inference_items: Dict[str, AquaContainerConfigItem] = {}
142
+ finetune_items: Dict[str, AquaContainerConfigItem] = {}
143
+ evaluate_items: Dict[str, AquaContainerConfigItem] = {}
144
+
145
+ for container_type, containers in config.items():
146
+ if isinstance(containers, list):
147
+ for container in containers:
148
+ platforms = container.get("platforms", [])
149
+ model_formats = container.get("modelFormats", [])
150
+ usages = container.get("usages", [])
151
+ container_spec = (
152
+ config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
153
+ container_type, {}
154
+ )
155
+ if enable_spec
156
+ else None
157
+ )
158
+ container_item = AquaContainerConfigItem(
159
+ name=container.get("name", ""),
160
+ version=container.get("version", ""),
161
+ display_name=container.get(
162
+ "displayName", container.get("version", "")
163
+ ),
164
+ family=container_type,
165
+ platforms=platforms,
166
+ model_formats=model_formats,
167
+ usages=usages,
168
+ spec=(
169
+ AquaContainerConfigSpec(
170
+ cli_param=container_spec.get(
171
+ ContainerSpec.CLI_PARM, ""
172
+ ),
173
+ server_port=container_spec.get(
174
+ ContainerSpec.SERVER_PORT, ""
175
+ ),
176
+ health_check_port=container_spec.get(
177
+ ContainerSpec.HEALTH_CHECK_PORT, ""
178
+ ),
179
+ env_vars=container_spec.get(ContainerSpec.ENV_VARS, []),
180
+ restricted_params=container_spec.get(
181
+ ContainerSpec.RESTRICTED_PARAMS, []
182
+ ),
183
+ )
184
+ if container_spec
185
+ else None
186
+ ),
187
+ )
188
+ if container.get("type") == "inference":
189
+ inference_items[container_type] = container_item
190
+ elif (
191
+ container.get("type") == "fine-tune"
192
+ or container_type == "odsc-llm-fine-tuning"
193
+ ):
194
+ finetune_items[container_type] = container_item
195
+ elif (
196
+ container.get("type") == "evaluate"
197
+ or container_type == "odsc-llm-evaluate"
198
+ ):
199
+ evaluate_items[container_type] = container_item
200
+
201
+ return cls(
202
+ inference=inference_items, finetune=finetune_items, evaluate=evaluate_items
203
+ )