oracle-ads 2.12.9__py3-none-any.whl → 2.12.10rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. ads/aqua/__init__.py +4 -4
  2. ads/aqua/common/enums.py +3 -0
  3. ads/aqua/common/utils.py +62 -2
  4. ads/aqua/data.py +2 -19
  5. ads/aqua/extension/finetune_handler.py +8 -14
  6. ads/aqua/extension/model_handler.py +19 -2
  7. ads/aqua/finetuning/constants.py +5 -2
  8. ads/aqua/finetuning/entities.py +64 -17
  9. ads/aqua/finetuning/finetuning.py +38 -54
  10. ads/aqua/model/entities.py +2 -1
  11. ads/aqua/model/model.py +61 -23
  12. ads/common/auth.py +9 -9
  13. ads/llm/autogen/__init__.py +2 -0
  14. ads/llm/autogen/constants.py +15 -0
  15. ads/llm/autogen/reports/__init__.py +2 -0
  16. ads/llm/autogen/reports/base.py +67 -0
  17. ads/llm/autogen/reports/data.py +103 -0
  18. ads/llm/autogen/reports/session.py +526 -0
  19. ads/llm/autogen/reports/templates/chat_box.html +13 -0
  20. ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
  21. ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
  22. ads/llm/autogen/reports/utils.py +56 -0
  23. ads/llm/autogen/v02/__init__.py +4 -0
  24. ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
  25. ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
  26. ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
  27. ads/llm/autogen/v02/loggers/__init__.py +6 -0
  28. ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
  29. ads/llm/autogen/v02/loggers/session_logger.py +580 -0
  30. ads/llm/autogen/v02/loggers/utils.py +86 -0
  31. ads/llm/autogen/v02/runtime_logging.py +163 -0
  32. ads/llm/langchain/plugins/chat_models/oci_data_science.py +12 -11
  33. ads/model/__init__.py +11 -13
  34. ads/model/artifact.py +47 -8
  35. ads/model/extractor/embedding_onnx_extractor.py +80 -0
  36. ads/model/framework/embedding_onnx_model.py +438 -0
  37. ads/model/generic_model.py +26 -24
  38. ads/model/model_metadata.py +8 -7
  39. ads/opctl/config/merger.py +13 -14
  40. ads/opctl/operator/common/operator_config.py +4 -4
  41. ads/opctl/operator/lowcode/common/transformations.py +12 -5
  42. ads/opctl/operator/lowcode/common/utils.py +11 -5
  43. ads/opctl/operator/lowcode/forecast/const.py +2 -0
  44. ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
  45. ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
  46. ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
  47. ads/opctl/operator/lowcode/forecast/model/base_model.py +61 -14
  48. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
  49. ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
  50. ads/opctl/operator/lowcode/forecast/schema.yaml +13 -0
  51. ads/opctl/operator/lowcode/forecast/utils.py +4 -3
  52. ads/telemetry/base.py +18 -11
  53. ads/telemetry/client.py +33 -13
  54. ads/templates/schemas/openapi.json +1740 -0
  55. ads/templates/score_embedding_onnx.jinja2 +202 -0
  56. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/METADATA +7 -8
  57. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/RECORD +60 -39
  58. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/LICENSE.txt +0 -0
  59. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/WHEEL +0 -0
  60. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/entry_points.txt +0 -0
ads/aqua/__init__.py CHANGED
@@ -1,12 +1,12 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
6
5
 
7
6
  import os
7
+ from logging import getLogger
8
8
 
9
- from ads import logger, set_auth
9
+ from ads import set_auth
10
10
  from ads.aqua.common.utils import fetch_service_compartment
11
11
  from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
12
12
 
@@ -19,6 +19,7 @@ def get_logger_level():
19
19
  return level
20
20
 
21
21
 
22
+ logger = getLogger(__name__)
22
23
  logger.setLevel(get_logger_level())
23
24
 
24
25
 
