oracle-ads 2.13.1rc0__py3-none-any.whl → 2.13.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +7 -1
- ads/aqua/app.py +24 -23
- ads/aqua/client/client.py +48 -11
- ads/aqua/common/entities.py +28 -1
- ads/aqua/common/enums.py +13 -7
- ads/aqua/common/utils.py +8 -13
- ads/aqua/config/container_config.py +203 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +5 -181
- ads/aqua/constants.py +0 -1
- ads/aqua/evaluation/evaluation.py +4 -4
- ads/aqua/extension/base_handler.py +4 -0
- ads/aqua/extension/model_handler.py +19 -28
- ads/aqua/finetuning/finetuning.py +2 -3
- ads/aqua/model/entities.py +2 -3
- ads/aqua/model/model.py +25 -30
- ads/aqua/modeldeployment/deployment.py +6 -14
- ads/aqua/modeldeployment/entities.py +2 -2
- ads/aqua/server/__init__.py +4 -0
- ads/aqua/server/__main__.py +24 -0
- ads/aqua/server/app.py +47 -0
- ads/aqua/server/aqua_spec.yml +1291 -0
- ads/aqua/ui.py +5 -199
- ads/common/auth.py +20 -11
- ads/common/utils.py +91 -11
- ads/config.py +3 -0
- ads/llm/__init__.py +1 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +32 -23
- ads/model/artifact_downloader.py +4 -1
- ads/model/common/utils.py +15 -3
- ads/model/datascience_model.py +339 -8
- ads/model/model_metadata.py +54 -14
- ads/model/model_version_set.py +5 -3
- ads/model/service/oci_datascience_model.py +477 -5
- ads/opctl/operator/common/utils.py +16 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +3 -3
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +1 -1
- ads/opctl/operator/lowcode/anomaly/utils.py +1 -1
- ads/opctl/operator/lowcode/common/data.py +5 -2
- ads/opctl/operator/lowcode/common/transformations.py +7 -13
- ads/opctl/operator/lowcode/common/utils.py +7 -2
- ads/opctl/operator/lowcode/forecast/model/arima.py +15 -10
- ads/opctl/operator/lowcode/forecast/model/automlx.py +39 -9
- ads/opctl/operator/lowcode/forecast/model/autots.py +7 -5
- ads/opctl/operator/lowcode/forecast/model/base_model.py +135 -110
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +30 -14
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +2 -2
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +46 -32
- ads/opctl/operator/lowcode/forecast/model/prophet.py +82 -29
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +142 -62
- ads/opctl/operator/lowcode/forecast/operator_config.py +29 -3
- ads/opctl/operator/lowcode/forecast/schema.yaml +1 -1
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +108 -56
- {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info}/METADATA +15 -12
- {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info}/RECORD +57 -53
- {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info}/WHEEL +1 -1
- ads/aqua/config/evaluation/evaluation_service_model_config.py +0 -8
- {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info}/entry_points.txt +0 -0
- {oracle_ads-2.13.1rc0.dist-info → oracle_ads-2.13.2rc1.dist-info/licenses}/LICENSE.txt +0 -0
ads/aqua/__init__.py
CHANGED
@@ -7,7 +7,13 @@ import os
|
|
7
7
|
from logging import getLogger
|
8
8
|
|
9
9
|
from ads import logger, set_auth
|
10
|
-
from ads.aqua.client.client import
|
10
|
+
from ads.aqua.client.client import (
|
11
|
+
AsyncClient,
|
12
|
+
Client,
|
13
|
+
HttpxOCIAuth,
|
14
|
+
get_async_httpx_client,
|
15
|
+
get_httpx_client,
|
16
|
+
)
|
11
17
|
from ads.aqua.common.utils import fetch_service_compartment
|
12
18
|
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
|
13
19
|
|
ads/aqua/app.py
CHANGED
@@ -6,13 +6,14 @@ import json
|
|
6
6
|
import os
|
7
7
|
import traceback
|
8
8
|
from dataclasses import fields
|
9
|
-
from typing import Dict, Optional, Union
|
9
|
+
from typing import Any, Dict, Optional, Union
|
10
10
|
|
11
11
|
import oci
|
12
12
|
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
|
13
13
|
|
14
14
|
from ads import set_auth
|
15
15
|
from ads.aqua import logger
|
16
|
+
from ads.aqua.common.entities import ModelConfigResult
|
16
17
|
from ads.aqua.common.enums import ConfigFolder, Tags
|
17
18
|
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
|
18
19
|
from ads.aqua.common.utils import (
|
@@ -21,10 +22,9 @@ from ads.aqua.common.utils import (
|
|
21
22
|
is_valid_ocid,
|
22
23
|
load_config,
|
23
24
|
)
|
24
|
-
from ads.aqua.constants import UNKNOWN
|
25
25
|
from ads.common import oci_client as oc
|
26
26
|
from ads.common.auth import default_signer
|
27
|
-
from ads.common.utils import extract_region, is_path_exists
|
27
|
+
from ads.common.utils import UNKNOWN, extract_region, is_path_exists
|
28
28
|
from ads.config import (
|
29
29
|
AQUA_TELEMETRY_BUCKET,
|
30
30
|
AQUA_TELEMETRY_BUCKET_NS,
|
@@ -273,24 +273,24 @@ class AquaApp:
|
|
273
273
|
model_id: str,
|
274
274
|
config_file_name: str,
|
275
275
|
config_folder: Optional[str] = ConfigFolder.CONFIG,
|
276
|
-
) ->
|
277
|
-
"""
|
276
|
+
) -> ModelConfigResult:
|
277
|
+
"""
|
278
|
+
Gets the configuration for the given Aqua model along with the model details.
|
278
279
|
|
279
280
|
Parameters
|
280
281
|
----------
|
281
|
-
model_id: str
|
282
|
+
model_id : str
|
282
283
|
The OCID of the Aqua model.
|
283
|
-
config_file_name: str
|
284
|
-
name of the
|
285
|
-
config_folder:
|
286
|
-
subfolder path where config_file_name
|
287
|
-
|
288
|
-
When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT`
|
284
|
+
config_file_name : str
|
285
|
+
The name of the configuration file.
|
286
|
+
config_folder : Optional[str]
|
287
|
+
The subfolder path where config_file_name is searched.
|
288
|
+
Defaults to ConfigFolder.CONFIG. For model artifact directories, use ConfigFolder.ARTIFACT.
|
289
289
|
|
290
290
|
Returns
|
291
291
|
-------
|
292
|
-
|
293
|
-
A
|
292
|
+
ModelConfigResult
|
293
|
+
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
|
294
294
|
"""
|
295
295
|
config_folder = config_folder or ConfigFolder.CONFIG
|
296
296
|
oci_model = self.ds_client.get_model(model_id).data
|
@@ -302,11 +302,11 @@ class AquaApp:
|
|
302
302
|
if oci_model.freeform_tags
|
303
303
|
else False
|
304
304
|
)
|
305
|
-
|
306
305
|
if not oci_aqua:
|
307
|
-
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
|
306
|
+
raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
|
307
|
+
|
308
|
+
config: Dict[str, Any] = {}
|
308
309
|
|
309
|
-
config = {}
|
310
310
|
# if the current model has a service model tag, then
|
311
311
|
if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags:
|
312
312
|
base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG]
|
@@ -326,7 +326,7 @@ class AquaApp:
|
|
326
326
|
logger.debug(
|
327
327
|
f"Failed to get artifact path from custom metadata for the model: {model_id}"
|
328
328
|
)
|
329
|
-
return config
|
329
|
+
return ModelConfigResult(config=config, model_details=oci_model)
|
330
330
|
|
331
331
|
config_path = os.path.join(os.path.dirname(artifact_path), config_folder)
|
332
332
|
if not is_path_exists(config_path):
|
@@ -351,9 +351,8 @@ class AquaApp:
|
|
351
351
|
f"{config_file_name} is not available for the model: {model_id}. "
|
352
352
|
f"Check if the custom metadata has the artifact path set."
|
353
353
|
)
|
354
|
-
return config
|
355
354
|
|
356
|
-
return config
|
355
|
+
return ModelConfigResult(config=config, model_details=oci_model)
|
357
356
|
|
358
357
|
@property
|
359
358
|
def telemetry(self):
|
@@ -375,9 +374,11 @@ class CLIBuilderMixin:
|
|
375
374
|
"""
|
376
375
|
cmd = f"ads aqua {self._command}"
|
377
376
|
params = [
|
378
|
-
|
379
|
-
|
380
|
-
|
377
|
+
(
|
378
|
+
f"--{field.name} {json.dumps(getattr(self, field.name))}"
|
379
|
+
if isinstance(getattr(self, field.name), dict)
|
380
|
+
else f"--{field.name} {getattr(self, field.name)}"
|
381
|
+
)
|
381
382
|
for field in fields(self.__class__)
|
382
383
|
if getattr(self, field.name) is not None
|
383
384
|
]
|
ads/aqua/client/client.py
CHANGED
@@ -51,7 +51,7 @@ _T = TypeVar("_T", bound="BaseClient")
|
|
51
51
|
logger = logging.getLogger(__name__)
|
52
52
|
|
53
53
|
|
54
|
-
class
|
54
|
+
class HttpxOCIAuth(httpx.Auth):
|
55
55
|
"""
|
56
56
|
Custom HTTPX authentication class that uses the OCI Signer for request signing.
|
57
57
|
|
@@ -59,14 +59,15 @@ class OCIAuth(httpx.Auth):
|
|
59
59
|
signer (oci.signer.Signer): The OCI signer used to sign requests.
|
60
60
|
"""
|
61
61
|
|
62
|
-
def __init__(self, signer: oci.signer.Signer):
|
62
|
+
def __init__(self, signer: Optional[oci.signer.Signer] = None):
|
63
63
|
"""
|
64
|
-
Initialize the
|
64
|
+
Initialize the HttpxOCIAuth instance.
|
65
65
|
|
66
66
|
Args:
|
67
67
|
signer (oci.signer.Signer): The OCI signer to use for signing requests.
|
68
68
|
"""
|
69
|
-
|
69
|
+
|
70
|
+
self.signer = signer or authutil.default_signer().get("signer")
|
70
71
|
|
71
72
|
def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
|
72
73
|
"""
|
@@ -256,7 +257,7 @@ class BaseClient:
|
|
256
257
|
auth = auth or authutil.default_signer()
|
257
258
|
if not callable(auth.get("signer")):
|
258
259
|
raise ValueError("Auth object must have a 'signer' callable attribute.")
|
259
|
-
self.auth =
|
260
|
+
self.auth = HttpxOCIAuth(auth["signer"])
|
260
261
|
|
261
262
|
logger.debug(
|
262
263
|
f"Initialized {self.__class__.__name__} with endpoint={self.endpoint}, "
|
@@ -352,7 +353,7 @@ class Client(BaseClient):
|
|
352
353
|
**kwargs: Keyword arguments forwarded to BaseClient.
|
353
354
|
"""
|
354
355
|
super().__init__(*args, **kwargs)
|
355
|
-
self._client = httpx.Client(timeout=self.timeout)
|
356
|
+
self._client = httpx.Client(timeout=self.timeout, auth=self.auth)
|
356
357
|
|
357
358
|
def is_closed(self) -> bool:
|
358
359
|
return self._client.is_closed
|
@@ -400,7 +401,6 @@ class Client(BaseClient):
|
|
400
401
|
response = self._client.post(
|
401
402
|
self.endpoint,
|
402
403
|
headers=self._prepare_headers(stream=False, headers=headers),
|
403
|
-
auth=self.auth,
|
404
404
|
json=payload,
|
405
405
|
)
|
406
406
|
logger.debug(f"Received response with status code: {response.status_code}")
|
@@ -447,7 +447,6 @@ class Client(BaseClient):
|
|
447
447
|
"POST",
|
448
448
|
self.endpoint,
|
449
449
|
headers=self._prepare_headers(stream=True, headers=headers),
|
450
|
-
auth=self.auth,
|
451
450
|
json={**payload, "stream": True},
|
452
451
|
) as response:
|
453
452
|
try:
|
@@ -581,7 +580,7 @@ class AsyncClient(BaseClient):
|
|
581
580
|
**kwargs: Keyword arguments forwarded to BaseClient.
|
582
581
|
"""
|
583
582
|
super().__init__(*args, **kwargs)
|
584
|
-
self._client = httpx.AsyncClient(timeout=self.timeout)
|
583
|
+
self._client = httpx.AsyncClient(timeout=self.timeout, auth=self.auth)
|
585
584
|
|
586
585
|
def is_closed(self) -> bool:
|
587
586
|
return self._client.is_closed
|
@@ -637,7 +636,6 @@ class AsyncClient(BaseClient):
|
|
637
636
|
response = await self._client.post(
|
638
637
|
self.endpoint,
|
639
638
|
headers=self._prepare_headers(stream=False, headers=headers),
|
640
|
-
auth=self.auth,
|
641
639
|
json=payload,
|
642
640
|
)
|
643
641
|
logger.debug(f"Received response with status code: {response.status_code}")
|
@@ -683,7 +681,6 @@ class AsyncClient(BaseClient):
|
|
683
681
|
"POST",
|
684
682
|
self.endpoint,
|
685
683
|
headers=self._prepare_headers(stream=True, headers=headers),
|
686
|
-
auth=self.auth,
|
687
684
|
json={**payload, "stream": True},
|
688
685
|
) as response:
|
689
686
|
try:
|
@@ -797,3 +794,43 @@ class AsyncClient(BaseClient):
|
|
797
794
|
logger.debug(f"Generating embeddings with input: {input}, payload: {payload}")
|
798
795
|
payload = {**(payload or {}), "input": input}
|
799
796
|
return await self._request(payload=payload, headers=headers)
|
797
|
+
|
798
|
+
|
799
|
+
def get_httpx_client(**kwargs: Any) -> httpx.Client:
|
800
|
+
"""
|
801
|
+
Creates and returns a synchronous httpx Client configured with OCI authentication signer based
|
802
|
+
the authentication type setup using ads.set_auth method or env variable OCI_IAM_TYPE.
|
803
|
+
More information - https://accelerated-data-science.readthedocs.io/en/stable/user_guide/cli/authentication.html
|
804
|
+
|
805
|
+
Parameters
|
806
|
+
----------
|
807
|
+
**kwargs : Any
|
808
|
+
Keyword arguments supported by httpx.Client
|
809
|
+
|
810
|
+
Returns
|
811
|
+
-------
|
812
|
+
Client
|
813
|
+
A configured synchronous httpx Client instance.
|
814
|
+
"""
|
815
|
+
kwargs["auth"] = kwargs.get("auth") or HttpxOCIAuth()
|
816
|
+
return httpx.Client(**kwargs)
|
817
|
+
|
818
|
+
|
819
|
+
def get_async_httpx_client(**kwargs: Any) -> httpx.AsyncClient:
|
820
|
+
"""
|
821
|
+
Creates and returns a synchronous httpx Client configured with OCI authentication signer based
|
822
|
+
the authentication type setup using ads.set_auth method or env variable OCI_IAM_TYPE.
|
823
|
+
More information - https://accelerated-data-science.readthedocs.io/en/stable/user_guide/cli/authentication.html
|
824
|
+
|
825
|
+
Parameters
|
826
|
+
----------
|
827
|
+
**kwargs : Any
|
828
|
+
Keyword arguments supported by httpx.Client
|
829
|
+
|
830
|
+
Returns
|
831
|
+
-------
|
832
|
+
AsyncClient
|
833
|
+
A configured asynchronous httpx AsyncClient instance.
|
834
|
+
"""
|
835
|
+
kwargs["auth"] = kwargs.get("auth") or HttpxOCIAuth()
|
836
|
+
return httpx.AsyncClient(**kwargs)
|
ads/aqua/common/entities.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
|
+
from typing import Any, Dict, Optional
|
6
|
+
|
7
|
+
from oci.data_science.models import Model
|
8
|
+
from pydantic import BaseModel, Field
|
9
|
+
|
5
10
|
|
6
11
|
class ContainerSpec:
|
7
12
|
"""
|
@@ -15,3 +20,25 @@ class ContainerSpec:
|
|
15
20
|
ENV_VARS = "envVars"
|
16
21
|
RESTRICTED_PARAMS = "restrictedParams"
|
17
22
|
EVALUATION_CONFIGURATION = "evaluationConfiguration"
|
23
|
+
|
24
|
+
|
25
|
+
class ModelConfigResult(BaseModel):
|
26
|
+
"""
|
27
|
+
Represents the result of getting the AQUA model configuration.
|
28
|
+
|
29
|
+
Attributes:
|
30
|
+
model_details (Dict[str, Any]): A dictionary containing model details extracted from OCI.
|
31
|
+
config (Dict[str, Any]): A dictionary of the loaded configuration.
|
32
|
+
"""
|
33
|
+
|
34
|
+
config: Optional[Dict[str, Any]] = Field(
|
35
|
+
None, description="Loaded configuration dictionary."
|
36
|
+
)
|
37
|
+
model_details: Optional[Model] = Field(
|
38
|
+
None, description="Details of the model from OCI."
|
39
|
+
)
|
40
|
+
|
41
|
+
class Config:
|
42
|
+
extra = "ignore"
|
43
|
+
arbitrary_types_allowed = True
|
44
|
+
protected_namespaces = ()
|
ads/aqua/common/enums.py
CHANGED
@@ -2,12 +2,6 @@
|
|
2
2
|
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
|
-
"""
|
6
|
-
aqua.common.enums
|
7
|
-
~~~~~~~~~~~~~~
|
8
|
-
This module contains the set of enums used in AQUA.
|
9
|
-
"""
|
10
|
-
|
11
5
|
from ads.common.extended_enum import ExtendedEnum
|
12
6
|
|
13
7
|
|
@@ -88,7 +82,8 @@ class RqsAdditionalDetails(ExtendedEnum):
|
|
88
82
|
|
89
83
|
class TextEmbeddingInferenceContainerParams(ExtendedEnum):
|
90
84
|
"""Contains a subset of params that are required for enabling model deployment in OCI Data Science. More options
|
91
|
-
are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments
|
85
|
+
are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments
|
86
|
+
"""
|
92
87
|
|
93
88
|
MODEL_ID = "model-id"
|
94
89
|
PORT = "port"
|
@@ -97,3 +92,14 @@ class TextEmbeddingInferenceContainerParams(ExtendedEnum):
|
|
97
92
|
class ConfigFolder(ExtendedEnum):
|
98
93
|
CONFIG = "config"
|
99
94
|
ARTIFACT = "artifact"
|
95
|
+
|
96
|
+
|
97
|
+
class ModelFormat(ExtendedEnum):
|
98
|
+
GGUF = "GGUF"
|
99
|
+
SAFETENSORS = "SAFETENSORS"
|
100
|
+
UNKNOWN = "UNKNOWN"
|
101
|
+
|
102
|
+
|
103
|
+
class Platform(ExtendedEnum):
|
104
|
+
ARM_CPU = "ARM_CPU"
|
105
|
+
NVIDIA_GPU = "NVIDIA_GPU"
|
ads/aqua/common/utils.py
CHANGED
@@ -19,7 +19,6 @@ from pathlib import Path
|
|
19
19
|
from string import Template
|
20
20
|
from typing import List, Union
|
21
21
|
|
22
|
-
import fsspec
|
23
22
|
import oci
|
24
23
|
from cachetools import TTLCache, cached
|
25
24
|
from huggingface_hub.constants import HF_HUB_CACHE
|
@@ -58,7 +57,6 @@ from ads.aqua.constants import (
|
|
58
57
|
SUPPORTED_FILE_FORMATS,
|
59
58
|
TEI_CONTAINER_DEFAULT_HOST,
|
60
59
|
TGI_INFERENCE_RESTRICTED_PARAMS,
|
61
|
-
UNKNOWN,
|
62
60
|
UNKNOWN_JSON_STR,
|
63
61
|
VLLM_INFERENCE_RESTRICTED_PARAMS,
|
64
62
|
)
|
@@ -68,7 +66,13 @@ from ads.common.decorator.threaded import threaded
|
|
68
66
|
from ads.common.extended_enum import ExtendedEnum
|
69
67
|
from ads.common.object_storage_details import ObjectStorageDetails
|
70
68
|
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
|
71
|
-
from ads.common.utils import
|
69
|
+
from ads.common.utils import (
|
70
|
+
UNKNOWN,
|
71
|
+
copy_file,
|
72
|
+
get_console_link,
|
73
|
+
read_file,
|
74
|
+
upload_to_os,
|
75
|
+
)
|
72
76
|
from ads.config import (
|
73
77
|
AQUA_MODEL_DEPLOYMENT_FOLDER,
|
74
78
|
AQUA_SERVICE_MODELS_BUCKET,
|
@@ -228,15 +232,6 @@ def get_artifact_path(custom_metadata_list: List) -> str:
|
|
228
232
|
return UNKNOWN
|
229
233
|
|
230
234
|
|
231
|
-
def read_file(file_path: str, **kwargs) -> str:
|
232
|
-
try:
|
233
|
-
with fsspec.open(file_path, "r", **kwargs.get("auth", {})) as f:
|
234
|
-
return f.read()
|
235
|
-
except Exception as e:
|
236
|
-
logger.debug(f"Failed to read file {file_path}. {e}")
|
237
|
-
return UNKNOWN
|
238
|
-
|
239
|
-
|
240
235
|
@threaded()
|
241
236
|
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
|
242
237
|
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
|
@@ -553,7 +548,7 @@ def service_config_path():
|
|
553
548
|
return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
|
554
549
|
|
555
550
|
|
556
|
-
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(
|
551
|
+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=10), timer=datetime.now))
|
557
552
|
def get_container_config():
|
558
553
|
config = load_config(
|
559
554
|
file_path=service_config_path(),
|
@@ -0,0 +1,203 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# Copyright (c) 2025 Oracle and/or its affiliates.
|
3
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
+
|
5
|
+
from typing import Dict, List, Optional
|
6
|
+
|
7
|
+
from pydantic import Field
|
8
|
+
|
9
|
+
from ads.aqua.common.entities import ContainerSpec
|
10
|
+
from ads.aqua.config.utils.serializer import Serializable
|
11
|
+
|
12
|
+
|
13
|
+
class AquaContainerConfigSpec(Serializable):
|
14
|
+
"""
|
15
|
+
Represents container specification details.
|
16
|
+
|
17
|
+
Attributes
|
18
|
+
----------
|
19
|
+
cli_param (Optional[str]): CLI parameter for container configuration.
|
20
|
+
server_port (Optional[str]): The server port for the container.
|
21
|
+
health_check_port (Optional[str]): The health check port for the container.
|
22
|
+
env_vars (Optional[List[Dict]]): Environment variables for the container.
|
23
|
+
restricted_params (Optional[List[str]]): Restricted parameters for container configuration.
|
24
|
+
"""
|
25
|
+
|
26
|
+
cli_param: Optional[str] = Field(
|
27
|
+
default=None, description="CLI parameter for container configuration."
|
28
|
+
)
|
29
|
+
server_port: Optional[str] = Field(
|
30
|
+
default=None, description="Server port for the container."
|
31
|
+
)
|
32
|
+
health_check_port: Optional[str] = Field(
|
33
|
+
default=None, description="Health check port for the container."
|
34
|
+
)
|
35
|
+
env_vars: Optional[List[Dict]] = Field(
|
36
|
+
default_factory=list, description="List of environment variables."
|
37
|
+
)
|
38
|
+
restricted_params: Optional[List[str]] = Field(
|
39
|
+
default_factory=list, description="List of restricted parameters."
|
40
|
+
)
|
41
|
+
|
42
|
+
class Config:
|
43
|
+
extra = "allow"
|
44
|
+
|
45
|
+
|
46
|
+
class AquaContainerConfigItem(Serializable):
|
47
|
+
"""
|
48
|
+
Represents an item of the AQUA container configuration.
|
49
|
+
|
50
|
+
Attributes
|
51
|
+
----------
|
52
|
+
name (Optional[str]): Name of the container configuration item.
|
53
|
+
version (Optional[str]): Version of the container.
|
54
|
+
display_name (Optional[str]): Display name for UI.
|
55
|
+
family (Optional[str]): Container family or category.
|
56
|
+
platforms (Optional[List[str]]): Supported platforms.
|
57
|
+
model_formats (Optional[List[str]]): Supported model formats.
|
58
|
+
spec (Optional[AquaContainerConfigSpec]): Container specification details.
|
59
|
+
"""
|
60
|
+
|
61
|
+
name: Optional[str] = Field(
|
62
|
+
default=None, description="Name of the container configuration item."
|
63
|
+
)
|
64
|
+
version: Optional[str] = Field(
|
65
|
+
default=None, description="Version of the container."
|
66
|
+
)
|
67
|
+
display_name: Optional[str] = Field(
|
68
|
+
default=None, description="Display name of the container."
|
69
|
+
)
|
70
|
+
family: Optional[str] = Field(
|
71
|
+
default=None, description="Container family or category."
|
72
|
+
)
|
73
|
+
platforms: Optional[List[str]] = Field(
|
74
|
+
default_factory=list, description="Supported platforms."
|
75
|
+
)
|
76
|
+
model_formats: Optional[List[str]] = Field(
|
77
|
+
default_factory=list, description="Supported model formats."
|
78
|
+
)
|
79
|
+
spec: Optional[AquaContainerConfigSpec] = Field(
|
80
|
+
default_factory=AquaContainerConfigSpec,
|
81
|
+
description="Detailed container specification.",
|
82
|
+
)
|
83
|
+
usages: Optional[List[str]] = Field(
|
84
|
+
default_factory=list, description="Supported usages."
|
85
|
+
)
|
86
|
+
|
87
|
+
class Config:
|
88
|
+
extra = "allow"
|
89
|
+
|
90
|
+
|
91
|
+
class AquaContainerConfig(Serializable):
|
92
|
+
"""
|
93
|
+
Represents a configuration of AQUA containers to be returned to the client.
|
94
|
+
|
95
|
+
Attributes
|
96
|
+
----------
|
97
|
+
inference (Dict[str, AquaContainerConfigItem]): Inference container configuration items.
|
98
|
+
finetune (Dict[str, AquaContainerConfigItem]): Fine-tuning container configuration items.
|
99
|
+
evaluate (Dict[str, AquaContainerConfigItem]): Evaluation container configuration items.
|
100
|
+
"""
|
101
|
+
|
102
|
+
inference: Dict[str, AquaContainerConfigItem] = Field(
|
103
|
+
default_factory=dict, description="Inference container configuration items."
|
104
|
+
)
|
105
|
+
finetune: Dict[str, AquaContainerConfigItem] = Field(
|
106
|
+
default_factory=dict, description="Fine-tuning container configuration items."
|
107
|
+
)
|
108
|
+
evaluate: Dict[str, AquaContainerConfigItem] = Field(
|
109
|
+
default_factory=dict, description="Evaluation container configuration items."
|
110
|
+
)
|
111
|
+
|
112
|
+
def to_dict(self):
|
113
|
+
return {
|
114
|
+
"inference": list(self.inference.values()),
|
115
|
+
"finetune": list(self.finetune.values()),
|
116
|
+
"evaluate": list(self.evaluate.values()),
|
117
|
+
}
|
118
|
+
|
119
|
+
@classmethod
|
120
|
+
def from_container_index_json(
|
121
|
+
cls,
|
122
|
+
config: Dict,
|
123
|
+
enable_spec: Optional[bool] = False,
|
124
|
+
) -> "AquaContainerConfig":
|
125
|
+
"""
|
126
|
+
Creates an AquaContainerConfig instance from a container index JSON.
|
127
|
+
|
128
|
+
Parameters
|
129
|
+
----------
|
130
|
+
config (Optional[Dict]): The container index JSON.
|
131
|
+
enable_spec (Optional[bool]): If True, fetch container specification details.
|
132
|
+
|
133
|
+
Returns
|
134
|
+
-------
|
135
|
+
AquaContainerConfig: The constructed container configuration.
|
136
|
+
"""
|
137
|
+
# TODO: Return this logic back if necessary in the next iteraion.
|
138
|
+
# if not config:
|
139
|
+
# config = get_container_config()
|
140
|
+
|
141
|
+
inference_items: Dict[str, AquaContainerConfigItem] = {}
|
142
|
+
finetune_items: Dict[str, AquaContainerConfigItem] = {}
|
143
|
+
evaluate_items: Dict[str, AquaContainerConfigItem] = {}
|
144
|
+
|
145
|
+
for container_type, containers in config.items():
|
146
|
+
if isinstance(containers, list):
|
147
|
+
for container in containers:
|
148
|
+
platforms = container.get("platforms", [])
|
149
|
+
model_formats = container.get("modelFormats", [])
|
150
|
+
usages = container.get("usages", [])
|
151
|
+
container_spec = (
|
152
|
+
config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
|
153
|
+
container_type, {}
|
154
|
+
)
|
155
|
+
if enable_spec
|
156
|
+
else None
|
157
|
+
)
|
158
|
+
container_item = AquaContainerConfigItem(
|
159
|
+
name=container.get("name", ""),
|
160
|
+
version=container.get("version", ""),
|
161
|
+
display_name=container.get(
|
162
|
+
"displayName", container.get("version", "")
|
163
|
+
),
|
164
|
+
family=container_type,
|
165
|
+
platforms=platforms,
|
166
|
+
model_formats=model_formats,
|
167
|
+
usages=usages,
|
168
|
+
spec=(
|
169
|
+
AquaContainerConfigSpec(
|
170
|
+
cli_param=container_spec.get(
|
171
|
+
ContainerSpec.CLI_PARM, ""
|
172
|
+
),
|
173
|
+
server_port=container_spec.get(
|
174
|
+
ContainerSpec.SERVER_PORT, ""
|
175
|
+
),
|
176
|
+
health_check_port=container_spec.get(
|
177
|
+
ContainerSpec.HEALTH_CHECK_PORT, ""
|
178
|
+
),
|
179
|
+
env_vars=container_spec.get(ContainerSpec.ENV_VARS, []),
|
180
|
+
restricted_params=container_spec.get(
|
181
|
+
ContainerSpec.RESTRICTED_PARAMS, []
|
182
|
+
),
|
183
|
+
)
|
184
|
+
if container_spec
|
185
|
+
else None
|
186
|
+
),
|
187
|
+
)
|
188
|
+
if container.get("type") == "inference":
|
189
|
+
inference_items[container_type] = container_item
|
190
|
+
elif (
|
191
|
+
container.get("type") == "fine-tune"
|
192
|
+
or container_type == "odsc-llm-fine-tuning"
|
193
|
+
):
|
194
|
+
finetune_items[container_type] = container_item
|
195
|
+
elif (
|
196
|
+
container.get("type") == "evaluate"
|
197
|
+
or container_type == "odsc-llm-evaluate"
|
198
|
+
):
|
199
|
+
evaluate_items[container_type] = container_item
|
200
|
+
|
201
|
+
return cls(
|
202
|
+
inference=inference_items, finetune=finetune_items, evaluate=evaluate_items
|
203
|
+
)
|