oracle-ads 2.11.9__py3-none-any.whl → 2.11.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +1 -1
- ads/aqua/{base.py → app.py} +27 -7
- ads/aqua/cli.py +59 -17
- ads/aqua/common/__init__.py +5 -0
- ads/aqua/{decorator.py → common/decorator.py} +14 -8
- ads/aqua/common/enums.py +69 -0
- ads/aqua/{exception.py → common/errors.py} +28 -0
- ads/aqua/{utils.py → common/utils.py} +193 -95
- ads/aqua/config/config.py +18 -0
- ads/aqua/constants.py +51 -33
- ads/aqua/data.py +15 -26
- ads/aqua/evaluation/__init__.py +8 -0
- ads/aqua/evaluation/constants.py +53 -0
- ads/aqua/evaluation/entities.py +170 -0
- ads/aqua/evaluation/errors.py +71 -0
- ads/aqua/{evaluation.py → evaluation/evaluation.py} +122 -370
- ads/aqua/extension/__init__.py +2 -0
- ads/aqua/extension/aqua_ws_msg_handler.py +97 -0
- ads/aqua/extension/base_handler.py +0 -7
- ads/aqua/extension/common_handler.py +12 -6
- ads/aqua/extension/deployment_handler.py +70 -4
- ads/aqua/extension/errors.py +10 -0
- ads/aqua/extension/evaluation_handler.py +5 -3
- ads/aqua/extension/evaluation_ws_msg_handler.py +43 -0
- ads/aqua/extension/finetune_handler.py +41 -3
- ads/aqua/extension/model_handler.py +56 -4
- ads/aqua/extension/models/__init__.py +0 -0
- ads/aqua/extension/models/ws_models.py +69 -0
- ads/aqua/extension/ui_handler.py +65 -4
- ads/aqua/extension/ui_websocket_handler.py +124 -0
- ads/aqua/extension/utils.py +1 -1
- ads/aqua/finetuning/__init__.py +7 -0
- ads/aqua/finetuning/constants.py +17 -0
- ads/aqua/finetuning/entities.py +102 -0
- ads/aqua/{finetune.py → finetuning/finetuning.py} +170 -141
- ads/aqua/model/__init__.py +8 -0
- ads/aqua/model/constants.py +46 -0
- ads/aqua/model/entities.py +266 -0
- ads/aqua/model/enums.py +26 -0
- ads/aqua/{model.py → model/model.py} +405 -309
- ads/aqua/modeldeployment/__init__.py +8 -0
- ads/aqua/modeldeployment/constants.py +26 -0
- ads/aqua/{deployment.py → modeldeployment/deployment.py} +288 -227
- ads/aqua/modeldeployment/entities.py +142 -0
- ads/aqua/modeldeployment/inference.py +75 -0
- ads/aqua/ui.py +88 -8
- ads/cli.py +55 -7
- ads/common/decorator/threaded.py +97 -0
- ads/common/serializer.py +2 -2
- ads/config.py +5 -1
- ads/jobs/builders/infrastructure/dsc_job.py +49 -6
- ads/model/datascience_model.py +1 -1
- ads/model/deployment/model_deployment.py +11 -0
- ads/model/model_metadata.py +17 -6
- ads/opctl/operator/lowcode/anomaly/README.md +0 -2
- ads/opctl/operator/lowcode/anomaly/__main__.py +3 -3
- ads/opctl/operator/lowcode/anomaly/environment.yaml +0 -2
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +2 -2
- ads/opctl/operator/lowcode/anomaly/model/autots.py +1 -1
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +13 -17
- ads/opctl/operator/lowcode/anomaly/operator_config.py +2 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +1 -2
- ads/opctl/operator/lowcode/anomaly/utils.py +3 -2
- ads/opctl/operator/lowcode/common/transformations.py +2 -1
- ads/opctl/operator/lowcode/common/utils.py +1 -1
- ads/opctl/operator/lowcode/forecast/README.md +1 -3
- ads/opctl/operator/lowcode/forecast/__main__.py +3 -18
- ads/opctl/operator/lowcode/forecast/const.py +2 -0
- ads/opctl/operator/lowcode/forecast/environment.yaml +1 -2
- ads/opctl/operator/lowcode/forecast/model/arima.py +1 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +7 -4
- ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +38 -22
- ads/opctl/operator/lowcode/forecast/model/factory.py +33 -4
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +15 -1
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +234 -0
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +9 -1
- ads/opctl/operator/lowcode/forecast/model/prophet.py +1 -0
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +147 -0
- ads/opctl/operator/lowcode/forecast/operator_config.py +2 -1
- ads/opctl/operator/lowcode/forecast/schema.yaml +7 -2
- ads/opctl/operator/lowcode/forecast/utils.py +18 -44
- {oracle_ads-2.11.9.dist-info → oracle_ads-2.11.11.dist-info}/METADATA +9 -12
- {oracle_ads-2.11.9.dist-info → oracle_ads-2.11.11.dist-info}/RECORD +87 -61
- ads/aqua/job.py +0 -29
- {oracle_ads-2.11.9.dist-info → oracle_ads-2.11.11.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.11.9.dist-info → oracle_ads-2.11.11.dist-info}/WHEEL +0 -0
- {oracle_ads-2.11.9.dist-info → oracle_ads-2.11.11.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,142 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
from dataclasses import dataclass, field
|
7
|
+
from typing import Union
|
8
|
+
|
9
|
+
from oci.data_science.models import ModelDeployment, ModelDeploymentSummary
|
10
|
+
|
11
|
+
from ads.aqua.common.enums import Tags
|
12
|
+
from ads.aqua.constants import UNKNOWN, UNKNOWN_DICT
|
13
|
+
from ads.aqua.data import AquaResourceIdentifier
|
14
|
+
from ads.common.serializer import DataClassSerializable
|
15
|
+
from ads.common.utils import get_console_link
|
16
|
+
|
17
|
+
|
18
|
+
@dataclass
|
19
|
+
class ModelParams:
|
20
|
+
max_tokens: int = None
|
21
|
+
temperature: float = None
|
22
|
+
top_k: float = None
|
23
|
+
top_p: float = None
|
24
|
+
model: str = None
|
25
|
+
|
26
|
+
|
27
|
+
class ContainerSpec:
|
28
|
+
"""
|
29
|
+
Class to hold to hold keys within the container spec.
|
30
|
+
"""
|
31
|
+
|
32
|
+
CONTAINER_SPEC = "containerSpec"
|
33
|
+
CLI_PARM = "cliParam"
|
34
|
+
SERVER_PORT = "serverPort"
|
35
|
+
HEALTH_CHECK_PORT = "healthCheckPort"
|
36
|
+
ENV_VARS = "envVars"
|
37
|
+
|
38
|
+
|
39
|
+
@dataclass
|
40
|
+
class ShapeInfo:
|
41
|
+
instance_shape: str = None
|
42
|
+
instance_count: int = None
|
43
|
+
ocpus: float = None
|
44
|
+
memory_in_gbs: float = None
|
45
|
+
|
46
|
+
|
47
|
+
@dataclass(repr=False)
|
48
|
+
class AquaDeployment(DataClassSerializable):
|
49
|
+
"""Represents an Aqua Model Deployment"""
|
50
|
+
|
51
|
+
id: str = None
|
52
|
+
display_name: str = None
|
53
|
+
aqua_service_model: bool = None
|
54
|
+
aqua_model_name: str = None
|
55
|
+
state: str = None
|
56
|
+
description: str = None
|
57
|
+
created_on: str = None
|
58
|
+
created_by: str = None
|
59
|
+
endpoint: str = None
|
60
|
+
console_link: str = None
|
61
|
+
lifecycle_details: str = None
|
62
|
+
shape_info: field(default_factory=ShapeInfo) = None
|
63
|
+
tags: dict = None
|
64
|
+
|
65
|
+
@classmethod
|
66
|
+
def from_oci_model_deployment(
|
67
|
+
cls,
|
68
|
+
oci_model_deployment: Union[ModelDeploymentSummary, ModelDeployment],
|
69
|
+
region: str,
|
70
|
+
) -> "AquaDeployment":
|
71
|
+
"""Converts oci model deployment response to AquaDeployment instance.
|
72
|
+
|
73
|
+
Parameters
|
74
|
+
----------
|
75
|
+
oci_model_deployment: Union[ModelDeploymentSummary, ModelDeployment]
|
76
|
+
The instance of either oci.data_science.models.ModelDeployment or
|
77
|
+
oci.data_science.models.ModelDeploymentSummary class.
|
78
|
+
region: str
|
79
|
+
The region of this model deployment.
|
80
|
+
|
81
|
+
Returns
|
82
|
+
-------
|
83
|
+
AquaDeployment:
|
84
|
+
The instance of the Aqua model deployment.
|
85
|
+
"""
|
86
|
+
instance_configuration = (
|
87
|
+
oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration
|
88
|
+
)
|
89
|
+
instance_shape_config_details = (
|
90
|
+
instance_configuration.model_deployment_instance_shape_config_details
|
91
|
+
)
|
92
|
+
instance_count = (
|
93
|
+
oci_model_deployment.model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count
|
94
|
+
)
|
95
|
+
shape_info = ShapeInfo(
|
96
|
+
instance_shape=instance_configuration.instance_shape_name,
|
97
|
+
instance_count=instance_count,
|
98
|
+
ocpus=(
|
99
|
+
instance_shape_config_details.ocpus
|
100
|
+
if instance_shape_config_details
|
101
|
+
else None
|
102
|
+
),
|
103
|
+
memory_in_gbs=(
|
104
|
+
instance_shape_config_details.memory_in_gbs
|
105
|
+
if instance_shape_config_details
|
106
|
+
else None
|
107
|
+
),
|
108
|
+
)
|
109
|
+
|
110
|
+
freeform_tags = oci_model_deployment.freeform_tags or UNKNOWN_DICT
|
111
|
+
aqua_service_model_tag = freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None)
|
112
|
+
aqua_model_name = freeform_tags.get(Tags.AQUA_MODEL_NAME_TAG, UNKNOWN)
|
113
|
+
|
114
|
+
return AquaDeployment(
|
115
|
+
id=oci_model_deployment.id,
|
116
|
+
display_name=oci_model_deployment.display_name,
|
117
|
+
aqua_service_model=aqua_service_model_tag is not None,
|
118
|
+
aqua_model_name=aqua_model_name,
|
119
|
+
shape_info=shape_info,
|
120
|
+
state=oci_model_deployment.lifecycle_state,
|
121
|
+
lifecycle_details=getattr(
|
122
|
+
oci_model_deployment, "lifecycle_details", UNKNOWN
|
123
|
+
),
|
124
|
+
description=oci_model_deployment.description,
|
125
|
+
created_on=str(oci_model_deployment.time_created),
|
126
|
+
created_by=oci_model_deployment.created_by,
|
127
|
+
endpoint=oci_model_deployment.model_deployment_url,
|
128
|
+
console_link=get_console_link(
|
129
|
+
resource="model-deployments",
|
130
|
+
ocid=oci_model_deployment.id,
|
131
|
+
region=region,
|
132
|
+
),
|
133
|
+
tags=freeform_tags,
|
134
|
+
)
|
135
|
+
|
136
|
+
|
137
|
+
@dataclass(repr=False)
|
138
|
+
class AquaDeploymentDetail(AquaDeployment, DataClassSerializable):
|
139
|
+
"""Represents a details of Aqua deployment."""
|
140
|
+
|
141
|
+
log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
|
142
|
+
log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
|
@@ -0,0 +1,75 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*--
|
3
|
+
|
4
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
import json
|
8
|
+
from dataclasses import asdict, dataclass, field
|
9
|
+
|
10
|
+
import requests
|
11
|
+
|
12
|
+
from ads.aqua.app import AquaApp, logger
|
13
|
+
from ads.aqua.modeldeployment.entities import ModelParams
|
14
|
+
from ads.common.auth import default_signer
|
15
|
+
from ads.telemetry import telemetry
|
16
|
+
|
17
|
+
|
18
|
+
@dataclass
|
19
|
+
class MDInferenceResponse(AquaApp):
|
20
|
+
"""Contains APIs for Aqua Model deployments Inference.
|
21
|
+
|
22
|
+
Attributes
|
23
|
+
----------
|
24
|
+
|
25
|
+
model_params: Dict
|
26
|
+
prompt: string
|
27
|
+
|
28
|
+
Methods
|
29
|
+
-------
|
30
|
+
get_model_deployment_response(self, **kwargs) -> "String"
|
31
|
+
Creates an instance of model deployment via Aqua
|
32
|
+
"""
|
33
|
+
|
34
|
+
prompt: str = None
|
35
|
+
model_params: field(default_factory=ModelParams) = None
|
36
|
+
|
37
|
+
@telemetry(entry_point="plugin=inference&action=get_response", name="aqua")
|
38
|
+
def get_model_deployment_response(self, endpoint):
|
39
|
+
"""
|
40
|
+
Returns MD inference response
|
41
|
+
|
42
|
+
Parameters
|
43
|
+
----------
|
44
|
+
endpoint: str
|
45
|
+
MD predict url
|
46
|
+
prompt: str
|
47
|
+
User prompt.
|
48
|
+
|
49
|
+
model_params: (Dict, optional)
|
50
|
+
Model parameters to be associated with the message.
|
51
|
+
Currently supported VLLM+OpenAI parameters.
|
52
|
+
|
53
|
+
--model-params '{
|
54
|
+
"max_tokens":500,
|
55
|
+
"temperature": 0.5,
|
56
|
+
"top_k": 10,
|
57
|
+
"top_p": 0.5,
|
58
|
+
"model": "/opt/ds/model/deployed_model",
|
59
|
+
...}'
|
60
|
+
|
61
|
+
Returns
|
62
|
+
-------
|
63
|
+
model_response_content
|
64
|
+
"""
|
65
|
+
|
66
|
+
params_dict = asdict(self.model_params)
|
67
|
+
params_dict = {
|
68
|
+
key: value for key, value in params_dict.items() if value is not None
|
69
|
+
}
|
70
|
+
body = {"prompt": self.prompt, **params_dict}
|
71
|
+
request_kwargs = {"json": body, "headers": {"Content-Type": "application/json"}}
|
72
|
+
response = requests.post(
|
73
|
+
endpoint, auth=default_signer()["signer"], **request_kwargs
|
74
|
+
)
|
75
|
+
return json.loads(response.content)
|
ads/aqua/ui.py
CHANGED
@@ -3,21 +3,24 @@
|
|
3
3
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
import concurrent.futures
|
6
|
+
from dataclasses import dataclass, field
|
6
7
|
from datetime import datetime, timedelta
|
7
8
|
from threading import Lock
|
9
|
+
from typing import Dict, List
|
8
10
|
|
9
11
|
from cachetools import TTLCache
|
10
12
|
from oci.exceptions import ServiceError
|
11
13
|
from oci.identity.models import Compartment
|
12
14
|
|
13
15
|
from ads.aqua import logger
|
14
|
-
from ads.aqua.
|
15
|
-
from ads.aqua.
|
16
|
-
from ads.aqua.
|
17
|
-
from ads.aqua.utils import load_config, sanitize_response
|
16
|
+
from ads.aqua.app import AquaApp
|
17
|
+
from ads.aqua.common.enums import Tags
|
18
|
+
from ads.aqua.common.errors import AquaResourceAccessError, AquaValueError
|
19
|
+
from ads.aqua.common.utils import get_container_config, load_config, sanitize_response
|
18
20
|
from ads.common import oci_client as oc
|
19
21
|
from ads.common.auth import default_signer
|
20
22
|
from ads.common.object_storage_details import ObjectStorageDetails
|
23
|
+
from ads.common.serializer import DataClassSerializable
|
21
24
|
from ads.config import (
|
22
25
|
AQUA_CONFIG_FOLDER,
|
23
26
|
AQUA_RESOURCE_LIMIT_NAMES_CONFIG,
|
@@ -28,6 +31,70 @@ from ads.config import (
|
|
28
31
|
from ads.telemetry import telemetry
|
29
32
|
|
30
33
|
|
34
|
+
@dataclass(repr=False)
|
35
|
+
class AquaContainerConfigItem(DataClassSerializable):
|
36
|
+
"""Represents an item of the AQUA container configuration."""
|
37
|
+
|
38
|
+
name: str = None
|
39
|
+
version: str = None
|
40
|
+
display_name: str = None
|
41
|
+
family: str = None
|
42
|
+
|
43
|
+
|
44
|
+
@dataclass(repr=False)
|
45
|
+
class AquaContainerConfig(DataClassSerializable):
|
46
|
+
"""
|
47
|
+
Represents a configuration with AQUA containers to be returned to the client.
|
48
|
+
"""
|
49
|
+
|
50
|
+
inference: List[AquaContainerConfigItem] = field(default_factory=list)
|
51
|
+
finetune: List[AquaContainerConfigItem] = field(default_factory=list)
|
52
|
+
evaluate: List[AquaContainerConfigItem] = field(default_factory=list)
|
53
|
+
|
54
|
+
@classmethod
|
55
|
+
def from_container_index_json(cls, config: Dict) -> "AquaContainerConfig":
|
56
|
+
"""
|
57
|
+
Create an AquaContainerConfig instance from a container index JSON.
|
58
|
+
|
59
|
+
Parameters
|
60
|
+
----------
|
61
|
+
config : Dict
|
62
|
+
The container index JSON.
|
63
|
+
|
64
|
+
Returns
|
65
|
+
-------
|
66
|
+
AquaContainerConfig
|
67
|
+
The container configuration instance.
|
68
|
+
"""
|
69
|
+
config = config or {}
|
70
|
+
inference_items = []
|
71
|
+
finetune_items = []
|
72
|
+
evaluate_items = []
|
73
|
+
|
74
|
+
# extract inference containers
|
75
|
+
for container_type, containers in config.items():
|
76
|
+
if isinstance(containers, list):
|
77
|
+
for container in containers:
|
78
|
+
container_item = AquaContainerConfigItem(
|
79
|
+
name=container.get("name", ""),
|
80
|
+
version=container.get("version", ""),
|
81
|
+
display_name=container.get(
|
82
|
+
"displayName", container.get("version", "")
|
83
|
+
),
|
84
|
+
family=container_type,
|
85
|
+
)
|
86
|
+
if container.get("type") == "inference":
|
87
|
+
inference_items.append(container_item)
|
88
|
+
elif container_type == "odsc-llm-fine-tuning":
|
89
|
+
finetune_items.append(container_item)
|
90
|
+
elif container_type == "odsc-llm-evaluate":
|
91
|
+
evaluate_items.append(container_item)
|
92
|
+
|
93
|
+
return AquaContainerConfig(
|
94
|
+
inference=inference_items, finetune=finetune_items, evaluate=evaluate_items
|
95
|
+
)
|
96
|
+
|
97
|
+
|
31
98
|
class AquaUIApp(AquaApp):
|
32
99
|
"""Contains APIs for supporting Aqua UI.
|
33
100
|
|
@@ -42,7 +109,8 @@ class AquaUIApp(AquaApp):
|
|
42
109
|
Lists the specified log group's log objects.
|
43
110
|
list_compartments(self, **kwargs) -> List[Dict]
|
44
111
|
Lists the compartments in a specified compartment.
|
45
|
-
|
112
|
+
list_containers(self, **kwargs) -> AquaContainerConfig
|
113
|
+
Containers config to be returned to the client.
|
46
114
|
"""
|
47
115
|
|
48
116
|
_compartments_cache = TTLCache(
|
@@ -219,9 +287,7 @@ class AquaUIApp(AquaApp):
|
|
219
287
|
"""
|
220
288
|
compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
|
221
289
|
target_resource = (
|
222
|
-
"experiments"
|
223
|
-
if target_tag == Tags.AQUA_EVALUATION.value
|
224
|
-
else "modelversionsets"
|
290
|
+
"experiments" if target_tag == Tags.AQUA_EVALUATION else "modelversionsets"
|
225
291
|
)
|
226
292
|
logger.info(f"Loading {target_resource} from compartment: {compartment_id}")
|
227
293
|
|
@@ -451,3 +517,17 @@ class AquaUIApp(AquaApp):
|
|
451
517
|
message = f"Model artifact bucket {bucket_uri} is not versioned. Check if the path exists and enable versioning on the bucket to proceed with model creation."
|
452
518
|
|
453
519
|
return dict(is_versioned=is_versioned, message=message)
|
520
|
+
|
521
|
+
@telemetry(entry_point="plugin=ui&action=list_containers", name="aqua")
|
522
|
+
def list_containers(self) -> AquaContainerConfig:
|
523
|
+
"""
|
524
|
+
Lists the AQUA containers.
|
525
|
+
|
526
|
+
Returns
|
527
|
+
-------
|
528
|
+
AquaContainerConfig
|
529
|
+
The AQUA containers configuration.
|
530
|
+
"""
|
531
|
+
return AquaContainerConfig.from_container_index_json(
|
532
|
+
config=get_container_config()
|
533
|
+
)
|
ads/cli.py
CHANGED
@@ -4,19 +4,21 @@
|
|
4
4
|
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
|
5
5
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
6
|
|
7
|
-
import traceback
|
8
7
|
import sys
|
8
|
+
import traceback
|
9
|
+
from dataclasses import is_dataclass
|
9
10
|
|
10
11
|
import fire
|
11
|
-
|
12
|
+
|
12
13
|
from ads.common import logger
|
13
14
|
|
14
15
|
try:
|
15
16
|
import click
|
16
|
-
|
17
|
+
|
17
18
|
import ads.jobs.cli
|
18
|
-
import ads.
|
19
|
+
import ads.opctl.cli
|
19
20
|
import ads.opctl.operator.cli
|
21
|
+
import ads.pipeline.cli
|
20
22
|
except Exception as ex:
|
21
23
|
print(
|
22
24
|
"Please run `pip install oracle-ads[opctl]` to install "
|
@@ -33,6 +35,7 @@ if sys.version_info >= (3, 8):
|
|
33
35
|
else:
|
34
36
|
import importlib_metadata as metadata
|
35
37
|
|
38
|
+
|
36
39
|
ADS_VERSION = metadata.version("oracle_ads")
|
37
40
|
|
38
41
|
|
@@ -86,13 +89,58 @@ def serialize(data):
|
|
86
89
|
print(str(data))
|
87
90
|
|
88
91
|
|
92
|
+
def exit_program(ex: Exception, logger: "logging.Logger") -> None:
|
93
|
+
"""
|
94
|
+
Logs the exception and exits the program with a specific exit code.
|
95
|
+
|
96
|
+
This function logs the full traceback and the exception message, then terminates
|
97
|
+
the program with an exit code. If the exception object has an 'exit_code' attribute,
|
98
|
+
it uses that as the exit code; otherwise, it defaults to 1.
|
99
|
+
|
100
|
+
Parameters
|
101
|
+
----------
|
102
|
+
ex (Exception):
|
103
|
+
The exception that triggered the program exit. This exception
|
104
|
+
should ideally contain an 'exit_code' attribute, but it is not mandatory.
|
105
|
+
logger (Logger):
|
106
|
+
A logging.Logger instance used to log the traceback and the error message.
|
107
|
+
|
108
|
+
Returns
|
109
|
+
-------
|
110
|
+
None:
|
111
|
+
This function does not return anything because it calls sys.exit,
|
112
|
+
terminating the process.
|
113
|
+
|
114
|
+
Examples
|
115
|
+
--------
|
116
|
+
|
117
|
+
>>> import logging
|
118
|
+
>>> logger = logging.getLogger('ExampleLogger')
|
119
|
+
>>> try:
|
120
|
+
... raise ValueError("An error occurred")
|
121
|
+
... except Exception as e:
|
122
|
+
... exit_program(e, logger)
|
123
|
+
"""
|
124
|
+
|
125
|
+
logger.debug(traceback.format_exc())
|
126
|
+
logger.error(str(ex))
|
127
|
+
|
128
|
+
exit_code = getattr(ex, "exit_code", 1)
|
129
|
+
logger.error(f"Exit code: {exit_code}")
|
130
|
+
sys.exit(exit_code)
|
131
|
+
|
132
|
+
|
89
133
|
def cli():
|
90
134
|
if len(sys.argv) > 1 and sys.argv[1] == "aqua":
|
135
|
+
from ads.aqua import logger as aqua_logger
|
91
136
|
from ads.aqua.cli import AquaCommand
|
92
137
|
|
93
|
-
|
94
|
-
|
95
|
-
|
138
|
+
try:
|
139
|
+
fire.Fire(
|
140
|
+
AquaCommand, command=sys.argv[2:], name="ads aqua", serialize=serialize
|
141
|
+
)
|
142
|
+
except Exception as err:
|
143
|
+
exit_program(err, aqua_logger)
|
96
144
|
else:
|
97
145
|
click_cli()
|
98
146
|
|
@@ -0,0 +1,97 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8; -*-
|
3
|
+
|
4
|
+
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
|
8
|
+
import concurrent.futures
|
9
|
+
import functools
|
10
|
+
import logging
|
11
|
+
from typing import Optional
|
12
|
+
|
13
|
+
from git import Optional
|
14
|
+
|
15
|
+
from ads.config import THREADED_DEFAULT_TIMEOUT
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
# Create a global thread pool with a maximum of 10 threads
|
20
|
+
thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=10)
|
21
|
+
|
22
|
+
|
23
|
+
class TimeoutError(Exception):
|
24
|
+
"""
|
25
|
+
Custom exception to be raised when a function times out.
|
26
|
+
|
27
|
+
Attributes
|
28
|
+
----------
|
29
|
+
message : str
|
30
|
+
The error message describing what went wrong.
|
31
|
+
|
32
|
+
Parameters
|
33
|
+
----------
|
34
|
+
message : str
|
35
|
+
The error message.
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
func_name: str,
|
41
|
+
timeout: int,
|
42
|
+
message: Optional[str] = "The operation could not be completed in time.",
|
43
|
+
):
|
44
|
+
super().__init__(
|
45
|
+
f"{message} The function '{func_name}' exceeded the timeout of {timeout} seconds."
|
46
|
+
)
|
47
|
+
|
48
|
+
|
49
|
+
def threaded(timeout: Optional[int] = THREADED_DEFAULT_TIMEOUT):
|
50
|
+
"""
|
51
|
+
Decorator to run a function in a separate thread using a global thread pool.
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
timeout (int, optional)
|
56
|
+
The maximum time in seconds to wait for the function to complete.
|
57
|
+
If the function does not complete within this time, "timeout" is returned.
|
58
|
+
|
59
|
+
Returns
|
60
|
+
-------
|
61
|
+
function: The wrapped function that will run in a separate thread with the specified timeout.
|
62
|
+
"""
|
63
|
+
|
64
|
+
def decorator(func):
|
65
|
+
@functools.wraps(func)
|
66
|
+
def wrapper(*args, **kwargs):
|
67
|
+
"""
|
68
|
+
Wrapper function to submit the decorated function to the thread pool and handle timeout.
|
69
|
+
|
70
|
+
Parameters
|
71
|
+
----------
|
72
|
+
*args: Positional arguments to pass to the decorated function.
|
73
|
+
**kwargs: Keyword arguments to pass to the decorated function.
|
74
|
+
|
75
|
+
Returns
|
76
|
+
-------
|
77
|
+
Any: The result of the decorated function if it completes within the timeout.
|
78
|
+
|
79
|
+
Raise
|
80
|
+
-----
|
81
|
+
TimeoutError
|
82
|
+
In case of the function exceeded the timeout.
|
83
|
+
"""
|
84
|
+
future = thread_pool.submit(func, *args, **kwargs)
|
85
|
+
try:
|
86
|
+
return future.result(timeout=timeout)
|
87
|
+
except concurrent.futures.TimeoutError as ex:
|
88
|
+
logger.debug(
|
89
|
+
f"The function '{func.__name__}' "
|
90
|
+
f"exceeded the timeout of {timeout} seconds. "
|
91
|
+
f"{ex}"
|
92
|
+
)
|
93
|
+
raise TimeoutError(func.__name__, timeout)
|
94
|
+
|
95
|
+
return wrapper
|
96
|
+
|
97
|
+
return decorator
|
ads/common/serializer.py
CHANGED
@@ -79,7 +79,7 @@ class Serializable(ABC):
|
|
79
79
|
|
80
80
|
@classmethod
|
81
81
|
@abstractmethod
|
82
|
-
def from_dict(cls, obj_dict: dict) -> "Serializable":
|
82
|
+
def from_dict(cls, obj_dict: dict, **kwargs) -> "Serializable":
|
83
83
|
"""Returns an instance of the class instantiated by the dictionary provided.
|
84
84
|
|
85
85
|
Parameters
|
@@ -239,7 +239,7 @@ class Serializable(ABC):
|
|
239
239
|
Returns instance of the class
|
240
240
|
"""
|
241
241
|
if json_string:
|
242
|
-
return cls.from_dict(json.loads(json_string, cls=decoder))
|
242
|
+
return cls.from_dict(json.loads(json_string, cls=decoder), **kwargs)
|
243
243
|
if uri:
|
244
244
|
json_dict = json.loads(cls._read_from_file(uri, **kwargs), cls=decoder)
|
245
245
|
return cls.from_dict(json_dict)
|
ads/config.py
CHANGED
@@ -8,6 +8,7 @@ import contextlib
|
|
8
8
|
import inspect
|
9
9
|
import os
|
10
10
|
from typing import Dict, Optional
|
11
|
+
|
11
12
|
from ads.common.config import DEFAULT_CONFIG_PATH, DEFAULT_CONFIG_PROFILE, Config, Mode
|
12
13
|
|
13
14
|
OCI_ODSC_SERVICE_ENDPOINT = os.environ.get("OCI_ODSC_SERVICE_ENDPOINT")
|
@@ -41,7 +42,6 @@ COMPARTMENT_OCID = (
|
|
41
42
|
)
|
42
43
|
MD_OCID = os.environ.get("MD_OCID")
|
43
44
|
DATAFLOW_RUN_OCID = os.environ.get("DATAFLOW_RUN_ID")
|
44
|
-
|
45
45
|
RESOURCE_OCID = (
|
46
46
|
NB_SESSION_OCID or JOB_RUN_OCID or MD_OCID or PIPELINE_RUN_OCID or DATAFLOW_RUN_OCID
|
47
47
|
)
|
@@ -66,6 +66,8 @@ AQUA_RESOURCE_LIMIT_NAMES_CONFIG = os.environ.get(
|
|
66
66
|
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME = "deployment-container"
|
67
67
|
AQUA_FINETUNING_CONTAINER_METADATA_NAME = "finetune-container"
|
68
68
|
AQUA_EVALUATION_CONTAINER_METADATA_NAME = "evaluation-container"
|
69
|
+
AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME = "deployment-container-custom"
|
70
|
+
AQUA_FINETUNING_CONTAINER_OVERRIDE_FLAG_METADATA_NAME = "finetune-container-custom"
|
69
71
|
AQUA_MODEL_DEPLOYMENT_FOLDER = "/opt/ds/model/deployed_model/"
|
70
72
|
AQUA_SERVED_MODEL_NAME = "odsc-llm"
|
71
73
|
AQUA_CONFIG_FOLDER = os.path.join(
|
@@ -84,6 +86,8 @@ DEBUG_TELEMETRY = os.environ.get("DEBUG_TELEMETRY", None)
|
|
84
86
|
AQUA_SERVICE_NAME = "aqua"
|
85
87
|
DATA_SCIENCE_SERVICE_NAME = "data-science"
|
86
88
|
|
89
|
+
THREADED_DEFAULT_TIMEOUT = os.environ.get("THREADED_DEFAULT_TIMEOUT", 5)
|
90
|
+
|
87
91
|
|
88
92
|
def export(
|
89
93
|
uri: Optional[str] = DEFAULT_CONFIG_PATH,
|
@@ -9,6 +9,7 @@ import datetime
|
|
9
9
|
import inspect
|
10
10
|
import logging
|
11
11
|
import os
|
12
|
+
import re
|
12
13
|
import time
|
13
14
|
import traceback
|
14
15
|
import uuid
|
@@ -375,12 +376,13 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
375
376
|
"""
|
376
377
|
runs = self.run_list()
|
377
378
|
for run in runs:
|
378
|
-
if
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
379
|
+
if force_delete:
|
380
|
+
if run.lifecycle_state in [
|
381
|
+
DataScienceJobRun.LIFECYCLE_STATE_ACCEPTED,
|
382
|
+
DataScienceJobRun.LIFECYCLE_STATE_IN_PROGRESS,
|
383
|
+
DataScienceJobRun.LIFECYCLE_STATE_NEEDS_ATTENTION,
|
384
|
+
]:
|
385
|
+
run.cancel(wait_for_completion=True)
|
384
386
|
run.delete()
|
385
387
|
self.client.delete_job(self.id)
|
386
388
|
return self
|
@@ -582,6 +584,25 @@ class DataScienceJobRun(
|
|
582
584
|
id=self.log_id, log_group_id=self.log_details.log_group_id, **auth
|
583
585
|
)
|
584
586
|
|
587
|
+
@property
|
588
|
+
def exit_code(self):
|
589
|
+
"""The exit code of the job run from the lifecycle details.
|
590
|
+
Note that,
|
591
|
+
None will be returned if the job run is not finished or failed without exit code.
|
592
|
+
0 will be returned if job run succeeded.
|
593
|
+
"""
|
594
|
+
if self.lifecycle_state == self.LIFECYCLE_STATE_SUCCEEDED:
|
595
|
+
return 0
|
596
|
+
if not self.lifecycle_details:
|
597
|
+
return None
|
598
|
+
match = re.search(r"exit code (\d+)", self.lifecycle_details)
|
599
|
+
if not match:
|
600
|
+
return None
|
601
|
+
try:
|
602
|
+
return int(match.group(1))
|
603
|
+
except Exception:
|
604
|
+
return None
|
605
|
+
|
585
606
|
@staticmethod
|
586
607
|
def _format_log(message: str, date_time: datetime.datetime) -> dict:
|
587
608
|
"""Formats a message as log record with datetime.
|
@@ -655,6 +676,22 @@ class DataScienceJobRun(
|
|
655
676
|
print(f"{timestamp} - {status}")
|
656
677
|
return status
|
657
678
|
|
679
|
+
def wait(self, interval: float = SLEEP_INTERVAL):
|
680
|
+
"""Waits for the job run until if finishes.
|
681
|
+
|
682
|
+
Parameters
|
683
|
+
----------
|
684
|
+
interval : float
|
685
|
+
Time interval in seconds between each request to update the logs.
|
686
|
+
Defaults to 3 (seconds).
|
687
|
+
|
688
|
+
"""
|
689
|
+
self.sync()
|
690
|
+
while self.status not in self.TERMINAL_STATES:
|
691
|
+
time.sleep(interval)
|
692
|
+
self.sync()
|
693
|
+
return self
|
694
|
+
|
658
695
|
def watch(
|
659
696
|
self,
|
660
697
|
interval: float = SLEEP_INTERVAL,
|
@@ -830,6 +867,12 @@ class DataScienceJobRun(
|
|
830
867
|
self.job.download(to_dir)
|
831
868
|
return self
|
832
869
|
|
870
|
+
def delete(self, force_delete: bool = False):
|
871
|
+
if force_delete:
|
872
|
+
self.cancel(wait_for_completion=True)
|
873
|
+
super().delete()
|
874
|
+
return
|
875
|
+
|
833
876
|
|
834
877
|
# This is for backward compatibility
|
835
878
|
DSCJobRun = DataScienceJobRun
|
ads/model/datascience_model.py
CHANGED
@@ -238,7 +238,7 @@ class DataScienceModel(Builder):
|
|
238
238
|
CONST_MODEL_VERSION_ID: "version_id",
|
239
239
|
CONST_TIME_CREATED: "time_created",
|
240
240
|
CONST_LIFECYCLE_STATE: "lifecycle_state",
|
241
|
-
CONST_MODEL_FILE_DESCRIPTION: "
|
241
|
+
CONST_MODEL_FILE_DESCRIPTION: "model_description",
|
242
242
|
}
|
243
243
|
|
244
244
|
def __init__(self, spec: Dict = None, **kwargs) -> None:
|