oracle-ads 2.12.9__py3-none-any.whl → 2.12.10rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +4 -4
- ads/aqua/common/enums.py +3 -0
- ads/aqua/common/utils.py +62 -2
- ads/aqua/data.py +2 -19
- ads/aqua/extension/finetune_handler.py +8 -14
- ads/aqua/extension/model_handler.py +19 -2
- ads/aqua/finetuning/constants.py +5 -2
- ads/aqua/finetuning/entities.py +64 -17
- ads/aqua/finetuning/finetuning.py +38 -54
- ads/aqua/model/entities.py +2 -1
- ads/aqua/model/model.py +61 -23
- ads/common/auth.py +9 -9
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +12 -11
- ads/model/__init__.py +11 -13
- ads/model/artifact.py +47 -8
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/generic_model.py +26 -24
- ads/model/model_metadata.py +8 -7
- ads/opctl/config/merger.py +13 -14
- ads/opctl/operator/common/operator_config.py +4 -4
- ads/opctl/operator/lowcode/common/transformations.py +12 -5
- ads/opctl/operator/lowcode/common/utils.py +11 -5
- ads/opctl/operator/lowcode/forecast/const.py +2 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
- ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
- ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +61 -14
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
- ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
- ads/opctl/operator/lowcode/forecast/schema.yaml +13 -0
- ads/opctl/operator/lowcode/forecast/utils.py +4 -3
- ads/telemetry/base.py +18 -11
- ads/telemetry/client.py +33 -13
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/METADATA +7 -8
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/RECORD +60 -39
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/entry_points.txt +0 -0
ads/aqua/__init__.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
#
|
3
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
6
5
|
|
7
6
|
import os
|
7
|
+
from logging import getLogger
|
8
8
|
|
9
|
-
from ads import
|
9
|
+
from ads import set_auth
|
10
10
|
from ads.aqua.common.utils import fetch_service_compartment
|
11
11
|
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
|
12
12
|
|
@@ -19,6 +19,7 @@ def get_logger_level():
|
|
19
19
|
return level
|
20
20
|
|
21
21
|
|
22
|
+
logger = getLogger(__name__)
|
22
23
|
logger.setLevel(get_logger_level())
|
23
24
|
|
24
25
|
|
@@ -27,7 +28,6 @@ def set_log_level(log_level: str):
|
|
27
28
|
|
28
29
|
log_level = log_level.upper()
|
29
30
|
logger.setLevel(log_level.upper())
|
30
|
-
logger.handlers[0].setLevel(log_level)
|
31
31
|
|
32
32
|
|
33
33
|
if OCI_RESOURCE_PRINCIPAL_VERSION:
|
ads/aqua/common/enums.py
CHANGED
@@ -52,6 +52,9 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
|
|
52
52
|
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
|
53
53
|
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
|
54
54
|
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
|
55
|
+
|
56
|
+
|
57
|
+
class CustomInferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
|
55
58
|
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
|
56
59
|
|
57
60
|
|
ads/aqua/common/utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
"""AQUA utils and constants."""
|
5
5
|
|
@@ -11,6 +11,7 @@ import os
|
|
11
11
|
import random
|
12
12
|
import re
|
13
13
|
import shlex
|
14
|
+
import shutil
|
14
15
|
import subprocess
|
15
16
|
from datetime import datetime, timedelta
|
16
17
|
from functools import wraps
|
@@ -21,6 +22,8 @@ from typing import List, Union
|
|
21
22
|
import fsspec
|
22
23
|
import oci
|
23
24
|
from cachetools import TTLCache, cached
|
25
|
+
from huggingface_hub.constants import HF_HUB_CACHE
|
26
|
+
from huggingface_hub.file_download import repo_folder_name
|
24
27
|
from huggingface_hub.hf_api import HfApi, ModelInfo
|
25
28
|
from huggingface_hub.utils import (
|
26
29
|
GatedRepoError,
|
@@ -30,6 +33,7 @@ from huggingface_hub.utils import (
|
|
30
33
|
)
|
31
34
|
from oci.data_science.models import JobRun, Model
|
32
35
|
from oci.object_storage.models import ObjectSummary
|
36
|
+
from pydantic import ValidationError
|
33
37
|
|
34
38
|
from ads.aqua.common.enums import (
|
35
39
|
InferenceContainerParamType,
|
@@ -788,7 +792,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
|
|
788
792
|
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
|
789
793
|
|
790
794
|
|
791
|
-
def upload_folder(
|
795
|
+
def upload_folder(
|
796
|
+
os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None
|
797
|
+
) -> str:
|
792
798
|
"""Upload the local folder to the object storage
|
793
799
|
|
794
800
|
Args:
|
@@ -818,6 +824,48 @@ def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern
|
|
818
824
|
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
|
819
825
|
|
820
826
|
|
827
|
+
def cleanup_local_hf_model_artifact(
|
828
|
+
model_name: str,
|
829
|
+
local_dir: str = None,
|
830
|
+
):
|
831
|
+
"""
|
832
|
+
Helper function that deletes local artifacts downloaded from Hugging Face to free up disk space.
|
833
|
+
Parameters
|
834
|
+
----------
|
835
|
+
model_name (str): Name of the huggingface model
|
836
|
+
local_dir (str): Local directory where the object is downloaded
|
837
|
+
|
838
|
+
"""
|
839
|
+
if local_dir and os.path.exists(local_dir):
|
840
|
+
model_dir = os.path.join(local_dir, model_name)
|
841
|
+
model_dir = (
|
842
|
+
os.path.dirname(model_dir)
|
843
|
+
if "/" in model_name or os.sep in model_name
|
844
|
+
else model_dir
|
845
|
+
)
|
846
|
+
shutil.rmtree(model_dir, ignore_errors=True)
|
847
|
+
if os.path.exists(model_dir):
|
848
|
+
logger.debug(
|
849
|
+
f"Could not delete local model artifact directory: {model_dir}"
|
850
|
+
)
|
851
|
+
else:
|
852
|
+
logger.debug(f"Deleted local model artifact directory: {model_dir}.")
|
853
|
+
|
854
|
+
hf_local_path = os.path.join(
|
855
|
+
HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model")
|
856
|
+
)
|
857
|
+
shutil.rmtree(hf_local_path, ignore_errors=True)
|
858
|
+
|
859
|
+
if os.path.exists(hf_local_path):
|
860
|
+
logger.debug(
|
861
|
+
f"Could not clear the local Hugging Face cache directory {hf_local_path} for the model {model_name}."
|
862
|
+
)
|
863
|
+
else:
|
864
|
+
logger.debug(
|
865
|
+
f"Cleared contents of local Hugging Face cache directory {hf_local_path} for the model {model_name}."
|
866
|
+
)
|
867
|
+
|
868
|
+
|
821
869
|
def is_service_managed_container(container):
|
822
870
|
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
|
823
871
|
|
@@ -1159,3 +1207,15 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
|
|
1159
1207
|
|
1160
1208
|
combined_cmd_var = cmd_var + overrides
|
1161
1209
|
return combined_cmd_var
|
1210
|
+
|
1211
|
+
|
1212
|
+
def build_pydantic_error_message(ex: ValidationError):
|
1213
|
+
"""Added to handle error messages from pydantic model validator.
|
1214
|
+
Combine both loc and msg for errors where loc (field) is present in error details, else only build error
|
1215
|
+
message using msg field."""
|
1216
|
+
|
1217
|
+
return {
|
1218
|
+
".".join(map(str, e["loc"])): e["msg"]
|
1219
|
+
for e in ex.errors()
|
1220
|
+
if "loc" in e and e["loc"]
|
1221
|
+
} or "; ".join(e["msg"] for e in ex.errors())
|
ads/aqua/data.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
#
|
3
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
6
|
-
from dataclasses import dataclass
|
5
|
+
from dataclasses import dataclass
|
7
6
|
|
8
7
|
from ads.common.serializer import DataClassSerializable
|
9
8
|
|
@@ -13,19 +12,3 @@ class AquaResourceIdentifier(DataClassSerializable):
|
|
13
12
|
id: str = ""
|
14
13
|
name: str = ""
|
15
14
|
url: str = ""
|
16
|
-
|
17
|
-
|
18
|
-
@dataclass(repr=False)
|
19
|
-
class AquaJobSummary(DataClassSerializable):
|
20
|
-
"""Represents an Aqua job summary."""
|
21
|
-
|
22
|
-
id: str
|
23
|
-
name: str
|
24
|
-
console_url: str
|
25
|
-
lifecycle_state: str
|
26
|
-
lifecycle_details: str
|
27
|
-
time_created: str
|
28
|
-
tags: dict
|
29
|
-
experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
|
30
|
-
source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
|
31
|
-
job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
5
|
|
@@ -10,9 +10,7 @@ from tornado.web import HTTPError
|
|
10
10
|
from ads.aqua.common.decorator import handle_exceptions
|
11
11
|
from ads.aqua.extension.base_handler import AquaAPIhandler
|
12
12
|
from ads.aqua.extension.errors import Errors
|
13
|
-
from ads.aqua.extension.utils import validate_function_parameters
|
14
13
|
from ads.aqua.finetuning import AquaFineTuningApp
|
15
|
-
from ads.aqua.finetuning.entities import CreateFineTuningDetails
|
16
14
|
|
17
15
|
|
18
16
|
class AquaFineTuneHandler(AquaAPIhandler):
|
@@ -33,7 +31,7 @@ class AquaFineTuneHandler(AquaAPIhandler):
|
|
33
31
|
raise HTTPError(400, f"The request {self.request.path} is invalid.")
|
34
32
|
|
35
33
|
@handle_exceptions
|
36
|
-
def post(self, *args, **kwargs):
|
34
|
+
def post(self, *args, **kwargs): # noqa: ARG002
|
37
35
|
"""Handles post request for the fine-tuning API
|
38
36
|
|
39
37
|
Raises
|
@@ -43,17 +41,13 @@ class AquaFineTuneHandler(AquaAPIhandler):
|
|
43
41
|
"""
|
44
42
|
try:
|
45
43
|
input_data = self.get_json_body()
|
46
|
-
except Exception:
|
47
|
-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
44
|
+
except Exception as ex:
|
45
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
|
48
46
|
|
49
47
|
if not input_data:
|
50
48
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
51
49
|
|
52
|
-
|
53
|
-
data_class=CreateFineTuningDetails, input_data=input_data
|
54
|
-
)
|
55
|
-
|
56
|
-
self.finish(AquaFineTuningApp().create(CreateFineTuningDetails(**input_data)))
|
50
|
+
self.finish(AquaFineTuningApp().create(**input_data))
|
57
51
|
|
58
52
|
def get_finetuning_config(self, model_id):
|
59
53
|
"""Gets the finetuning config for Aqua model."""
|
@@ -71,7 +65,7 @@ class AquaFineTuneParamsHandler(AquaAPIhandler):
|
|
71
65
|
)
|
72
66
|
|
73
67
|
@handle_exceptions
|
74
|
-
def post(self, *args, **kwargs):
|
68
|
+
def post(self, *args, **kwargs): # noqa: ARG002
|
75
69
|
"""Handles post request for the finetuning param handler API.
|
76
70
|
|
77
71
|
Raises
|
@@ -81,8 +75,8 @@ class AquaFineTuneParamsHandler(AquaAPIhandler):
|
|
81
75
|
"""
|
82
76
|
try:
|
83
77
|
input_data = self.get_json_body()
|
84
|
-
except Exception:
|
85
|
-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
|
78
|
+
except Exception as ex:
|
79
|
+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
|
86
80
|
|
87
81
|
if not input_data:
|
88
82
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
5
|
from typing import Optional
|
@@ -8,6 +8,9 @@ from urllib.parse import urlparse
|
|
8
8
|
from tornado.web import HTTPError
|
9
9
|
|
10
10
|
from ads.aqua.common.decorator import handle_exceptions
|
11
|
+
from ads.aqua.common.enums import (
|
12
|
+
CustomInferenceContainerTypeFamily,
|
13
|
+
)
|
11
14
|
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
|
12
15
|
from ads.aqua.common.utils import (
|
13
16
|
get_hf_model_info,
|
@@ -128,6 +131,10 @@ class AquaModelHandler(AquaAPIhandler):
|
|
128
131
|
download_from_hf = (
|
129
132
|
str(input_data.get("download_from_hf", "false")).lower() == "true"
|
130
133
|
)
|
134
|
+
local_dir = input_data.get("local_dir")
|
135
|
+
cleanup_model_cache = (
|
136
|
+
str(input_data.get("cleanup_model_cache", "true")).lower() == "true"
|
137
|
+
)
|
131
138
|
inference_container_uri = input_data.get("inference_container_uri")
|
132
139
|
allow_patterns = input_data.get("allow_patterns")
|
133
140
|
ignore_patterns = input_data.get("ignore_patterns")
|
@@ -139,6 +146,8 @@ class AquaModelHandler(AquaAPIhandler):
|
|
139
146
|
model=model,
|
140
147
|
os_path=os_path,
|
141
148
|
download_from_hf=download_from_hf,
|
149
|
+
local_dir=local_dir,
|
150
|
+
cleanup_model_cache=cleanup_model_cache,
|
142
151
|
inference_container=inference_container,
|
143
152
|
finetuning_container=finetuning_container,
|
144
153
|
compartment_id=compartment_id,
|
@@ -163,7 +172,9 @@ class AquaModelHandler(AquaAPIhandler):
|
|
163
172
|
raise HTTPError(400, Errors.NO_INPUT_DATA)
|
164
173
|
|
165
174
|
inference_container = input_data.get("inference_container")
|
175
|
+
inference_container_uri = input_data.get("inference_container_uri")
|
166
176
|
inference_containers = AquaModelApp.list_valid_inference_containers()
|
177
|
+
inference_containers.extend(CustomInferenceContainerTypeFamily.values())
|
167
178
|
if (
|
168
179
|
inference_container is not None
|
169
180
|
and inference_container not in inference_containers
|
@@ -176,7 +187,13 @@ class AquaModelHandler(AquaAPIhandler):
|
|
176
187
|
task = input_data.get("task")
|
177
188
|
app = AquaModelApp()
|
178
189
|
self.finish(
|
179
|
-
app.edit_registered_model(
|
190
|
+
app.edit_registered_model(
|
191
|
+
id,
|
192
|
+
inference_container,
|
193
|
+
inference_container_uri,
|
194
|
+
enable_finetuning,
|
195
|
+
task,
|
196
|
+
)
|
180
197
|
)
|
181
198
|
app.clear_model_details_cache(model_id=id)
|
182
199
|
|
ads/aqua/finetuning/constants.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
#
|
3
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
6
5
|
from ads.common.extended_enum import ExtendedEnumMeta
|
@@ -17,4 +16,8 @@ class FineTuneCustomMetadata(str, metaclass=ExtendedEnumMeta):
|
|
17
16
|
SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container"
|
18
17
|
|
19
18
|
|
19
|
+
class FineTuningRestrictedParams(str, metaclass=ExtendedEnumMeta):
|
20
|
+
OPTIMIZER = "optimizer"
|
21
|
+
|
22
|
+
|
20
23
|
ENV_AQUA_FINE_TUNING_CONTAINER = "AQUA_FINE_TUNING_CONTAINER"
|
ads/aqua/finetuning/entities.py
CHANGED
@@ -1,18 +1,24 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
-
from dataclasses import dataclass, field
|
5
|
-
from typing import List, Optional
|
6
4
|
|
7
|
-
|
8
|
-
from
|
5
|
+
import json
|
6
|
+
from typing import List, Literal, Optional, Union
|
9
7
|
|
8
|
+
from pydantic import Field, model_validator
|
10
9
|
|
11
|
-
|
12
|
-
|
13
|
-
|
10
|
+
from ads.aqua.common.errors import AquaValueError
|
11
|
+
from ads.aqua.config.utils.serializer import Serializable
|
12
|
+
from ads.aqua.data import AquaResourceIdentifier
|
13
|
+
from ads.aqua.finetuning.constants import FineTuningRestrictedParams
|
14
|
+
|
15
|
+
|
16
|
+
class AquaFineTuningParams(Serializable):
|
17
|
+
"""Class for maintaining aqua fine-tuning model parameters"""
|
18
|
+
|
19
|
+
epochs: Optional[int] = None
|
14
20
|
learning_rate: Optional[float] = None
|
15
|
-
sample_packing:
|
21
|
+
sample_packing: Union[bool, None, Literal["auto"]] = "auto"
|
16
22
|
batch_size: Optional[int] = (
|
17
23
|
None # make it batch_size for user, but internally this is micro_batch_size
|
18
24
|
)
|
@@ -22,21 +28,59 @@ class AquaFineTuningParams(DataClassSerializable):
|
|
22
28
|
lora_alpha: Optional[int] = None
|
23
29
|
lora_dropout: Optional[float] = None
|
24
30
|
lora_target_linear: Optional[bool] = None
|
25
|
-
lora_target_modules: Optional[List] = None
|
31
|
+
lora_target_modules: Optional[List[str]] = None
|
26
32
|
early_stopping_patience: Optional[int] = None
|
27
33
|
early_stopping_threshold: Optional[float] = None
|
28
34
|
|
35
|
+
class Config:
|
36
|
+
extra = "allow"
|
37
|
+
|
38
|
+
def to_dict(self) -> dict:
|
39
|
+
return json.loads(super().to_json(exclude_none=True))
|
40
|
+
|
41
|
+
@model_validator(mode="before")
|
42
|
+
@classmethod
|
43
|
+
def validate_restricted_fields(cls, data: dict):
|
44
|
+
# we may want to skip validation if loading data from config files instead of user entered parameters
|
45
|
+
validate = data.pop("_validate", True)
|
46
|
+
if not (validate and isinstance(data, dict)):
|
47
|
+
return data
|
48
|
+
restricted_params = [
|
49
|
+
param for param in data if param in FineTuningRestrictedParams.values()
|
50
|
+
]
|
51
|
+
if restricted_params:
|
52
|
+
raise AquaValueError(
|
53
|
+
f"Found restricted parameter name: {restricted_params}"
|
54
|
+
)
|
55
|
+
return data
|
29
56
|
|
30
|
-
@dataclass(repr=False)
|
31
|
-
class AquaFineTuningSummary(AquaJobSummary, DataClassSerializable):
|
32
|
-
parameters: AquaFineTuningParams = field(default_factory=AquaFineTuningParams)
|
33
57
|
|
58
|
+
class AquaFineTuningSummary(Serializable):
|
59
|
+
"""Represents a summary of Aqua Finetuning job."""
|
34
60
|
|
35
|
-
|
36
|
-
|
37
|
-
|
61
|
+
id: str
|
62
|
+
name: str
|
63
|
+
console_url: str
|
64
|
+
lifecycle_state: str
|
65
|
+
lifecycle_details: str
|
66
|
+
time_created: str
|
67
|
+
tags: dict
|
68
|
+
experiment: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
|
69
|
+
source: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
|
70
|
+
job: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
|
71
|
+
parameters: AquaFineTuningParams = Field(default_factory=AquaFineTuningParams)
|
38
72
|
|
39
|
-
|
73
|
+
class Config:
|
74
|
+
extra = "ignore"
|
75
|
+
|
76
|
+
def to_dict(self) -> dict:
|
77
|
+
return json.loads(super().to_json(exclude_none=True))
|
78
|
+
|
79
|
+
|
80
|
+
class CreateFineTuningDetails(Serializable):
|
81
|
+
"""Class to create aqua model fine-tuning instance.
|
82
|
+
|
83
|
+
Properties
|
40
84
|
------
|
41
85
|
ft_source_id: str
|
42
86
|
The fine tuning source id. Must be model ocid.
|
@@ -107,3 +151,6 @@ class CreateFineTuningDetails(DataClassSerializable):
|
|
107
151
|
force_overwrite: Optional[bool] = False
|
108
152
|
freeform_tags: Optional[dict] = None
|
109
153
|
defined_tags: Optional[dict] = None
|
154
|
+
|
155
|
+
class Config:
|
156
|
+
extra = "ignore"
|
@@ -1,10 +1,9 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
5
|
import json
|
6
6
|
import os
|
7
|
-
from dataclasses import MISSING, asdict, fields
|
8
7
|
from typing import Dict
|
9
8
|
|
10
9
|
from oci.data_science.models import (
|
@@ -12,12 +11,14 @@ from oci.data_science.models import (
|
|
12
11
|
UpdateModelDetails,
|
13
12
|
UpdateModelProvenanceDetails,
|
14
13
|
)
|
14
|
+
from pydantic import ValidationError
|
15
15
|
|
16
16
|
from ads.aqua import logger
|
17
17
|
from ads.aqua.app import AquaApp
|
18
18
|
from ads.aqua.common.enums import Resource, Tags
|
19
19
|
from ads.aqua.common.errors import AquaFileExistsError, AquaValueError
|
20
20
|
from ads.aqua.common.utils import (
|
21
|
+
build_pydantic_error_message,
|
21
22
|
get_container_image,
|
22
23
|
upload_local_to_os,
|
23
24
|
)
|
@@ -104,24 +105,12 @@ class AquaFineTuningApp(AquaApp):
|
|
104
105
|
if not create_fine_tuning_details:
|
105
106
|
try:
|
106
107
|
create_fine_tuning_details = CreateFineTuningDetails(**kwargs)
|
107
|
-
except
|
108
|
-
|
109
|
-
field.name for field in fields(CreateFineTuningDetails)
|
110
|
-
).rstrip()
|
108
|
+
except ValidationError as ex:
|
109
|
+
custom_errors = build_pydantic_error_message(ex)
|
111
110
|
raise AquaValueError(
|
112
|
-
"Invalid
|
113
|
-
f"{allowed_create_fine_tuning_details}."
|
111
|
+
f"Invalid parameters for creating a fine-tuned model. Error details: {custom_errors}."
|
114
112
|
) from ex
|
115
113
|
|
116
|
-
source = self.get_source(create_fine_tuning_details.ft_source_id)
|
117
|
-
|
118
|
-
# todo: revisit validation for fine tuned models
|
119
|
-
# if source.compartment_id != ODSC_MODEL_COMPARTMENT_OCID:
|
120
|
-
# raise AquaValueError(
|
121
|
-
# f"Fine tuning is only supported for Aqua service models in {ODSC_MODEL_COMPARTMENT_OCID}. "
|
122
|
-
# "Use a valid Aqua service model id instead."
|
123
|
-
# )
|
124
|
-
|
125
114
|
target_compartment = (
|
126
115
|
create_fine_tuning_details.compartment_id or COMPARTMENT_OCID
|
127
116
|
)
|
@@ -160,19 +149,9 @@ class AquaFineTuningApp(AquaApp):
|
|
160
149
|
f"Logging is required for fine tuning if replica is larger than {DEFAULT_FT_REPLICA}."
|
161
150
|
)
|
162
151
|
|
163
|
-
ft_parameters =
|
164
|
-
|
165
|
-
|
166
|
-
**create_fine_tuning_details.ft_parameters,
|
167
|
-
)
|
168
|
-
except Exception as ex:
|
169
|
-
allowed_fine_tuning_parameters = ", ".join(
|
170
|
-
field.name for field in fields(AquaFineTuningParams)
|
171
|
-
).rstrip()
|
172
|
-
raise AquaValueError(
|
173
|
-
"Invalid fine tuning parameters. Fine tuning parameters should "
|
174
|
-
f"be a dictionary with keys: {allowed_fine_tuning_parameters}."
|
175
|
-
) from ex
|
152
|
+
ft_parameters = self._get_finetuning_params(
|
153
|
+
create_fine_tuning_details.ft_parameters
|
154
|
+
)
|
176
155
|
|
177
156
|
experiment_model_version_set_id = create_fine_tuning_details.experiment_id
|
178
157
|
experiment_model_version_set_name = create_fine_tuning_details.experiment_name
|
@@ -229,6 +208,8 @@ class AquaFineTuningApp(AquaApp):
|
|
229
208
|
defined_tags=create_fine_tuning_details.defined_tags,
|
230
209
|
)
|
231
210
|
|
211
|
+
source = self.get_source(create_fine_tuning_details.ft_source_id)
|
212
|
+
|
232
213
|
ft_model_custom_metadata = ModelCustomMetadata()
|
233
214
|
ft_model_custom_metadata.add(
|
234
215
|
key=FineTuneCustomMetadata.FINE_TUNE_SOURCE,
|
@@ -481,11 +462,7 @@ class AquaFineTuningApp(AquaApp):
|
|
481
462
|
**model_freeform_tags,
|
482
463
|
**model_defined_tags,
|
483
464
|
},
|
484
|
-
parameters=
|
485
|
-
key: value
|
486
|
-
for key, value in asdict(ft_parameters).items()
|
487
|
-
if value is not None
|
488
|
-
},
|
465
|
+
parameters=ft_parameters,
|
489
466
|
)
|
490
467
|
|
491
468
|
def _build_fine_tuning_runtime(
|
@@ -548,7 +525,7 @@ class AquaFineTuningApp(AquaApp):
|
|
548
525
|
) -> str:
|
549
526
|
"""Builds the oci launch cmd for fine tuning container runtime."""
|
550
527
|
oci_launch_cmd = f"--training_data {dataset_path} --output_dir {report_path} --val_set_size {val_set_size} "
|
551
|
-
for key, value in
|
528
|
+
for key, value in parameters.to_dict().items():
|
552
529
|
if value is not None:
|
553
530
|
if key == "batch_size":
|
554
531
|
oci_launch_cmd += f"--micro_{key} {value} "
|
@@ -613,15 +590,36 @@ class AquaFineTuningApp(AquaApp):
|
|
613
590
|
default_params = {"params": {}}
|
614
591
|
finetuning_config = self.get_finetuning_config(model_id)
|
615
592
|
config_parameters = finetuning_config.get("configuration", UNKNOWN_DICT)
|
616
|
-
dataclass_fields =
|
593
|
+
dataclass_fields = self._get_finetuning_params(
|
594
|
+
config_parameters, validate=False
|
595
|
+
).to_dict()
|
617
596
|
for name, value in config_parameters.items():
|
618
|
-
if name == "micro_batch_size":
|
619
|
-
name = "batch_size"
|
620
597
|
if name in dataclass_fields:
|
598
|
+
if name == "micro_batch_size":
|
599
|
+
name = "batch_size"
|
621
600
|
default_params["params"][name] = value
|
622
601
|
|
623
602
|
return default_params
|
624
603
|
|
604
|
+
@staticmethod
|
605
|
+
def _get_finetuning_params(
|
606
|
+
params: Dict = None, validate: bool = True
|
607
|
+
) -> AquaFineTuningParams:
|
608
|
+
"""
|
609
|
+
Get and validate the fine-tuning params, and return an error message if validation fails. In order to skip
|
610
|
+
@model_validator decorator's validation, pass validate=False.
|
611
|
+
"""
|
612
|
+
try:
|
613
|
+
finetuning_params = AquaFineTuningParams(
|
614
|
+
**{**params, **{"_validate": validate}}
|
615
|
+
)
|
616
|
+
except ValidationError as ex:
|
617
|
+
custom_errors = build_pydantic_error_message(ex)
|
618
|
+
raise AquaValueError(
|
619
|
+
f"Invalid finetuning parameters. Error details: {custom_errors}."
|
620
|
+
) from ex
|
621
|
+
return finetuning_params
|
622
|
+
|
625
623
|
def validate_finetuning_params(self, params: Dict = None) -> Dict:
|
626
624
|
"""Validate if the fine-tuning parameters passed by the user can be overridden. Parameter values are not
|
627
625
|
validated, only param keys are validated.
|
@@ -635,19 +633,5 @@ class AquaFineTuningApp(AquaApp):
|
|
635
633
|
-------
|
636
634
|
Return a list of restricted params.
|
637
635
|
"""
|
638
|
-
|
639
|
-
AquaFineTuningParams(
|
640
|
-
**params,
|
641
|
-
)
|
642
|
-
except Exception as e:
|
643
|
-
logger.debug(str(e))
|
644
|
-
allowed_fine_tuning_parameters = ", ".join(
|
645
|
-
f"{field.name} (required)" if field.default is MISSING else field.name
|
646
|
-
for field in fields(AquaFineTuningParams)
|
647
|
-
).rstrip()
|
648
|
-
raise AquaValueError(
|
649
|
-
f"Invalid fine tuning parameters. Allowable parameters are: "
|
650
|
-
f"{allowed_fine_tuning_parameters}."
|
651
|
-
) from e
|
652
|
-
|
636
|
+
self._get_finetuning_params(params or {})
|
653
637
|
return {"valid": True}
|
ads/aqua/model/entities.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
5
|
"""
|
@@ -283,6 +283,7 @@ class ImportModelDetails(CLIBuilderMixin):
|
|
283
283
|
os_path: str
|
284
284
|
download_from_hf: Optional[bool] = True
|
285
285
|
local_dir: Optional[str] = None
|
286
|
+
cleanup_model_cache: Optional[bool] = True
|
286
287
|
inference_container: Optional[str] = None
|
287
288
|
finetuning_container: Optional[str] = None
|
288
289
|
compartment_id: Optional[str] = None
|