oracle-ads 2.12.9__py3-none-any.whl → 2.12.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +4 -3
- ads/aqua/app.py +28 -16
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +799 -0
- ads/aqua/common/enums.py +3 -0
- ads/aqua/common/utils.py +62 -2
- ads/aqua/data.py +2 -19
- ads/aqua/evaluation/evaluation.py +20 -12
- ads/aqua/extension/aqua_ws_msg_handler.py +14 -7
- ads/aqua/extension/base_handler.py +12 -9
- ads/aqua/extension/finetune_handler.py +8 -14
- ads/aqua/extension/model_handler.py +24 -2
- ads/aqua/finetuning/constants.py +5 -2
- ads/aqua/finetuning/entities.py +67 -17
- ads/aqua/finetuning/finetuning.py +69 -54
- ads/aqua/model/entities.py +3 -1
- ads/aqua/model/model.py +196 -98
- ads/aqua/modeldeployment/deployment.py +22 -10
- ads/cli.py +16 -8
- ads/common/auth.py +9 -9
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +12 -11
- ads/model/__init__.py +11 -13
- ads/model/artifact.py +47 -8
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/generic_model.py +26 -24
- ads/model/model_metadata.py +8 -7
- ads/opctl/config/merger.py +13 -14
- ads/opctl/operator/common/operator_config.py +4 -4
- ads/opctl/operator/lowcode/common/transformations.py +50 -8
- ads/opctl/operator/lowcode/common/utils.py +22 -6
- ads/opctl/operator/lowcode/forecast/__main__.py +10 -0
- ads/opctl/operator/lowcode/forecast/const.py +2 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
- ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
- ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +61 -14
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +1 -1
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
- ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
- ads/opctl/operator/lowcode/forecast/operator_config.py +31 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +76 -0
- ads/opctl/operator/lowcode/forecast/utils.py +4 -3
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +233 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +238 -0
- ads/telemetry/base.py +18 -11
- ads/telemetry/client.py +33 -13
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/METADATA +9 -8
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/RECORD +74 -48
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/entry_points.txt +0 -0
ads/aqua/common/enums.py
CHANGED
@@ -52,6 +52,9 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
|
|
52
52
|
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
|
53
53
|
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
|
54
54
|
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
|
55
|
+
|
56
|
+
|
57
|
+
class CustomInferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
|
55
58
|
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
|
56
59
|
|
57
60
|
|
ads/aqua/common/utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
"""AQUA utils and constants."""
|
5
5
|
|
@@ -11,6 +11,7 @@ import os
|
|
11
11
|
import random
|
12
12
|
import re
|
13
13
|
import shlex
|
14
|
+
import shutil
|
14
15
|
import subprocess
|
15
16
|
from datetime import datetime, timedelta
|
16
17
|
from functools import wraps
|
@@ -21,6 +22,8 @@ from typing import List, Union
|
|
21
22
|
import fsspec
|
22
23
|
import oci
|
23
24
|
from cachetools import TTLCache, cached
|
25
|
+
from huggingface_hub.constants import HF_HUB_CACHE
|
26
|
+
from huggingface_hub.file_download import repo_folder_name
|
24
27
|
from huggingface_hub.hf_api import HfApi, ModelInfo
|
25
28
|
from huggingface_hub.utils import (
|
26
29
|
GatedRepoError,
|
@@ -30,6 +33,7 @@ from huggingface_hub.utils import (
|
|
30
33
|
)
|
31
34
|
from oci.data_science.models import JobRun, Model
|
32
35
|
from oci.object_storage.models import ObjectSummary
|
36
|
+
from pydantic import ValidationError
|
33
37
|
|
34
38
|
from ads.aqua.common.enums import (
|
35
39
|
InferenceContainerParamType,
|
@@ -788,7 +792,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
|
|
788
792
|
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
|
789
793
|
|
790
794
|
|
791
|
-
def upload_folder(
|
795
|
+
def upload_folder(
|
796
|
+
os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None
|
797
|
+
) -> str:
|
792
798
|
"""Upload the local folder to the object storage
|
793
799
|
|
794
800
|
Args:
|
@@ -818,6 +824,48 @@ def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern
|
|
818
824
|
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
|
819
825
|
|
820
826
|
|
827
|
+
def cleanup_local_hf_model_artifact(
|
828
|
+
model_name: str,
|
829
|
+
local_dir: str = None,
|
830
|
+
):
|
831
|
+
"""
|
832
|
+
Helper function that deletes local artifacts downloaded from Hugging Face to free up disk space.
|
833
|
+
Parameters
|
834
|
+
----------
|
835
|
+
model_name (str): Name of the huggingface model
|
836
|
+
local_dir (str): Local directory where the object is downloaded
|
837
|
+
|
838
|
+
"""
|
839
|
+
if local_dir and os.path.exists(local_dir):
|
840
|
+
model_dir = os.path.join(local_dir, model_name)
|
841
|
+
model_dir = (
|
842
|
+
os.path.dirname(model_dir)
|
843
|
+
if "/" in model_name or os.sep in model_name
|
844
|
+
else model_dir
|
845
|
+
)
|
846
|
+
shutil.rmtree(model_dir, ignore_errors=True)
|
847
|
+
if os.path.exists(model_dir):
|
848
|
+
logger.debug(
|
849
|
+
f"Could not delete local model artifact directory: {model_dir}"
|
850
|
+
)
|
851
|
+
else:
|
852
|
+
logger.debug(f"Deleted local model artifact directory: {model_dir}.")
|
853
|
+
|
854
|
+
hf_local_path = os.path.join(
|
855
|
+
HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model")
|
856
|
+
)
|
857
|
+
shutil.rmtree(hf_local_path, ignore_errors=True)
|
858
|
+
|
859
|
+
if os.path.exists(hf_local_path):
|
860
|
+
logger.debug(
|
861
|
+
f"Could not clear the local Hugging Face cache directory {hf_local_path} for the model {model_name}."
|
862
|
+
)
|
863
|
+
else:
|
864
|
+
logger.debug(
|
865
|
+
f"Cleared contents of local Hugging Face cache directory {hf_local_path} for the model {model_name}."
|
866
|
+
)
|
867
|
+
|
868
|
+
|
821
869
|
def is_service_managed_container(container):
|
822
870
|
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
|
823
871
|
|
@@ -1159,3 +1207,15 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
|
|
1159
1207
|
|
1160
1208
|
combined_cmd_var = cmd_var + overrides
|
1161
1209
|
return combined_cmd_var
|
1210
|
+
|
1211
|
+
|
1212
|
+
def build_pydantic_error_message(ex: ValidationError):
|
1213
|
+
"""Added to handle error messages from pydantic model validator.
|
1214
|
+
Combine both loc and msg for errors where loc (field) is present in error details, else only build error
|
1215
|
+
message using msg field."""
|
1216
|
+
|
1217
|
+
return {
|
1218
|
+
".".join(map(str, e["loc"])): e["msg"]
|
1219
|
+
for e in ex.errors()
|
1220
|
+
if "loc" in e and e["loc"]
|
1221
|
+
} or "; ".join(e["msg"] for e in ex.errors())
|
ads/aqua/data.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
#
|
3
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
6
|
-
from dataclasses import dataclass
|
5
|
+
from dataclasses import dataclass
|
7
6
|
|
8
7
|
from ads.common.serializer import DataClassSerializable
|
9
8
|
|
@@ -13,19 +12,3 @@ class AquaResourceIdentifier(DataClassSerializable):
|
|
13
12
|
id: str = ""
|
14
13
|
name: str = ""
|
15
14
|
url: str = ""
|
16
|
-
|
17
|
-
|
18
|
-
@dataclass(repr=False)
|
19
|
-
class AquaJobSummary(DataClassSerializable):
|
20
|
-
"""Represents an Aqua job summary."""
|
21
|
-
|
22
|
-
id: str
|
23
|
-
name: str
|
24
|
-
console_url: str
|
25
|
-
lifecycle_state: str
|
26
|
-
lifecycle_details: str
|
27
|
-
time_created: str
|
28
|
-
tags: dict
|
29
|
-
experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
|
30
|
-
source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
|
31
|
-
job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
import base64
|
5
5
|
import json
|
@@ -199,11 +199,11 @@ class AquaEvaluationApp(AquaApp):
|
|
199
199
|
eval_inference_configuration = (
|
200
200
|
container.spec.evaluation_configuration
|
201
201
|
)
|
202
|
-
except Exception:
|
202
|
+
except Exception as ex:
|
203
203
|
logger.debug(
|
204
204
|
f"Could not load inference config details for the evaluation source id: "
|
205
205
|
f"{create_aqua_evaluation_details.evaluation_source_id}. Please check if the container"
|
206
|
-
f" runtime has the correct SMC image information
|
206
|
+
f" runtime has the correct SMC image information.\nError: {str(ex)}"
|
207
207
|
)
|
208
208
|
elif (
|
209
209
|
DataScienceResource.MODEL
|
@@ -289,7 +289,7 @@ class AquaEvaluationApp(AquaApp):
|
|
289
289
|
f"Invalid experiment name. Please provide an experiment with `{Tags.AQUA_EVALUATION}` in tags."
|
290
290
|
)
|
291
291
|
except Exception:
|
292
|
-
logger.
|
292
|
+
logger.info(
|
293
293
|
f"Model version set {experiment_model_version_set_name} doesn't exist. "
|
294
294
|
"Creating new model version set."
|
295
295
|
)
|
@@ -711,21 +711,27 @@ class AquaEvaluationApp(AquaApp):
|
|
711
711
|
try:
|
712
712
|
log = utils.query_resource(log_id, return_all=False)
|
713
713
|
log_name = log.display_name if log else ""
|
714
|
-
except Exception:
|
714
|
+
except Exception as ex:
|
715
|
+
logger.debug(f"Failed to get associated log name. Error: {ex}")
|
715
716
|
pass
|
716
717
|
|
717
718
|
if loggroup_id:
|
718
719
|
try:
|
719
720
|
loggroup = utils.query_resource(loggroup_id, return_all=False)
|
720
721
|
loggroup_name = loggroup.display_name if loggroup else ""
|
721
|
-
except Exception:
|
722
|
+
except Exception as ex:
|
723
|
+
logger.debug(f"Failed to get associated loggroup name. Error: {ex}")
|
722
724
|
pass
|
723
725
|
|
724
726
|
try:
|
725
727
|
introspection = json.loads(
|
726
728
|
self._get_attribute_from_model_metadata(resource, "ArtifactTestResults")
|
727
729
|
)
|
728
|
-
except Exception:
|
730
|
+
except Exception as ex:
|
731
|
+
logger.debug(
|
732
|
+
f"There was an issue loading the model attribute as json object for evaluation {eval_id}. "
|
733
|
+
f"Setting introspection to empty.\n Error:{ex}"
|
734
|
+
)
|
729
735
|
introspection = {}
|
730
736
|
|
731
737
|
summary = AquaEvaluationDetail(
|
@@ -878,13 +884,13 @@ class AquaEvaluationApp(AquaApp):
|
|
878
884
|
try:
|
879
885
|
log_id = job_run_details.log_details.log_id
|
880
886
|
except Exception as e:
|
881
|
-
logger.debug(f"Failed to get associated log
|
887
|
+
logger.debug(f"Failed to get associated log.\nError: {str(e)}")
|
882
888
|
log_id = ""
|
883
889
|
|
884
890
|
try:
|
885
891
|
loggroup_id = job_run_details.log_details.log_group_id
|
886
892
|
except Exception as e:
|
887
|
-
logger.debug(f"Failed to get associated log
|
893
|
+
logger.debug(f"Failed to get associated log.\nError: {str(e)}")
|
888
894
|
loggroup_id = ""
|
889
895
|
|
890
896
|
loggroup_url = get_log_links(region=self.region, log_group_id=loggroup_id)
|
@@ -958,7 +964,7 @@ class AquaEvaluationApp(AquaApp):
|
|
958
964
|
)
|
959
965
|
except Exception as e:
|
960
966
|
logger.debug(
|
961
|
-
"Failed to load `report.json` from evaluation artifact
|
967
|
+
f"Failed to load `report.json` from evaluation artifact.\nError: {str(e)}"
|
962
968
|
)
|
963
969
|
json_report = {}
|
964
970
|
|
@@ -1047,6 +1053,7 @@ class AquaEvaluationApp(AquaApp):
|
|
1047
1053
|
return report
|
1048
1054
|
|
1049
1055
|
with tempfile.TemporaryDirectory() as temp_dir:
|
1056
|
+
logger.info(f"Downloading evaluation artifact for {eval_id}.")
|
1050
1057
|
DataScienceModel.from_id(eval_id).download_artifact(
|
1051
1058
|
temp_dir,
|
1052
1059
|
auth=self._auth,
|
@@ -1200,6 +1207,7 @@ class AquaEvaluationApp(AquaApp):
|
|
1200
1207
|
def load_evaluation_config(self, container: Optional[str] = None) -> Dict:
|
1201
1208
|
"""Loads evaluation config."""
|
1202
1209
|
|
1210
|
+
logger.info("Loading evaluation container config.")
|
1203
1211
|
# retrieve the evaluation config by container family name
|
1204
1212
|
evaluation_config = get_evaluation_service_config(container)
|
1205
1213
|
|
@@ -1279,9 +1287,9 @@ class AquaEvaluationApp(AquaApp):
|
|
1279
1287
|
raise AquaRuntimeError(
|
1280
1288
|
f"Not supported source type: {resource_type}"
|
1281
1289
|
)
|
1282
|
-
except Exception:
|
1290
|
+
except Exception as ex:
|
1283
1291
|
logger.debug(
|
1284
|
-
f"Failed to retrieve source information for evaluation {evaluation.identifier}
|
1292
|
+
f"Failed to retrieve source information for evaluation {evaluation.identifier}.\nError: {str(ex)}"
|
1285
1293
|
)
|
1286
1294
|
source_name = ""
|
1287
1295
|
|
@@ -1,10 +1,10 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
3
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
import traceback
|
7
|
+
import uuid
|
8
8
|
from abc import abstractmethod
|
9
9
|
from http.client import responses
|
10
10
|
from typing import List
|
@@ -34,7 +34,7 @@ class AquaWSMsgHandler:
|
|
34
34
|
self.telemetry = TelemetryClient(
|
35
35
|
bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS
|
36
36
|
)
|
37
|
-
except:
|
37
|
+
except Exception:
|
38
38
|
pass
|
39
39
|
|
40
40
|
@staticmethod
|
@@ -66,16 +66,23 @@ class AquaWSMsgHandler:
|
|
66
66
|
"message": message,
|
67
67
|
"service_payload": service_payload,
|
68
68
|
"reason": reason,
|
69
|
+
"request_id": str(uuid.uuid4()),
|
69
70
|
}
|
70
71
|
exc_info = kwargs.get("exc_info")
|
71
72
|
if exc_info:
|
72
|
-
logger.error(
|
73
|
+
logger.error(
|
74
|
+
f"Error Request ID: {reply['request_id']}\n"
|
75
|
+
f"Error: {''.join(traceback.format_exception(*exc_info))}"
|
76
|
+
)
|
73
77
|
e = exc_info[1]
|
74
78
|
if isinstance(e, HTTPError):
|
75
79
|
reply["message"] = e.log_message or message
|
76
80
|
reply["reason"] = e.reason
|
77
|
-
|
78
|
-
|
81
|
+
|
82
|
+
logger.error(
|
83
|
+
f"Error Request ID: {reply['request_id']}\n"
|
84
|
+
f"Error: {reply['message']} {reply['reason']}"
|
85
|
+
)
|
79
86
|
# telemetry may not be present if there is an error while initializing
|
80
87
|
if hasattr(self, "telemetry"):
|
81
88
|
aqua_api_details = kwargs.get("aqua_api_details", {})
|
@@ -83,7 +90,7 @@ class AquaWSMsgHandler:
|
|
83
90
|
category="aqua/error",
|
84
91
|
action=str(status_code),
|
85
92
|
value=reason,
|
86
|
-
**aqua_api_details
|
93
|
+
**aqua_api_details,
|
87
94
|
)
|
88
95
|
response = AquaWsError(
|
89
96
|
status=status_code,
|
@@ -1,6 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
#
|
3
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
6
5
|
|
@@ -35,7 +34,7 @@ class AquaAPIhandler(APIHandler):
|
|
35
34
|
self.telemetry = TelemetryClient(
|
36
35
|
bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS
|
37
36
|
)
|
38
|
-
except:
|
37
|
+
except Exception:
|
39
38
|
pass
|
40
39
|
|
41
40
|
@staticmethod
|
@@ -82,19 +81,23 @@ class AquaAPIhandler(APIHandler):
|
|
82
81
|
"message": message,
|
83
82
|
"service_payload": service_payload,
|
84
83
|
"reason": reason,
|
84
|
+
"request_id": str(uuid.uuid4()),
|
85
85
|
}
|
86
86
|
exc_info = kwargs.get("exc_info")
|
87
87
|
if exc_info:
|
88
|
-
logger.error(
|
88
|
+
logger.error(
|
89
|
+
f"Error Request ID: {reply['request_id']}\n"
|
90
|
+
f"Error: {''.join(traceback.format_exception(*exc_info))}"
|
91
|
+
)
|
89
92
|
e = exc_info[1]
|
90
93
|
if isinstance(e, HTTPError):
|
91
94
|
reply["message"] = e.log_message or message
|
92
95
|
reply["reason"] = e.reason if e.reason else reply["reason"]
|
93
|
-
reply["request_id"] = str(uuid.uuid4())
|
94
|
-
else:
|
95
|
-
reply["request_id"] = str(uuid.uuid4())
|
96
96
|
|
97
|
-
logger.
|
97
|
+
logger.error(
|
98
|
+
f"Error Request ID: {reply['request_id']}\n"
|
99
|
+
f"Error: {reply['message']} {reply['reason']}"
|
100
|
+
)
|
98
101
|
|
99
102
|
# telemetry may not be present if there is an error while initializing
|
100
103
|
if hasattr(self, "telemetry"):
|
@@ -103,7 +106,7 @@ class AquaAPIhandler(APIHandler):
|
|
103
106
|
category="aqua/error",
|
104
107
|
action=str(status_code),
|
105
108
|
value=reason,
|
106
|
-
**aqua_api_details
|
109
|
+
**aqua_api_details,
|
107
110
|
)
|
108
111
|
|
109
112
|
self.finish(json.dumps(reply))
|
@@ -1,5 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
5
|
|
@@ -10,9 +10,7 @@ from tornado.web import HTTPError
|
|
10
10
|
from ads.aqua.common.decorator import handle_exceptions
|
11
11
|
from ads.aqua.extension.base_handler import AquaAPIhandler
|
12
12
|
from ads.aqua.extension.errors import Errors
|
13
|
-
from ads.aqua.extension.utils import validate_function_parameters
|
14
13
|
from ads.aqua.finetuning import AquaFineTuningApp
|
15
|
-
from ads.aqua.finetuning.entities import CreateFineTuningDetails
|
16
14
|
|
17
15
|
|
18
16
|
class AquaFineTuneHandler(AquaAPIhandler):
|
@@ -33,7 +31,7 @@ class AquaFineTuneHandler(AquaAPIhandler):
|
|
33
31
|
raise HTTPError(400, f"The request {self.request.path} is invalid.")
|
34
32
|
|
35
33
|
@handle_exceptions
|
36
|
-
def post(self, *args, **kwargs):
|
34
|
+
def post(self, *args, **kwargs): # noqa: ARG002
|
37
35
|
"""Handles post request for the fine-tuning API
|
38
36
|
|
39
37
|
Raises
|
@@ -43,17 +41,13 @@ class AquaFineTuneHandler(AquaAPIhandler):
|
|
43
41
|
"""
|
44
42
|
try:
|
45
43
|
input_data = self.get_json_body()
|
46
|
-
except Exception:
|
47
|
-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
44
|
+
except Exception as ex:
|
45
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
|
48
46
|
|
49
47
|
if not input_data:
|
50
48
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
51
49
|
|
52
|
-
|
53
|
-
data_class=CreateFineTuningDetails, input_data=input_data
|
54
|
-
)
|
55
|
-
|
56
|
-
self.finish(AquaFineTuningApp().create(CreateFineTuningDetails(**input_data)))
|
50
|
+
self.finish(AquaFineTuningApp().create(**input_data))
|
57
51
|
|
58
52
|
def get_finetuning_config(self, model_id):
|
59
53
|
"""Gets the finetuning config for Aqua model."""
|
@@ -71,7 +65,7 @@ class AquaFineTuneParamsHandler(AquaAPIhandler):
|
|
71
65
|
)
|
72
66
|
|
73
67
|
@handle_exceptions
|
74
|
-
def post(self, *args, **kwargs):
|
68
|
+
def post(self, *args, **kwargs): # noqa: ARG002
|
75
69
|
"""Handles post request for the finetuning param handler API.
|
76
70
|
|
77
71
|
Raises
|
@@ -81,8 +75,8 @@ class AquaFineTuneParamsHandler(AquaAPIhandler):
|
|
81
75
|
"""
|
82
76
|
try:
|
83
77
|
input_data = self.get_json_body()
|
84
|
-
except Exception:
|
85
|
-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
78
|
+
except Exception as ex:
|
79
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
|
86
80
|
|
87
81
|
if not input_data:
|
88
82
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
5
|
from typing import Optional
|
@@ -8,6 +8,9 @@ from urllib.parse import urlparse
|
|
8
8
|
from tornado.web import HTTPError
|
9
9
|
|
10
10
|
from ads.aqua.common.decorator import handle_exceptions
|
11
|
+
from ads.aqua.common.enums import (
|
12
|
+
CustomInferenceContainerTypeFamily,
|
13
|
+
)
|
11
14
|
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
|
12
15
|
from ads.aqua.common.utils import (
|
13
16
|
get_hf_model_info,
|
@@ -128,17 +131,27 @@ class AquaModelHandler(AquaAPIhandler):
|
|
128
131
|
download_from_hf = (
|
129
132
|
str(input_data.get("download_from_hf", "false")).lower() == "true"
|
130
133
|
)
|
134
|
+
local_dir = input_data.get("local_dir")
|
135
|
+
cleanup_model_cache = (
|
136
|
+
str(input_data.get("cleanup_model_cache", "false")).lower() == "true"
|
137
|
+
)
|
131
138
|
inference_container_uri = input_data.get("inference_container_uri")
|
132
139
|
allow_patterns = input_data.get("allow_patterns")
|
133
140
|
ignore_patterns = input_data.get("ignore_patterns")
|
134
141
|
freeform_tags = input_data.get("freeform_tags")
|
135
142
|
defined_tags = input_data.get("defined_tags")
|
143
|
+
ignore_model_artifact_check = (
|
144
|
+
str(input_data.get("ignore_model_artifact_check", "false")).lower()
|
145
|
+
== "true"
|
146
|
+
)
|
136
147
|
|
137
148
|
return self.finish(
|
138
149
|
AquaModelApp().register(
|
139
150
|
model=model,
|
140
151
|
os_path=os_path,
|
141
152
|
download_from_hf=download_from_hf,
|
153
|
+
local_dir=local_dir,
|
154
|
+
cleanup_model_cache=cleanup_model_cache,
|
142
155
|
inference_container=inference_container,
|
143
156
|
finetuning_container=finetuning_container,
|
144
157
|
compartment_id=compartment_id,
|
@@ -149,6 +162,7 @@ class AquaModelHandler(AquaAPIhandler):
|
|
149
162
|
ignore_patterns=ignore_patterns,
|
150
163
|
freeform_tags=freeform_tags,
|
151
164
|
defined_tags=defined_tags,
|
165
|
+
ignore_model_artifact_check=ignore_model_artifact_check,
|
152
166
|
)
|
153
167
|
)
|
154
168
|
|
@@ -163,7 +177,9 @@ class AquaModelHandler(AquaAPIhandler):
|
|
163
177
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
164
178
|
|
165
179
|
inference_container = input_data.get("inference_container")
|
180
|
+
inference_container_uri = input_data.get("inference_container_uri")
|
166
181
|
inference_containers = AquaModelApp.list_valid_inference_containers()
|
182
|
+
inference_containers.extend(CustomInferenceContainerTypeFamily.values())
|
167
183
|
if (
|
168
184
|
inference_container is not None
|
169
185
|
and inference_container not in inference_containers
|
@@ -176,7 +192,13 @@ class AquaModelHandler(AquaAPIhandler):
|
|
176
192
|
task = input_data.get("task")
|
177
193
|
app = AquaModelApp()
|
178
194
|
self.finish(
|
179
|
-
app.edit_registered_model(
|
195
|
+
app.edit_registered_model(
|
196
|
+
id,
|
197
|
+
inference_container,
|
198
|
+
inference_container_uri,
|
199
|
+
enable_finetuning,
|
200
|
+
task,
|
201
|
+
)
|
180
202
|
)
|
181
203
|
app.clear_model_details_cache(model_id=id)
|
182
204
|
|
ads/aqua/finetuning/constants.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
#
|
3
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
6
5
|
from ads.common.extended_enum import ExtendedEnumMeta
|
@@ -17,4 +16,8 @@ class FineTuneCustomMetadata(str, metaclass=ExtendedEnumMeta):
|
|
17
16
|
SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container"
|
18
17
|
|
19
18
|
|
19
|
+
class FineTuningRestrictedParams(str, metaclass=ExtendedEnumMeta):
|
20
|
+
OPTIMIZER = "optimizer"
|
21
|
+
|
22
|
+
|
20
23
|
ENV_AQUA_FINE_TUNING_CONTAINER = "AQUA_FINE_TUNING_CONTAINER"
|
ads/aqua/finetuning/entities.py
CHANGED
@@ -1,18 +1,24 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
-
from dataclasses import dataclass, field
|
5
|
-
from typing import List, Optional
|
6
4
|
|
7
|
-
|
8
|
-
from
|
5
|
+
import json
|
6
|
+
from typing import List, Literal, Optional, Union
|
9
7
|
|
8
|
+
from pydantic import Field, model_validator
|
10
9
|
|
11
|
-
|
12
|
-
|
13
|
-
|
10
|
+
from ads.aqua.common.errors import AquaValueError
|
11
|
+
from ads.aqua.config.utils.serializer import Serializable
|
12
|
+
from ads.aqua.data import AquaResourceIdentifier
|
13
|
+
from ads.aqua.finetuning.constants import FineTuningRestrictedParams
|
14
|
+
|
15
|
+
|
16
|
+
class AquaFineTuningParams(Serializable):
|
17
|
+
"""Class for maintaining aqua fine-tuning model parameters"""
|
18
|
+
|
19
|
+
epochs: Optional[int] = None
|
14
20
|
learning_rate: Optional[float] = None
|
15
|
-
sample_packing:
|
21
|
+
sample_packing: Union[bool, None, Literal["auto"]] = "auto"
|
16
22
|
batch_size: Optional[int] = (
|
17
23
|
None # make it batch_size for user, but internally this is micro_batch_size
|
18
24
|
)
|
@@ -22,21 +28,59 @@ class AquaFineTuningParams(DataClassSerializable):
|
|
22
28
|
lora_alpha: Optional[int] = None
|
23
29
|
lora_dropout: Optional[float] = None
|
24
30
|
lora_target_linear: Optional[bool] = None
|
25
|
-
lora_target_modules: Optional[List] = None
|
31
|
+
lora_target_modules: Optional[List[str]] = None
|
26
32
|
early_stopping_patience: Optional[int] = None
|
27
33
|
early_stopping_threshold: Optional[float] = None
|
28
34
|
|
35
|
+
class Config:
|
36
|
+
extra = "allow"
|
37
|
+
|
38
|
+
def to_dict(self) -> dict:
|
39
|
+
return json.loads(super().to_json(exclude_none=True))
|
40
|
+
|
41
|
+
@model_validator(mode="before")
|
42
|
+
@classmethod
|
43
|
+
def validate_restricted_fields(cls, data: dict):
|
44
|
+
# we may want to skip validation if loading data from config files instead of user entered parameters
|
45
|
+
validate = data.pop("_validate", True)
|
46
|
+
if not (validate and isinstance(data, dict)):
|
47
|
+
return data
|
48
|
+
restricted_params = [
|
49
|
+
param for param in data if param in FineTuningRestrictedParams.values()
|
50
|
+
]
|
51
|
+
if restricted_params:
|
52
|
+
raise AquaValueError(
|
53
|
+
f"Found restricted parameter name: {restricted_params}"
|
54
|
+
)
|
55
|
+
return data
|
29
56
|
|
30
|
-
@dataclass(repr=False)
|
31
|
-
class AquaFineTuningSummary(AquaJobSummary, DataClassSerializable):
|
32
|
-
parameters: AquaFineTuningParams = field(default_factory=AquaFineTuningParams)
|
33
57
|
|
58
|
+
class AquaFineTuningSummary(Serializable):
|
59
|
+
"""Represents a summary of Aqua Finetuning job."""
|
34
60
|
|
35
|
-
|
36
|
-
|
37
|
-
|
61
|
+
id: str
|
62
|
+
name: str
|
63
|
+
console_url: str
|
64
|
+
lifecycle_state: str
|
65
|
+
lifecycle_details: str
|
66
|
+
time_created: str
|
67
|
+
tags: dict
|
68
|
+
experiment: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
|
69
|
+
source: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
|
70
|
+
job: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
|
71
|
+
parameters: AquaFineTuningParams = Field(default_factory=AquaFineTuningParams)
|
38
72
|
|
39
|
-
|
73
|
+
class Config:
|
74
|
+
extra = "ignore"
|
75
|
+
|
76
|
+
def to_dict(self) -> dict:
|
77
|
+
return json.loads(super().to_json(exclude_none=True))
|
78
|
+
|
79
|
+
|
80
|
+
class CreateFineTuningDetails(Serializable):
|
81
|
+
"""Class to create aqua model fine-tuning instance.
|
82
|
+
|
83
|
+
Properties
|
40
84
|
------
|
41
85
|
ft_source_id: str
|
42
86
|
The fine tuning source id. Must be model ocid.
|
@@ -78,6 +122,8 @@ class CreateFineTuningDetails(DataClassSerializable):
|
|
78
122
|
The log group id for fine tuning job infrastructure.
|
79
123
|
log_id: (str, optional). Defaults to `None`.
|
80
124
|
The log id for fine tuning job infrastructure.
|
125
|
+
watch_logs: (bool, optional). Defaults to `False`.
|
126
|
+
The flag to watch the job run logs when a fine-tuning job is created.
|
81
127
|
force_overwrite: (bool, optional). Defaults to `False`.
|
82
128
|
Whether to force overwrite the existing file in object storage.
|
83
129
|
freeform_tags: (dict, optional)
|
@@ -104,6 +150,10 @@ class CreateFineTuningDetails(DataClassSerializable):
|
|
104
150
|
subnet_id: Optional[str] = None
|
105
151
|
log_id: Optional[str] = None
|
106
152
|
log_group_id: Optional[str] = None
|
153
|
+
watch_logs: Optional[bool] = False
|
107
154
|
force_overwrite: Optional[bool] = False
|
108
155
|
freeform_tags: Optional[dict] = None
|
109
156
|
defined_tags: Optional[dict] = None
|
157
|
+
|
158
|
+
class Config:
|
159
|
+
extra = "ignore"
|