@@ -27,7 +28,6 @@ def set_log_level(log_level: str):
27
28
 
28
29
  log_level = log_level.upper()
29
30
  logger.setLevel(log_level.upper())
30
- logger.handlers[0].setLevel(log_level)
31
31
 
32
32
 
33
33
  if OCI_RESOURCE_PRINCIPAL_VERSION:
ads/aqua/common/enums.py CHANGED
@@ -52,6 +52,9 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
52
52
  AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
53
53
  AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
54
54
  AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
55
+
56
+
57
+ class CustomInferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
55
58
  AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
56
59
 
57
60
 
ads/aqua/common/utils.py CHANGED
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
  """AQUA utils and constants."""
5
5
 
@@ -11,6 +11,7 @@ import os
11
11
  import random
12
12
  import re
13
13
  import shlex
14
+ import shutil
14
15
  import subprocess
15
16
  from datetime import datetime, timedelta
16
17
  from functools import wraps
@@ -21,6 +22,8 @@ from typing import List, Union
21
22
  import fsspec
22
23
  import oci
23
24
  from cachetools import TTLCache, cached
25
+ from huggingface_hub.constants import HF_HUB_CACHE
26
+ from huggingface_hub.file_download import repo_folder_name
24
27
  from huggingface_hub.hf_api import HfApi, ModelInfo
