oracle-ads 2.12.9__py3-none-any.whl → 2.12.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. ads/aqua/__init__.py +4 -3
  2. ads/aqua/app.py +28 -16
  3. ads/aqua/client/__init__.py +3 -0
  4. ads/aqua/client/client.py +799 -0
  5. ads/aqua/common/enums.py +3 -0
  6. ads/aqua/common/utils.py +62 -2
  7. ads/aqua/data.py +2 -19
  8. ads/aqua/evaluation/evaluation.py +20 -12
  9. ads/aqua/extension/aqua_ws_msg_handler.py +14 -7
  10. ads/aqua/extension/base_handler.py +12 -9
  11. ads/aqua/extension/finetune_handler.py +8 -14
  12. ads/aqua/extension/model_handler.py +24 -2
  13. ads/aqua/finetuning/constants.py +5 -2
  14. ads/aqua/finetuning/entities.py +67 -17
  15. ads/aqua/finetuning/finetuning.py +69 -54
  16. ads/aqua/model/entities.py +3 -1
  17. ads/aqua/model/model.py +196 -98
  18. ads/aqua/modeldeployment/deployment.py +22 -10
  19. ads/cli.py +16 -8
  20. ads/common/auth.py +9 -9
  21. ads/llm/autogen/__init__.py +2 -0
  22. ads/llm/autogen/constants.py +15 -0
  23. ads/llm/autogen/reports/__init__.py +2 -0
  24. ads/llm/autogen/reports/base.py +67 -0
  25. ads/llm/autogen/reports/data.py +103 -0
  26. ads/llm/autogen/reports/session.py +526 -0
  27. ads/llm/autogen/reports/templates/chat_box.html +13 -0
  28. ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
  29. ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
  30. ads/llm/autogen/reports/utils.py +56 -0
  31. ads/llm/autogen/v02/__init__.py +4 -0
  32. ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
  33. ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
  34. ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
  35. ads/llm/autogen/v02/loggers/__init__.py +6 -0
  36. ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
  37. ads/llm/autogen/v02/loggers/session_logger.py +580 -0
  38. ads/llm/autogen/v02/loggers/utils.py +86 -0
  39. ads/llm/autogen/v02/runtime_logging.py +163 -0
  40. ads/llm/langchain/plugins/chat_models/oci_data_science.py +12 -11
  41. ads/model/__init__.py +11 -13
  42. ads/model/artifact.py +47 -8
  43. ads/model/extractor/embedding_onnx_extractor.py +80 -0
  44. ads/model/framework/embedding_onnx_model.py +438 -0
  45. ads/model/generic_model.py +26 -24
  46. ads/model/model_metadata.py +8 -7
  47. ads/opctl/config/merger.py +13 -14
  48. ads/opctl/operator/common/operator_config.py +4 -4
  49. ads/opctl/operator/lowcode/common/transformations.py +50 -8
  50. ads/opctl/operator/lowcode/common/utils.py +22 -6
  51. ads/opctl/operator/lowcode/forecast/__main__.py +10 -0
  52. ads/opctl/operator/lowcode/forecast/const.py +2 -0
  53. ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
  54. ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
  55. ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
  56. ads/opctl/operator/lowcode/forecast/model/base_model.py +61 -14
  57. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +1 -1
  58. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
  59. ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
  60. ads/opctl/operator/lowcode/forecast/operator_config.py +31 -0
  61. ads/opctl/operator/lowcode/forecast/schema.yaml +76 -0
  62. ads/opctl/operator/lowcode/forecast/utils.py +4 -3
  63. ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
  64. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +233 -0
  65. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +238 -0
  66. ads/telemetry/base.py +18 -11
  67. ads/telemetry/client.py +33 -13
  68. ads/templates/schemas/openapi.json +1740 -0
  69. ads/templates/score_embedding_onnx.jinja2 +202 -0
  70. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/METADATA +9 -8
  71. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/RECORD +74 -48
  72. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/LICENSE.txt +0 -0
  73. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/WHEEL +0 -0
  74. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/entry_points.txt +0 -0
ads/aqua/common/enums.py CHANGED
@@ -52,6 +52,9 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
52
52
  AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
53
53
  AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
54
54
  AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
55
+
56
+
57
+ class CustomInferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
55
58
  AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