25
28
  from huggingface_hub.utils import (
26
29
  GatedRepoError,
@@ -30,6 +33,7 @@ from huggingface_hub.utils import (
30
33
  )
31
34
  from oci.data_science.models import JobRun, Model
32
35
  from oci.object_storage.models import ObjectSummary
36
+ from pydantic import ValidationError
33
37
 
34
38
  from ads.aqua.common.enums import (
35
39
  InferenceContainerParamType,
@@ -788,7 +792,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
788
792
  return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
789
793
 
790
794
 
791
- def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
795
+ def upload_folder(
796
+ os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None
797
+ ) -> str:
792
798
  """Upload the local folder to the object storage
793
799
 
794
800
  Args:
@@ -818,6 +824,48 @@ def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern
818
824
  return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
819
825
 
820
826
 
827
+ def cleanup_local_hf_model_artifact(
828
+ model_name: str,
829
+ local_dir: str = None,
830
+ ):
831
+ """
832
+ Helper function that deletes local artifacts downloaded from Hugging Face to free up disk space.
833
+ Parameters
834
+ ----------
835
+ model_name (str): Name of the huggingface model
836
+ local_dir (str): Local directory where the object is downloaded
837
+
838
+ """
839
+ if local_dir and os.path.exists(local_dir):
840
+ model_dir = os.path.join(local_dir, model_name)
841
+ model_dir = (
842
+ os.path.dirname(model_dir)
843
+ if "/" in model_name or os.sep in model_name
844
+ else model_dir
845
+ )
846
+ shutil.rmtree(model_dir, ignore_errors=True)
847
+ if os.path.exists(model_dir):
848
+ logger.debug(
849
+ f"Could not delete local model artifact directory: {model_dir}"
850
+ )
851
+ else:
852
+ logger.debug(f"Deleted local model artifact directory: {model_dir}.")
853
+
854
+ hf_local_path = os.path.join(
855
+ HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model")
856
+ )
857
+ shutil.rmtree(hf_local_path, ignore_errors=True)
858
+
859
+ if os.path.exists(hf_local_path):
860
+ logger.debug(
861
+ f"Could not clear the local Hugging Face cache directory {hf_local_path} for the model {model_name}."
862
+ )
863
+ else:
864
+ logger.debug(
865
+ f"Cleared contents of local Hugging Face cache directory {hf_local_path} for the model {model_name}."
866
+ )
867
+
868
+
821
869
  def is_service_managed_container(container):
822
870
  return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
823
871
 
@@ -1159,3 +1207,15 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
1159
1207
 
1160
1208
  combined_cmd_var = cmd_var + overrides
1161
1209
  return combined_cmd_var
1210
+
1211
+
1212
+ def build_pydantic_error_message(ex: ValidationError):
1213
+ """Added to handle error messages from pydantic model validator.
1214
+ Combine both loc and msg for errors where loc (field) is present in error details, else only build error
1215
+ message using msg field."""
1216
+
1217
+ return {
1218
+ ".".join(map(str, e["loc"])): e["msg"]
1219
+ for e in ex.errors()
1220
+ if "loc" in e and e["loc"]
1221
+ } or "; ".join(e["msg"] for e in ex.errors())
ads/aqua/data.py CHANGED
@@ -1,9 +1,8 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
6
- from dataclasses import dataclass, field
5
+ from dataclasses import dataclass
7
6
 
8
7
  from ads.common.serializer import DataClassSerializable
9
8
 
@@ -13,19 +12,3 @@ class AquaResourceIdentifier(DataClassSerializable):
13
12
  id: str = ""
14
13
  name: str = ""
15
14
  url: str = ""
16
-
17
-
18
- @dataclass(repr=False)
19
- class AquaJobSummary(DataClassSerializable):
20
- """Represents an Aqua job summary."""
21
-
22
- id: str
23
- name: str
24
- console_url: str
25
- lifecycle_state: str
26
- lifecycle_details: str
27
- time_created: str
28
- tags: dict
29
- experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
30
- source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
31
- job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
 
5
5
 
@@ -10,9 +10,7 @@ from tornado.web import HTTPError
10
10
  from ads.aqua.common.decorator import handle_exceptions
11
11
  from ads.aqua.extension.base_handler import AquaAPIhandler
12
12
  from ads.aqua.extension.errors import Errors
13
- from ads.aqua.extension.utils import validate_function_parameters
14
13
  from ads.aqua.finetuning import AquaFineTuningApp
15
- from ads.aqua.finetuning.entities import CreateFineTuningDetails
16
14
 
17
15
 
18
16
  class AquaFineTuneHandler(AquaAPIhandler):
@@ -33,7 +31,7 @@ class AquaFineTuneHandler(AquaAPIhandler):
33
31
  raise HTTPError(400, f"The request {self.request.path} is invalid.")
34
32
 
35
33
  @handle_exceptions
36
- def post(self, *args, **kwargs):
34
+ def post(self, *args, **kwargs): # noqa: ARG002
37
35
  """Handles post request for the fine-tuning API
38
36
 
39
37
  Raises
@@ -43,17 +41,13 @@ class AquaFineTuneHandler(AquaAPIhandler):
43
41
  """
44
42
  try:
45
43
  input_data = self.get_json_body()
46
- except Exception:
47
- raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
44
+ except Exception as ex:
45
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
48
46
 
49
47
  if not input_data:
50
48
  raise HTTPError(400, Errors.NO_INPUT_DATA)
51
49
 
52
- validate_function_parameters(
53
- data_class=CreateFineTuningDetails, input_data=input_data
54
- )
55
-
56
- self.finish(AquaFineTuningApp().create(CreateFineTuningDetails(**input_data)))
50
+ self.finish(AquaFineTuningApp().create(**input_data))
57
51
 
58
52
  def get_finetuning_config(self, model_id):
59
53
  """Gets the finetuning config for Aqua model."""
@@ -71,7 +65,7 @@ class AquaFineTuneParamsHandler(AquaAPIhandler):
71
65
  )
72
66
 
73
67
  @handle_exceptions
74
- def post(self, *args, **kwargs):
68
+ def post(self, *args, **kwargs): # noqa: ARG002
75
69
  """Handles post request for the finetuning param handler API.
76
70
 
77
71
  Raises
@@ -81,8 +75,8 @@ class AquaFineTuneParamsHandler(AquaAPIhandler):
81
75
  """
82
76
  try:
83
77
  input_data = self.get_json_body()
84
- except Exception:
85
- raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
78
+ except Exception as ex:
79
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
86
80
 
87
81
  if not input_data:
88
82
  raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
 
5
5
  from typing import Optional
@@ -8,6 +8,9 @@ from urllib.parse import urlparse
8
8
  from tornado.web import HTTPError
9
9
 
10
10
  from ads.aqua.common.decorator import handle_exceptions
11
+ from ads.aqua.common.enums import (
12
+ CustomInferenceContainerTypeFamily,
13
+ )
11
14
  from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
12
15
  from ads.aqua.common.utils import (
13
16
  get_hf_model_info,
@@ -128,6 +131,10 @@ class AquaModelHandler(AquaAPIhandler):
128
131
  download_from_hf = (
129
132
  str(input_data.get("download_from_hf", "false")).lower() == "true"
130
133
  )
134
+ local_dir = input_data.get("local_dir")
135
+ cleanup_model_cache = (
136
+ str(input_data.get("cleanup_model_cache", "true")).lower() == "true"
137
+ )
131
138
  inference_container_uri = input_data.get("inference_container_uri")
132
139
  allow_patterns = input_data.get("allow_patterns")
133
140
  ignore_patterns = input_data.get("ignore_patterns")
@@ -139,6 +146,8 @@ class AquaModelHandler(AquaAPIhandler):
139
146
  model=model,
140
147
  os_path=os_path,
141
148
  download_from_hf=download_from_hf,
149
+ local_dir=local_dir,
150
+ cleanup_model_cache=cleanup_model_cache,
142
151
  inference_container=inference_container,
143
152
  finetuning_container=finetuning_container,
144
153
  compartment_id=compartment_id,
@@ -163,7 +172,9 @@ class AquaModelHandler(AquaAPIhandler):
163
172
  raise HTTPError(400, Errors.NO_INPUT_DATA)
164
173
 
165
174
  inference_container = input_data.get("inference_container")
175
+ inference_container_uri = input_data.get("inference_container_uri")
166
176
  inference_containers = AquaModelApp.list_valid_inference_containers()
177
+ inference_containers.extend(CustomInferenceContainerTypeFamily.values())
167
178
  if (
168
179
  inference_container is not None
169
180
  and inference_container not in inference_containers
@@ -176,7 +187,13 @@ class AquaModelHandler(AquaAPIhandler):
176
187
  task = input_data.get("task")
177
188
  app = AquaModelApp()
178
189
  self.finish(
179
- app.edit_registered_model(id, inference_container, enable_finetuning, task)
190
+ app.edit_registered_model(
191
+ id,
192
+ inference_container,
193
+ inference_container_uri,
194
+ enable_finetuning,
195
+ task,
196
+ )
180
197
  )
181
198
  app.clear_model_details_cache(model_id=id)
182
199
 
@@ -1,6 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
6
5
  from ads.common.extended_enum import ExtendedEnumMeta
@@ -17,4 +16,8 @@ class FineTuneCustomMetadata(str, metaclass=ExtendedEnumMeta):
17
16
  SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container"
18
17
 
19
18
 
19
+ class FineTuningRestrictedParams(str, metaclass=ExtendedEnumMeta):
20
+ OPTIMIZER = "optimizer"
21
+
22
+
20
23
  ENV_AQUA_FINE_TUNING_CONTAINER = "AQUA_FINE_TUNING_CONTAINER"
@@ -1,18 +1,24 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
- from dataclasses import dataclass, field
5
- from typing import List, Optional
6
4
 
7
- from ads.aqua.data import AquaJobSummary
8
- from ads.common.serializer import DataClassSerializable
5
+ import json
6
+ from typing import List, Literal, Optional, Union
9
7
 
8
+ from pydantic import Field, model_validator
10
9
 
11
- @dataclass(repr=False)
12
- class AquaFineTuningParams(DataClassSerializable):
13
- epochs: int
10
+ from ads.aqua.common.errors import AquaValueError
11
+ from ads.aqua.config.utils.serializer import Serializable
12
+ from ads.aqua.data import AquaResourceIdentifier
13
+ from ads.aqua.finetuning.constants import FineTuningRestrictedParams
14
+
15
+
16
+ class AquaFineTuningParams(Serializable):
17
+ """Class for maintaining aqua fine-tuning model parameters"""
18
+
19
+ epochs: Optional[int] = None
14
20
  learning_rate: Optional[float] = None
15
- sample_packing: Optional[bool] = "auto"
21
+ sample_packing: Union[bool, None, Literal["auto"]] = "auto"
16
22
  batch_size: Optional[int] = (
17
23
  None # make it batch_size for user, but internally this is micro_batch_size
18
24
  )
@@ -22,21 +28,59 @@ class AquaFineTuningParams(DataClassSerializable):
22
28
  lora_alpha: Optional[int] = None
23
29
  lora_dropout: Optional[float] = None
24
30
  lora_target_linear: Optional[bool] = None
25
- lora_target_modules: Optional[List] = None
31
+ lora_target_modules: Optional[List[str]] = None
26
32
  early_stopping_patience: Optional[int] = None
27
33
  early_stopping_threshold: Optional[float] = None
28
34
 
35
+ class Config:
36
+ extra = "allow"
37
+
38
+ def to_dict(self) -> dict:
39
+ return json.loads(super().to_json(exclude_none=True))
40
+
41
+ @model_validator(mode="before")
42
+ @classmethod
43
+ def validate_restricted_fields(cls, data: dict):
44
+ # we may want to skip validation if loading data from config files instead of user entered parameters
45
+ validate = data.pop("_validate", True)
46
+ if not (validate and isinstance(data, dict)):
47
+ return data
48
+ restricted_params = [
49
+ param for param in data if param in FineTuningRestrictedParams.values()
50
+ ]
51
+ if restricted_params:
52
+ raise AquaValueError(
53
+ f"Found restricted parameter name: {restricted_params}"
54
+ )
55
+ return data
29
56
 
30
- @dataclass(repr=False)
31
- class AquaFineTuningSummary(AquaJobSummary, DataClassSerializable):
32
- parameters: AquaFineTuningParams = field(default_factory=AquaFineTuningParams)
33
57
 
58
+ class AquaFineTuningSummary(Serializable):
59
+ """Represents a summary of Aqua Finetuning job."""
34
60
 
35
- @dataclass(repr=False)
36
- class CreateFineTuningDetails(DataClassSerializable):
37
- """Dataclass to create aqua model fine tuning.
61
+ id: str
62
+ name: str
63
+ console_url: str
64
+ lifecycle_state: str
65
+ lifecycle_details: str
66
+ time_created: str
67
+ tags: dict
68
+ experiment: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
69
+ source: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
70
+ job: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
71
+ parameters: AquaFineTuningParams = Field(default_factory=AquaFineTuningParams)
38
72
 
39
- Fields
73
+ class Config:
74
+ extra = "ignore"
75
+
76
+ def to_dict(self) -> dict:
77
+ return json.loads(super().to_json(exclude_none=True))
78
+
79
+
80
+ class CreateFineTuningDetails(Serializable):
81
+ """Class to create aqua model fine-tuning instance.
82
+
83
+ Properties
40
84
  ------
41
85
  ft_source_id: str
42
86
  The fine tuning source id. Must be model ocid.
@@ -107,3 +151,6 @@ class CreateFineTuningDetails(DataClassSerializable):
107
151
  force_overwrite: Optional[bool] = False
108
152
  freeform_tags: Optional[dict] = None
109
153
  defined_tags: Optional[dict] = None
154
+
155
+ class Config:
156
+ extra = "ignore"
@@ -1,10 +1,9 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
 
5
5
  import json
6
6
  import os
7
- from dataclasses import MISSING, asdict, fields
8
7
  from typing import Dict
9
8
 
10
9
  from oci.data_science.models import (
@@ -12,12 +11,14 @@ from oci.data_science.models import (
12
11
  UpdateModelDetails,
13
12
  UpdateModelProvenanceDetails,
14
13
  )
14
+ from pydantic import ValidationError
15
15
 
16
16
  from ads.aqua import logger
17
17
  from ads.aqua.app import AquaApp
18
18
  from ads.aqua.common.enums import Resource, Tags
19
19
  from ads.aqua.common.errors import AquaFileExistsError, AquaValueError
20
20
  from ads.aqua.common.utils import (
21
+ build_pydantic_error_message,
21
22
  get_container_image,
22
23
  upload_local_to_os,
23
24
  )
@@ -104,24 +105,12 @@ class AquaFineTuningApp(AquaApp):
104
105
  if not create_fine_tuning_details:
105
106
  try:
106
107
  create_fine_tuning_details = CreateFineTuningDetails(**kwargs)
107
- except Exception as ex:
108
- allowed_create_fine_tuning_details = ", ".join(
109
- field.name for field in fields(CreateFineTuningDetails)
110
- ).rstrip()
108
+ except ValidationError as ex:
109
+ custom_errors = build_pydantic_error_message(ex)
111
110
  raise AquaValueError(
112
- "Invalid create fine tuning parameters. Allowable parameters are: "
113
- f"{allowed_create_fine_tuning_details}."
111
+ f"Invalid parameters for creating a fine-tuned model. Error details: {custom_errors}."
114
112
  ) from ex
115
113
 
116
- source = self.get_source(create_fine_tuning_details.ft_source_id)
117
-
118
- # todo: revisit validation for fine tuned models
119
- # if source.compartment_id != ODSC_MODEL_COMPARTMENT_OCID:
120
- # raise AquaValueError(
121
- # f"Fine tuning is only supported for Aqua service models in {ODSC_MODEL_COMPARTMENT_OCID}. "
122
- # "Use a valid Aqua service model id instead."
123
- # )
124
-
125
114
  target_compartment = (
126
115
  create_fine_tuning_details.compartment_id or COMPARTMENT_OCID
127
116
  )
@@ -160,19 +149,9 @@ class AquaFineTuningApp(AquaApp):
160
149
  f"Logging is required for fine tuning if replica is larger than {DEFAULT_FT_REPLICA}."
161
150
  )
162
151
 
163
- ft_parameters = None
164
- try:
165
- ft_parameters = AquaFineTuningParams(
166
- **create_fine_tuning_details.ft_parameters,
167
- )
168
- except Exception as ex:
169
- allowed_fine_tuning_parameters = ", ".join(
170
- field.name for field in fields(AquaFineTuningParams)
171
- ).rstrip()
172
- raise AquaValueError(
173
- "Invalid fine tuning parameters. Fine tuning parameters should "
174
- f"be a dictionary with keys: {allowed_fine_tuning_parameters}."
175
- ) from ex
152
+ ft_parameters = self._get_finetuning_params(
153
+ create_fine_tuning_details.ft_parameters
154
+ )
176
155
 
177
156
  experiment_model_version_set_id = create_fine_tuning_details.experiment_id
178
157
  experiment_model_version_set_name = create_fine_tuning_details.experiment_name
@@ -229,6 +208,8 @@ class AquaFineTuningApp(AquaApp):
229
208
  defined_tags=create_fine_tuning_details.defined_tags,
230
209
  )
231
210
 
211
+ source = self.get_source(create_fine_tuning_details.ft_source_id)
212
+
232
213
  ft_model_custom_metadata = ModelCustomMetadata()
233
214
  ft_model_custom_metadata.add(
234
215
  key=FineTuneCustomMetadata.FINE_TUNE_SOURCE,
@@ -481,11 +462,7 @@ class AquaFineTuningApp(AquaApp):
481
462
  **model_freeform_tags,
482
463
  **model_defined_tags,
483
464
  },
484
- parameters={
485
- key: value
486
- for key, value in asdict(ft_parameters).items()
487
- if value is not None
488
- },
465
+ parameters=ft_parameters,
489
466
  )
490
467
 
491
468
  def _build_fine_tuning_runtime(
@@ -548,7 +525,7 @@ class AquaFineTuningApp(AquaApp):
548
525
  ) -> str:
549
526
  """Builds the oci launch cmd for fine tuning container runtime."""
550
527
  oci_launch_cmd = f"--training_data {dataset_path} --output_dir {report_path} --val_set_size {val_set_size} "
551
- for key, value in asdict(parameters).items():
528
+ for key, value in parameters.to_dict().items():
552
529
  if value is not None:
553
530
  if key == "batch_size":
554
531
  oci_launch_cmd += f"--micro_{key} {value} "
@@ -613,15 +590,36 @@ class AquaFineTuningApp(AquaApp):
613
590
  default_params = {"params": {}}
614
591
  finetuning_config = self.get_finetuning_config(model_id)
615
592
  config_parameters = finetuning_config.get("configuration", UNKNOWN_DICT)
616
- dataclass_fields = {field.name for field in fields(AquaFineTuningParams)}
593
+ dataclass_fields = self._get_finetuning_params(
594
+ config_parameters, validate=False
595
+ ).to_dict()
617
596
  for name, value in config_parameters.items():
618
- if name == "micro_batch_size":
619
- name = "batch_size"
620
597
  if name in dataclass_fields:
598
+ if name == "micro_batch_size":
599
+ name = "batch_size"
621
600
  default_params["params"][name] = value
622
601
 
623
602
  return default_params
624
603
 
604
+ @staticmethod
605
+ def _get_finetuning_params(
606
+ params: Dict = None, validate: bool = True
607
+ ) -> AquaFineTuningParams:
608
+ """
609
+ Get and validate the fine-tuning params, and return an error message if validation fails. In order to skip
610
+ @model_validator decorator's validation, pass validate=False.
611
+ """
612
+ try:
613
+ finetuning_params = AquaFineTuningParams(
614
+ **{**params, **{"_validate": validate}}
615
+ )
616
+ except ValidationError as ex:
617
+ custom_errors = build_pydantic_error_message(ex)
618
+ raise AquaValueError(
619
+ f"Invalid finetuning parameters. Error details: {custom_errors}."
620
+ ) from ex
621
+ return finetuning_params
622
+
625
623
  def validate_finetuning_params(self, params: Dict = None) -> Dict:
626
624
  """Validate if the fine-tuning parameters passed by the user can be overridden. Parameter values are not
627
625
  validated, only param keys are validated.
@@ -635,19 +633,5 @@ class AquaFineTuningApp(AquaApp):
635
633
  -------
636
634
  Return a list of restricted params.
637
635
  """
638
- try:
639
- AquaFineTuningParams(
640
- **params,
641
- )
642
- except Exception as e:
643
- logger.debug(str(e))
644
- allowed_fine_tuning_parameters = ", ".join(
645
- f"{field.name} (required)" if field.default is MISSING else field.name
646
- for field in fields(AquaFineTuningParams)
647
- ).rstrip()
648
- raise AquaValueError(
649
- f"Invalid fine tuning parameters. Allowable parameters are: "
650
- f"{allowed_fine_tuning_parameters}."
651
- ) from e
652
-
636
+ self._get_finetuning_params(params or {})
653
637
  return {"valid": True}
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
 
5
5
  """
@@ -283,6 +283,7 @@ class ImportModelDetails(CLIBuilderMixin):
283
283
  os_path: str
284
284
  download_from_hf: Optional[bool] = True
285
285
  local_dir: Optional[str] = None
286
+ cleanup_model_cache: Optional[bool] = True
286
287
  inference_container: Optional[str] = None
287
288
  finetuning_container: Optional[str] = None
288
289
  compartment_id: Optional[str] = None