56
59
 
57
60
 
ads/aqua/common/utils.py CHANGED
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
  """AQUA utils and constants."""
5
5
 
@@ -11,6 +11,7 @@ import os
11
11
  import random
12
12
  import re
13
13
  import shlex
14
+ import shutil
14
15
  import subprocess
15
16
  from datetime import datetime, timedelta
16
17
  from functools import wraps
@@ -21,6 +22,8 @@ from typing import List, Union
21
22
  import fsspec
22
23
  import oci
23
24
  from cachetools import TTLCache, cached
25
+ from huggingface_hub.constants import HF_HUB_CACHE
26
+ from huggingface_hub.file_download import repo_folder_name
24
27
  from huggingface_hub.hf_api import HfApi, ModelInfo
25
28
  from huggingface_hub.utils import (
26
29
  GatedRepoError,
@@ -30,6 +33,7 @@ from huggingface_hub.utils import (
30
33
  )
31
34
  from oci.data_science.models import JobRun, Model
32
35
  from oci.object_storage.models import ObjectSummary
36
+ from pydantic import ValidationError
33
37
 
34
38
  from ads.aqua.common.enums import (
35
39
  InferenceContainerParamType,
@@ -788,7 +792,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
788
792
  return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
789
793
 
790
794
 
791
- def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
795
+ def upload_folder(
796
+ os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None
797
+ ) -> str:
792
798
  """Upload the local folder to the object storage
793
799
 
794
800
  Args:
@@ -818,6 +824,48 @@ def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern
818
824
  return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
819
825
 
820
826
 
827
+ def cleanup_local_hf_model_artifact(
828
+ model_name: str,
829
+ local_dir: str = None,
830
+ ):
831
+ """
832
+ Helper function that deletes local artifacts downloaded from Hugging Face to free up disk space.
833
+ Parameters
834
+ ----------
835
+ model_name (str): Name of the huggingface model
836
+ local_dir (str): Local directory where the object is downloaded
837
+
838
+ """
839
+ if local_dir and os.path.exists(local_dir):
840
+ model_dir = os.path.join(local_dir, model_name)
841
+ model_dir = (
842
+ os.path.dirname(model_dir)
843
+ if "/" in model_name or os.sep in model_name
844
+ else model_dir
845
+ )
846
+ shutil.rmtree(model_dir, ignore_errors=True)
847
+ if os.path.exists(model_dir):
848
+ logger.debug(
849
+ f"Could not delete local model artifact directory: {model_dir}"
850
+ )
851
+ else:
852
+ logger.debug(f"Deleted local model artifact directory: {model_dir}.")
853
+
854
+ hf_local_path = os.path.join(
855
+ HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model")
856
+ )
857
+ shutil.rmtree(hf_local_path, ignore_errors=True)
858
+
859
+ if os.path.exists(hf_local_path):
860
+ logger.debug(
861
+ f"Could not clear the local Hugging Face cache directory {hf_local_path} for the model {model_name}."
862
+ )
863
+ else:
864
+ logger.debug(
865
+ f"Cleared contents of local Hugging Face cache directory {hf_local_path} for the model {model_name}."
866
+ )
867
+
868
+
821
869
  def is_service_managed_container(container):
822
870
  return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
823
871
 
@@ -1159,3 +1207,15 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
1159
1207
 
1160
1208
  combined_cmd_var = cmd_var + overrides
1161
1209
  return combined_cmd_var
1210
+
1211
+
1212
+ def build_pydantic_error_message(ex: ValidationError):
1213
+ """Added to handle error messages from pydantic model validator.
1214
+ Combine both loc and msg for errors where loc (field) is present in error details, else only build error
1215
+ message using msg field."""
1216
+
1217
+ return {
1218
+ ".".join(map(str, e["loc"])): e["msg"]
1219
+ for e in ex.errors()
1220
+ if "loc" in e and e["loc"]
1221
+ } or "; ".join(e["msg"] for e in ex.errors())
ads/aqua/data.py CHANGED
@@ -1,9 +1,8 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
6
- from dataclasses import dataclass, field
5
+ from dataclasses import dataclass
7
6
 
8
7
  from ads.common.serializer import DataClassSerializable
9
8
 
@@ -13,19 +12,3 @@ class AquaResourceIdentifier(DataClassSerializable):
13
12
  id: str = ""
14
13
  name: str = ""
15
14
  url: str = ""
16
-
17
-
18
- @dataclass(repr=False)
19
- class AquaJobSummary(DataClassSerializable):
20
- """Represents an Aqua job summary."""
21
-
22
- id: str
23
- name: str
24
- console_url: str
25
- lifecycle_state: str
26
- lifecycle_details: str
27
- time_created: str
28
- tags: dict
29
- experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
30
- source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
31
- job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
  import base64
5
5
  import json
@@ -199,11 +199,11 @@ class AquaEvaluationApp(AquaApp):
199
199
  eval_inference_configuration = (
200
200
  container.spec.evaluation_configuration
201
201
  )
202
- except Exception:
202
+ except Exception as ex:
203
203
  logger.debug(
204
204
  f"Could not load inference config details for the evaluation source id: "
205
205
  f"{create_aqua_evaluation_details.evaluation_source_id}. Please check if the container"
206
- f" runtime has the correct SMC image information."
206
+ f" runtime has the correct SMC image information.\nError: {str(ex)}"
207
207
  )
208
208
  elif (
209
209
  DataScienceResource.MODEL
@@ -289,7 +289,7 @@ class AquaEvaluationApp(AquaApp):
289
289
  f"Invalid experiment name. Please provide an experiment with `{Tags.AQUA_EVALUATION}` in tags."
290
290
  )
291
291
  except Exception:
292
- logger.debug(
292
+ logger.info(
293
293
  f"Model version set {experiment_model_version_set_name} doesn't exist. "
294
294
  "Creating new model version set."
295
295
  )
@@ -711,21 +711,27 @@ class AquaEvaluationApp(AquaApp):
711
711
  try:
712
712
  log = utils.query_resource(log_id, return_all=False)
713
713
  log_name = log.display_name if log else ""
714
- except Exception:
714
+ except Exception as ex:
715
+ logger.debug(f"Failed to get associated log name. Error: {ex}")
715
716
  pass
716
717
 
717
718
  if loggroup_id:
718
719
  try:
719
720
  loggroup = utils.query_resource(loggroup_id, return_all=False)
720
721
  loggroup_name = loggroup.display_name if loggroup else ""
721
- except Exception:
722
+ except Exception as ex:
723
+ logger.debug(f"Failed to get associated loggroup name. Error: {ex}")
722
724
  pass
723
725
 
724
726
  try:
725
727
  introspection = json.loads(
726
728
  self._get_attribute_from_model_metadata(resource, "ArtifactTestResults")
727
729
  )
728
- except Exception:
730
+ except Exception as ex:
731
+ logger.debug(
732
+ f"There was an issue loading the model attribute as json object for evaluation {eval_id}. "
733
+ f"Setting introspection to empty.\n Error:{ex}"
734
+ )
729
735
  introspection = {}
730
736
 
731
737
  summary = AquaEvaluationDetail(
@@ -878,13 +884,13 @@ class AquaEvaluationApp(AquaApp):
878
884
  try:
879
885
  log_id = job_run_details.log_details.log_id
880
886
  except Exception as e:
881
- logger.debug(f"Failed to get associated log. {str(e)}")
887
+ logger.debug(f"Failed to get associated log.\nError: {str(e)}")
882
888
  log_id = ""
883
889
 
884
890
  try:
885
891
  loggroup_id = job_run_details.log_details.log_group_id
886
892
  except Exception as e:
887
- logger.debug(f"Failed to get associated log. {str(e)}")
893
+ logger.debug(f"Failed to get associated log.\nError: {str(e)}")
888
894
  loggroup_id = ""
889
895
 
890
896
  loggroup_url = get_log_links(region=self.region, log_group_id=loggroup_id)
@@ -958,7 +964,7 @@ class AquaEvaluationApp(AquaApp):
958
964
  )
959
965
  except Exception as e:
960
966
  logger.debug(
961
- "Failed to load `report.json` from evaluation artifact" f"{str(e)}"
967
+ f"Failed to load `report.json` from evaluation artifact.\nError: {str(e)}"
962
968
  )
963
969
  json_report = {}
964
970
 
@@ -1047,6 +1053,7 @@ class AquaEvaluationApp(AquaApp):
1047
1053
  return report
1048
1054
 
1049
1055
  with tempfile.TemporaryDirectory() as temp_dir:
1056
+ logger.info(f"Downloading evaluation artifact for {eval_id}.")
1050
1057
  DataScienceModel.from_id(eval_id).download_artifact(
1051
1058
  temp_dir,
1052
1059
  auth=self._auth,
@@ -1200,6 +1207,7 @@ class AquaEvaluationApp(AquaApp):
1200
1207
  def load_evaluation_config(self, container: Optional[str] = None) -> Dict:
1201
1208
  """Loads evaluation config."""
1202
1209
 
1210
+ logger.info("Loading evaluation container config.")
1203
1211
  # retrieve the evaluation config by container family name
1204
1212
  evaluation_config = get_evaluation_service_config(container)
1205
1213
 
@@ -1279,9 +1287,9 @@ class AquaEvaluationApp(AquaApp):
1279
1287
  raise AquaRuntimeError(
1280
1288
  f"Not supported source type: {resource_type}"
1281
1289
  )
1282
- except Exception:
1290
+ except Exception as ex:
1283
1291
  logger.debug(
1284
- f"Failed to retrieve source information for evaluation {evaluation.identifier}."
1292
+ f"Failed to retrieve source information for evaluation {evaluation.identifier}.\nError: {str(ex)}"
1285
1293
  )
1286
1294
  source_name = ""
1287
1295
 
@@ -1,10 +1,10 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import traceback
7
+ import uuid
8
8
  from abc import abstractmethod
9
9
  from http.client import responses
10
10
  from typing import List
@@ -34,7 +34,7 @@ class AquaWSMsgHandler:
34
34
  self.telemetry = TelemetryClient(
35
35
  bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS
36
36
  )
37
- except:
37
+ except Exception:
38
38
  pass
39
39
 
40
40
  @staticmethod
@@ -66,16 +66,23 @@ class AquaWSMsgHandler:
66
66
  "message": message,
67
67
  "service_payload": service_payload,
68
68
  "reason": reason,
69
+ "request_id": str(uuid.uuid4()),
69
70
  }
70
71
  exc_info = kwargs.get("exc_info")
71
72
  if exc_info:
72
- logger.error("".join(traceback.format_exception(*exc_info)))
73
+ logger.error(
74
+ f"Error Request ID: {reply['request_id']}\n"
75
+ f"Error: {''.join(traceback.format_exception(*exc_info))}"
76
+ )
73
77
  e = exc_info[1]
74
78
  if isinstance(e, HTTPError):
75
79
  reply["message"] = e.log_message or message
76
80
  reply["reason"] = e.reason
77
- else:
78
- logger.warning(reply["message"])
81
+
82
+ logger.error(
83
+ f"Error Request ID: {reply['request_id']}\n"
84
+ f"Error: {reply['message']} {reply['reason']}"
85
+ )
79
86
  # telemetry may not be present if there is an error while initializing
80
87
  if hasattr(self, "telemetry"):
81
88
  aqua_api_details = kwargs.get("aqua_api_details", {})
@@ -83,7 +90,7 @@ class AquaWSMsgHandler:
83
90
  category="aqua/error",
84
91
  action=str(status_code),
85
92
  value=reason,
86
- **aqua_api_details
93
+ **aqua_api_details,
87
94
  )
88
95
  response = AquaWsError(
89
96
  status=status_code,
@@ -1,6 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
6
5
 
@@ -35,7 +34,7 @@ class AquaAPIhandler(APIHandler):
35
34
  self.telemetry = TelemetryClient(
36
35
  bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS
37
36
  )
38
- except:
37
+ except Exception:
39
38
  pass
40
39
 
41
40
  @staticmethod
@@ -82,19 +81,23 @@ class AquaAPIhandler(APIHandler):
82
81
  "message": message,
83
82
  "service_payload": service_payload,
84
83
  "reason": reason,
84
+ "request_id": str(uuid.uuid4()),
85
85
  }
86
86
  exc_info = kwargs.get("exc_info")
87
87
  if exc_info:
88
- logger.error("".join(traceback.format_exception(*exc_info)))
88
+ logger.error(
89
+ f"Error Request ID: {reply['request_id']}\n"
90
+ f"Error: {''.join(traceback.format_exception(*exc_info))}"
91
+ )
89
92
  e = exc_info[1]
90
93
  if isinstance(e, HTTPError):
91
94
  reply["message"] = e.log_message or message
92
95
  reply["reason"] = e.reason if e.reason else reply["reason"]
93
- reply["request_id"] = str(uuid.uuid4())
94
- else:
95
- reply["request_id"] = str(uuid.uuid4())
96
96
 
97
- logger.warning(reply["message"])
97
+ logger.error(
98
+ f"Error Request ID: {reply['request_id']}\n"
99
+ f"Error: {reply['message']} {reply['reason']}"
100
+ )
98
101
 
99
102
  # telemetry may not be present if there is an error while initializing
100
103
  if hasattr(self, "telemetry"):
@@ -103,7 +106,7 @@ class AquaAPIhandler(APIHandler):
103
106
  category="aqua/error",
104
107
  action=str(status_code),
105
108
  value=reason,
106
- **aqua_api_details
109
+ **aqua_api_details,
107
110
  )
108
111
 
109
112
  self.finish(json.dumps(reply))
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
 
5
5
 
@@ -10,9 +10,7 @@ from tornado.web import HTTPError
10
10
  from ads.aqua.common.decorator import handle_exceptions
11
11
  from ads.aqua.extension.base_handler import AquaAPIhandler
12
12
  from ads.aqua.extension.errors import Errors
13
- from ads.aqua.extension.utils import validate_function_parameters
14
13
  from ads.aqua.finetuning import AquaFineTuningApp
15
- from ads.aqua.finetuning.entities import CreateFineTuningDetails
16
14
 
17
15
 
18
16
  class AquaFineTuneHandler(AquaAPIhandler):
@@ -33,7 +31,7 @@ class AquaFineTuneHandler(AquaAPIhandler):
33
31
  raise HTTPError(400, f"The request {self.request.path} is invalid.")
34
32
 
35
33
  @handle_exceptions
36
- def post(self, *args, **kwargs):
34
+ def post(self, *args, **kwargs): # noqa: ARG002
37
35
  """Handles post request for the fine-tuning API
38
36
 
39
37
  Raises
@@ -43,17 +41,13 @@ class AquaFineTuneHandler(AquaAPIhandler):
43
41
  """
44
42
  try:
45
43
  input_data = self.get_json_body()
46
- except Exception:
47
- raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
44
+ except Exception as ex:
45
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
48
46
 
49
47
  if not input_data:
50
48
  raise HTTPError(400, Errors.NO_INPUT_DATA)
51
49
 
52
- validate_function_parameters(
53
- data_class=CreateFineTuningDetails, input_data=input_data
54
- )
55
-
56
- self.finish(AquaFineTuningApp().create(CreateFineTuningDetails(**input_data)))
50
+ self.finish(AquaFineTuningApp().create(**input_data))
57
51
 
58
52
  def get_finetuning_config(self, model_id):
59
53
  """Gets the finetuning config for Aqua model."""
@@ -71,7 +65,7 @@ class AquaFineTuneParamsHandler(AquaAPIhandler):
71
65
  )
72
66
 
73
67
  @handle_exceptions
74
- def post(self, *args, **kwargs):
68
+ def post(self, *args, **kwargs): # noqa: ARG002
75
69
  """Handles post request for the finetuning param handler API.
76
70
 
77
71
  Raises
@@ -81,8 +75,8 @@ class AquaFineTuneParamsHandler(AquaAPIhandler):
81
75
  """
82
76
  try:
83
77
  input_data = self.get_json_body()
84
- except Exception:
85
- raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
78
+ except Exception as ex:
79
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
86
80
 
87
81
  if not input_data:
88
82
  raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
 
5
5
  from typing import Optional
@@ -8,6 +8,9 @@ from urllib.parse import urlparse
8
8
  from tornado.web import HTTPError
9
9
 
10
10
  from ads.aqua.common.decorator import handle_exceptions
11
+ from ads.aqua.common.enums import (
12
+ CustomInferenceContainerTypeFamily,
13
+ )
11
14
  from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
12
15
  from ads.aqua.common.utils import (
13
16
  get_hf_model_info,
@@ -128,17 +131,27 @@ class AquaModelHandler(AquaAPIhandler):
128
131
  download_from_hf = (
129
132
  str(input_data.get("download_from_hf", "false")).lower() == "true"
130
133
  )
134
+ local_dir = input_data.get("local_dir")
135
+ cleanup_model_cache = (
136
+ str(input_data.get("cleanup_model_cache", "false")).lower() == "true"
137
+ )
131
138
  inference_container_uri = input_data.get("inference_container_uri")
132
139
  allow_patterns = input_data.get("allow_patterns")
133
140
  ignore_patterns = input_data.get("ignore_patterns")
134
141
  freeform_tags = input_data.get("freeform_tags")
135
142
  defined_tags = input_data.get("defined_tags")
143
+ ignore_model_artifact_check = (
144
+ str(input_data.get("ignore_model_artifact_check", "false")).lower()
145
+ == "true"
146
+ )
136
147
 
137
148
  return self.finish(
138
149
  AquaModelApp().register(
139
150
  model=model,
140
151
  os_path=os_path,
141
152
  download_from_hf=download_from_hf,
153
+ local_dir=local_dir,
154
+ cleanup_model_cache=cleanup_model_cache,
142
155
  inference_container=inference_container,
143
156
  finetuning_container=finetuning_container,
144
157
  compartment_id=compartment_id,
@@ -149,6 +162,7 @@ class AquaModelHandler(AquaAPIhandler):
149
162
  ignore_patterns=ignore_patterns,
150
163
  freeform_tags=freeform_tags,
151
164
  defined_tags=defined_tags,
165
+ ignore_model_artifact_check=ignore_model_artifact_check,
152
166
  )
153
167
  )
154
168
 
@@ -163,7 +177,9 @@ class AquaModelHandler(AquaAPIhandler):
163
177
  raise HTTPError(400, Errors.NO_INPUT_DATA)
164
178
 
165
179
  inference_container = input_data.get("inference_container")
180
+ inference_container_uri = input_data.get("inference_container_uri")
166
181
  inference_containers = AquaModelApp.list_valid_inference_containers()
182
+ inference_containers.extend(CustomInferenceContainerTypeFamily.values())
167
183
  if (
168
184
  inference_container is not None
169
185
  and inference_container not in inference_containers
@@ -176,7 +192,13 @@ class AquaModelHandler(AquaAPIhandler):
176
192
  task = input_data.get("task")
177
193
  app = AquaModelApp()
178
194
  self.finish(
179
- app.edit_registered_model(id, inference_container, enable_finetuning, task)
195
+ app.edit_registered_model(
196
+ id,
197
+ inference_container,
198
+ inference_container_uri,
199
+ enable_finetuning,
200
+ task,
201
+ )
180
202
  )
181
203
  app.clear_model_details_cache(model_id=id)
182
204
 
@@ -1,6 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
6
5
  from ads.common.extended_enum import ExtendedEnumMeta
@@ -17,4 +16,8 @@ class FineTuneCustomMetadata(str, metaclass=ExtendedEnumMeta):
17
16
  SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container"
18
17
 
19
18
 
19
+ class FineTuningRestrictedParams(str, metaclass=ExtendedEnumMeta):
20
+ OPTIMIZER = "optimizer"
21
+
22
+
20
23
  ENV_AQUA_FINE_TUNING_CONTAINER = "AQUA_FINE_TUNING_CONTAINER"
@@ -1,18 +1,24 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
- from dataclasses import dataclass, field
5
- from typing import List, Optional
6
4
 
7
- from ads.aqua.data import AquaJobSummary
8
- from ads.common.serializer import DataClassSerializable
5
+ import json
6
+ from typing import List, Literal, Optional, Union
9
7
 
8
+ from pydantic import Field, model_validator
10
9
 
11
- @dataclass(repr=False)
12
- class AquaFineTuningParams(DataClassSerializable):
13
- epochs: int
10
+ from ads.aqua.common.errors import AquaValueError
11
+ from ads.aqua.config.utils.serializer import Serializable
12
+ from ads.aqua.data import AquaResourceIdentifier
13
+ from ads.aqua.finetuning.constants import FineTuningRestrictedParams
14
+
15
+
16
+ class AquaFineTuningParams(Serializable):
17
+ """Class for maintaining aqua fine-tuning model parameters"""
18
+
19
+ epochs: Optional[int] = None
14
20
  learning_rate: Optional[float] = None
15
- sample_packing: Optional[bool] = "auto"
21
+ sample_packing: Union[bool, None, Literal["auto"]] = "auto"
16
22
  batch_size: Optional[int] = (
17
23
  None # make it batch_size for user, but internally this is micro_batch_size
18
24
  )
@@ -22,21 +28,59 @@ class AquaFineTuningParams(DataClassSerializable):
22
28
  lora_alpha: Optional[int] = None
23
29
  lora_dropout: Optional[float] = None
24
30
  lora_target_linear: Optional[bool] = None
25
- lora_target_modules: Optional[List] = None
31
+ lora_target_modules: Optional[List[str]] = None
26
32
  early_stopping_patience: Optional[int] = None
27
33
  early_stopping_threshold: Optional[float] = None
28
34
 
35
+ class Config:
36
+ extra = "allow"
37
+
38
+ def to_dict(self) -> dict:
39
+ return json.loads(super().to_json(exclude_none=True))
40
+
41
+ @model_validator(mode="before")
42
+ @classmethod
43
+ def validate_restricted_fields(cls, data: dict):
44
+ # we may want to skip validation if loading data from config files instead of user entered parameters
45
+ validate = data.pop("_validate", True)
46
+ if not (validate and isinstance(data, dict)):
47
+ return data
48
+ restricted_params = [
49
+ param for param in data if param in FineTuningRestrictedParams.values()
50
+ ]
51
+ if restricted_params:
52
+ raise AquaValueError(
53
+ f"Found restricted parameter name: {restricted_params}"
54
+ )
55
+ return data
29
56
 
30
- @dataclass(repr=False)
31
- class AquaFineTuningSummary(AquaJobSummary, DataClassSerializable):
32
- parameters: AquaFineTuningParams = field(default_factory=AquaFineTuningParams)
33
57
 
58
+ class AquaFineTuningSummary(Serializable):
59
+ """Represents a summary of Aqua Finetuning job."""
34
60
 
35
- @dataclass(repr=False)
36
- class CreateFineTuningDetails(DataClassSerializable):
37
- """Dataclass to create aqua model fine tuning.
61
+ id: str
62
+ name: str
63
+ console_url: str
64
+ lifecycle_state: str
65
+ lifecycle_details: str
66
+ time_created: str
67
+ tags: dict
68
+ experiment: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
69
+ source: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
70
+ job: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
71
+ parameters: AquaFineTuningParams = Field(default_factory=AquaFineTuningParams)
38
72
 
39
- Fields
73
+ class Config:
74
+ extra = "ignore"
75
+
76
+ def to_dict(self) -> dict:
77
+ return json.loads(super().to_json(exclude_none=True))
78
+
79
+
80
+ class CreateFineTuningDetails(Serializable):
81
+ """Class to create aqua model fine-tuning instance.
82
+
83
+ Properties
40
84
  ------
41
85
  ft_source_id: str
42
86
  The fine tuning source id. Must be model ocid.
@@ -78,6 +122,8 @@ class CreateFineTuningDetails(DataClassSerializable):
78
122
  The log group id for fine tuning job infrastructure.
79
123
  log_id: (str, optional). Defaults to `None`.
80
124
  The log id for fine tuning job infrastructure.
125
+ watch_logs: (bool, optional). Defaults to `False`.
126
+ The flag to watch the job run logs when a fine-tuning job is created.
81
127
  force_overwrite: (bool, optional). Defaults to `False`.
82
128
  Whether to force overwrite the existing file in object storage.
83
129
  freeform_tags: (dict, optional)
@@ -104,6 +150,10 @@ class CreateFineTuningDetails(DataClassSerializable):
104
150
  subnet_id: Optional[str] = None
105
151
  log_id: Optional[str] = None
106
152
  log_group_id: Optional[str] = None
153
+ watch_logs: Optional[bool] = False
107
154
  force_overwrite: Optional[bool] = False
108
155
  freeform_tags: Optional[dict] = None
109
156
  defined_tags: Optional[dict] = None
157
+
158
+ class Config:
159
+ extra = "ignore"