oracle-ads 2.13.4__py3-none-any.whl → 2.13.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,19 +2,16 @@
2
2
  # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
 
5
-
6
5
  import json
7
- import traceback
8
- import uuid
9
6
  from dataclasses import asdict, is_dataclass
10
- from http.client import responses
11
7
  from typing import Any
12
8
 
13
9
  from notebook.base.handlers import APIHandler
14
10
  from tornado import httputil
15
- from tornado.web import Application, HTTPError
11
+ from tornado.web import Application
16
12
 
17
- from ads.aqua import logger
13
+ from ads.aqua.common.utils import is_pydantic_model
14
+ from ads.aqua.extension.utils import construct_error
18
15
  from ads.config import AQUA_TELEMETRY_BUCKET, AQUA_TELEMETRY_BUCKET_NS
19
16
  from ads.telemetry.client import TelemetryClient
20
17
 
@@ -40,7 +37,7 @@ class AquaAPIhandler(APIHandler):
40
37
  def prepare(self, *args, **kwargs):
41
38
  """The base class prepare is not required for Aqua"""
42
39
  pass
43
-
40
+
44
41
  @staticmethod
45
42
  def serialize(obj: Any):
46
43
  """Serialize the object.
@@ -52,6 +49,9 @@ class AquaAPIhandler(APIHandler):
52
49
  if is_dataclass(obj):
53
50
  return asdict(obj)
54
51
 
52
+ if is_pydantic_model(obj):
53
+ return obj.model_dump()
54
+
55
55
  return str(obj)
56
56
 
57
57
  def finish(self, payload=None): # pylint: disable=W0221
@@ -71,37 +71,11 @@ class AquaAPIhandler(APIHandler):
71
71
 
72
72
  def write_error(self, status_code, **kwargs):
73
73
  """AquaAPIhandler errors are JSON, not human pages."""
74
- self.set_header("Content-Type", "application/json")
75
- reason = kwargs.get("reason")
76
- self.set_status(status_code, reason=reason)
77
- service_payload = kwargs.get("service_payload", {})
78
- default_msg = responses.get(status_code, "Unknown HTTP Error")
79
- message = self.get_default_error_messages(
80
- service_payload, str(status_code), kwargs.get("message", default_msg)
81
- )
82
-
83
- reply = {
84
- "status": status_code,
85
- "message": message,
86
- "service_payload": service_payload,
87
- "reason": reason,
88
- "request_id": str(uuid.uuid4()),
89
- }
90
- exc_info = kwargs.get("exc_info")
91
- if exc_info:
92
- logger.error(
93
- f"Error Request ID: {reply['request_id']}\n"
94
- f"Error: {''.join(traceback.format_exception(*exc_info))}"
95
- )
96
- e = exc_info[1]
97
- if isinstance(e, HTTPError):
98
- reply["message"] = e.log_message or message
99
- reply["reason"] = e.reason if e.reason else reply["reason"]
100
74
 
101
- logger.error(
102
- f"Error Request ID: {reply['request_id']}\n"
103
- f"Error: {reply['message']} {reply['reason']}"
104
- )
75
+ reply_details = construct_error(status_code, **kwargs)
76
+
77
+ self.set_header("Content-Type", "application/json")
78
+ self.set_status(status_code, reason=reply_details.reason)
105
79
 
106
80
  # telemetry may not be present if there is an error while initializing
107
81
  if hasattr(self, "telemetry"):
@@ -109,40 +83,8 @@ class AquaAPIhandler(APIHandler):
109
83
  self.telemetry.record_event_async(
110
84
  category="aqua/error",
111
85
  action=str(status_code),
112
- value=reason,
86
+ value=reply_details.reason,
113
87
  **aqua_api_details,
114
88
  )
115
89
 
116
- self.finish(json.dumps(reply))
117
-
118
- @staticmethod
119
- def get_default_error_messages(
120
- service_payload: dict,
121
- status_code: str,
122
- default_msg: str = "Unknown HTTP Error.",
123
- ):
124
- """Method that maps the error messages based on the operation performed or the status codes encountered."""
125
-
126
- messages = {
127
- "400": "Something went wrong with your request.",
128
- "403": "We're having trouble processing your request with the information provided.",
129
- "404": "Authorization Failed: The resource you're looking for isn't accessible.",
130
- "408": "Server is taking too long to response, please try again.",
131
- "create": "Authorization Failed: Could not create resource.",
132
- "get": "Authorization Failed: The resource you're looking for isn't accessible.",
133
- }
134
-
135
- if service_payload and "operation_name" in service_payload:
136
- operation_name = service_payload["operation_name"]
137
- if operation_name:
138
- if operation_name.startswith("create"):
139
- return messages["create"] + f" Operation Name: {operation_name}."
140
- elif operation_name.startswith("list") or operation_name.startswith(
141
- "get"
142
- ):
143
- return messages["get"] + f" Operation Name: {operation_name}."
144
-
145
- if status_code in messages:
146
- return messages[status_code]
147
- else:
148
- return default_msg
90
+ self.finish(reply_details)
@@ -1,7 +1,8 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
 
5
+ from typing import List, Union
5
6
  from urllib.parse import urlparse
6
7
 
7
8
  from tornado.web import HTTPError
@@ -11,7 +12,7 @@ from ads.aqua.extension.base_handler import AquaAPIhandler
11
12
  from ads.aqua.extension.errors import Errors
12
13
  from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
13
14
  from ads.aqua.modeldeployment.entities import ModelParams
14
- from ads.config import COMPARTMENT_OCID, PROJECT_OCID
15
+ from ads.config import COMPARTMENT_OCID
15
16
 
16
17
 
17
18
  class AquaDeploymentHandler(AquaAPIhandler):
@@ -20,7 +21,7 @@ class AquaDeploymentHandler(AquaAPIhandler):
20
21
 
21
22
  Methods
22
23
  -------
23
- get(self, id="")
24
+ get(self, id: Union[str, List[str]])
24
25
  Retrieves a list of AQUA deployments or model info or logs by ID.
25
26
  post(self, *args, **kwargs)
26
27
  Creates a new AQUA deployment.
@@ -30,6 +31,8 @@ class AquaDeploymentHandler(AquaAPIhandler):
30
31
  Lists all the AQUA deployments.
31
32
  get_deployment_config(self, model_id)
32
33
  Gets the deployment config for Aqua model.
34
+ list_shapes(self)
35
+ Lists the valid model deployment shapes.
33
36
 
34
37
  Raises
35
38
  ------
@@ -37,16 +40,23 @@ class AquaDeploymentHandler(AquaAPIhandler):
37
40
  """
38
41
 
39
42
  @handle_exceptions
40
- def get(self, id=""):
43
+ def get(self, id: Union[str, List[str]] = None):
41
44
  """Handle GET request."""
42
45
  url_parse = urlparse(self.request.path)
43
46
  paths = url_parse.path.strip("/")
44
47
  if paths.startswith("aqua/deployments/config"):
45
- if not id:
48
+ if not id or not isinstance(id, str):
46
49
  raise HTTPError(
47
- 400, f"The request {self.request.path} requires model id."
50
+ 400,
51
+ f"Invalid request format for {self.request.path}. "
52
+ "Expected a single model ID or a comma-separated list of model IDs.",
48
53
  )
49
- return self.get_deployment_config(id)
54
+ id = id.replace(" ", "")
55
+ return self.get_deployment_config(
56
+ model_id=id.split(",") if "," in id else id
57
+ )
58
+ elif paths.startswith("aqua/deployments/shapes"):
59
+ return self.list_shapes()
50
60
  elif paths.startswith("aqua/deployments"):
51
61
  if not id:
52
62
  return self.list()
@@ -98,71 +108,7 @@ class AquaDeploymentHandler(AquaAPIhandler):
98
108
  if not input_data:
99
109
  raise HTTPError(400, Errors.NO_INPUT_DATA)
100
110
 
101
- # required input parameters
102
- display_name = input_data.get("display_name")
103
- if not display_name:
104
- raise HTTPError(
105
- 400, Errors.MISSING_REQUIRED_PARAMETER.format("display_name")
106
- )
107
- instance_shape = input_data.get("instance_shape")
108
- if not instance_shape:
109
- raise HTTPError(
110
- 400, Errors.MISSING_REQUIRED_PARAMETER.format("instance_shape")
111
- )
112
- model_id = input_data.get("model_id")
113
- if not model_id:
114
- raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))
115
-
116
- compartment_id = input_data.get("compartment_id", COMPARTMENT_OCID)
117
- project_id = input_data.get("project_id", PROJECT_OCID)
118
- log_group_id = input_data.get("log_group_id")
119
- access_log_id = input_data.get("access_log_id")
120
- predict_log_id = input_data.get("predict_log_id")
121
- description = input_data.get("description")
122
- instance_count = input_data.get("instance_count")
123
- bandwidth_mbps = input_data.get("bandwidth_mbps")
124
- web_concurrency = input_data.get("web_concurrency")
125
- server_port = input_data.get("server_port")
126
- health_check_port = input_data.get("health_check_port")
127
- env_var = input_data.get("env_var")
128
- container_family = input_data.get("container_family")
129
- ocpus = input_data.get("ocpus")
130
- memory_in_gbs = input_data.get("memory_in_gbs")
131
- model_file = input_data.get("model_file")
132
- private_endpoint_id = input_data.get("private_endpoint_id")
133
- container_image_uri = input_data.get("container_image_uri")
134
- cmd_var = input_data.get("cmd_var")
135
- freeform_tags = input_data.get("freeform_tags")
136
- defined_tags = input_data.get("defined_tags")
137
-
138
- self.finish(
139
- AquaDeploymentApp().create(
140
- compartment_id=compartment_id,
141
- project_id=project_id,
142
- model_id=model_id,
143
- display_name=display_name,
144
- description=description,
145
- instance_count=instance_count,
146
- instance_shape=instance_shape,
147
- log_group_id=log_group_id,
148
- access_log_id=access_log_id,
149
- predict_log_id=predict_log_id,
150
- bandwidth_mbps=bandwidth_mbps,
151
- web_concurrency=web_concurrency,
152
- server_port=server_port,
153
- health_check_port=health_check_port,
154
- env_var=env_var,
155
- container_family=container_family,
156
- ocpus=ocpus,
157
- memory_in_gbs=memory_in_gbs,
158
- model_file=model_file,
159
- private_endpoint_id=private_endpoint_id,
160
- container_image_uri=container_image_uri,
161
- cmd_var=cmd_var,
162
- freeform_tags=freeform_tags,
163
- defined_tags=defined_tags,
164
- )
165
- )
111
+ self.finish(AquaDeploymentApp().create(**input_data))
166
112
 
167
113
  def read(self, id):
168
114
  """Read the information of an Aqua model deployment."""
@@ -181,9 +127,52 @@ class AquaDeploymentHandler(AquaAPIhandler):
181
127
  )
182
128
  )
183
129
 
184
- def get_deployment_config(self, model_id):
185
- """Gets the deployment config for Aqua model."""
186
- return self.finish(AquaDeploymentApp().get_deployment_config(model_id=model_id))
130
+ def get_deployment_config(self, model_id: Union[str, List[str]]):
131
+ """
132
+ Retrieves the deployment configuration for one or more Aqua models.
133
+
134
+ Parameters
135
+ ----------
136
+ model_id : Union[str, List[str]]
137
+ A single model ID (str) or a list of model IDs (List[str]).
138
+
139
+ Returns
140
+ -------
141
+ None
142
+ The function sends the deployment configuration as a response.
143
+ """
144
+ app = AquaDeploymentApp()
145
+
146
+ compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
147
+
148
+ if isinstance(model_id, list):
149
+ # Handle multiple model deployment
150
+ primary_model_id = self.get_argument("primary_model_id", default=None)
151
+ deployment_config = app.get_multimodel_deployment_config(
152
+ model_ids=model_id,
153
+ primary_model_id=primary_model_id,
154
+ compartment_id=compartment_id,
155
+ )
156
+ else:
157
+ # Handle single model deployment
158
+ deployment_config = app.get_deployment_config(model_id=model_id)
159
+
160
+ return self.finish(deployment_config)
161
+
162
+ def list_shapes(self):
163
+ """
164
+ Lists the valid model deployment shapes.
165
+
166
+ Returns
167
+ -------
168
+ List[ComputeShapeSummary]:
169
+ The list of the model deployment shapes.
170
+ """
171
+ compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
172
+
173
+ return self.finish(
174
+ AquaDeploymentApp().list_shapes(compartment_id=compartment_id)
175
+ )
187
176
 
188
177
 
189
178
  class AquaDeploymentInferenceHandler(AquaAPIhandler):
@@ -259,9 +248,10 @@ class AquaDeploymentParamsHandler(AquaAPIhandler):
259
248
  def get(self, model_id):
260
249
  """Handle GET request."""
261
250
  instance_shape = self.get_argument("instance_shape")
251
+ gpu_count = self.get_argument("gpu_count", default=None)
262
252
  return self.finish(
263
253
  AquaDeploymentApp().get_deployment_default_params(
264
- model_id=model_id, instance_shape=instance_shape
254
+ model_id=model_id, instance_shape=instance_shape, gpu_count=gpu_count
265
255
  )
266
256
  )
267
257
 
@@ -300,6 +290,7 @@ class AquaDeploymentParamsHandler(AquaAPIhandler):
300
290
  __handlers__ = [
301
291
  ("deployments/?([^/]*)/params", AquaDeploymentParamsHandler),
302
292
  ("deployments/config/?([^/]*)", AquaDeploymentHandler),
293
+ ("deployments/shapes/?([^/]*)", AquaDeploymentHandler),
303
294
  ("deployments/?([^/]*)", AquaDeploymentHandler),
304
295
  ("deployments/?([^/]*)/activate", AquaDeploymentHandler),
305
296
  ("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),
@@ -1,7 +1,16 @@
1
1
  #!/usr/bin/env python
2
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
+ import uuid
5
+ from typing import Any, Dict, List, Optional
4
6
 
7
+ from pydantic import Field
8
+
9
+ from ads.aqua.config.utils.serializer import Serializable
10
+
11
+ from ads.aqua.constants import (
12
+ AQUA_TROUBLESHOOTING_LINK
13
+ )
5
14
 
6
15
  class Errors(str):
7
16
  INVALID_INPUT_DATA_FORMAT = "Invalid format of input data."
@@ -9,3 +18,13 @@ class Errors(str):
9
18
  MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
10
19
  MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required."
11
20
  INVALID_VALUE_OF_PARAMETER = "Invalid value of parameter: '{}'"
21
+
22
+ class ReplyDetails(Serializable):
23
+ """Structured reply to be returned to the client."""
24
+ status: int
25
+ troubleshooting_tips: str = Field(f"For general tips on troubleshooting: {AQUA_TROUBLESHOOTING_LINK}",
26
+ description="GitHub Link for troubleshooting documentation")
27
+ message: str = Field("Unknown HTTP Error.", description="GitHub Link for troubleshooting documentation")
28
+ service_payload: Optional[Dict[str, Any]] = Field(default_factory=dict)
29
+ reason: str = Field("Unknown error", description="Reason for Error")
30
+ request_id: str = Field(str(uuid.uuid4()), description="Unique ID for tracking the error.")
@@ -1,16 +1,26 @@
1
1
  #!/usr/bin/env python
2
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
+
5
+ import re
6
+ import traceback
7
+ import uuid
4
8
  from dataclasses import fields
5
9
  from datetime import datetime, timedelta
10
+ from http.client import responses
6
11
  from typing import Dict, Optional
7
12
 
8
13
  from cachetools import TTLCache, cached
9
14
  from tornado.web import HTTPError
10
15
 
11
- from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
16
+ from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
12
17
  from ads.aqua.common.utils import fetch_service_compartment
13
- from ads.aqua.extension.errors import Errors
18
+ from ads.aqua.constants import (
19
+ AQUA_TROUBLESHOOTING_LINK,
20
+ OCI_OPERATION_FAILURES,
21
+ STATUS_CODE_MESSAGES,
22
+ )
23
+ from ads.aqua.extension.errors import Errors, ReplyDetails
14
24
 
15
25
 
16
26
  def validate_function_parameters(data_class, input_data: Dict):
@@ -32,3 +42,105 @@ def ui_compatability_check():
32
42
  fetched from the configuration. The cached result is returned when multiple calls are made in quick succession
33
43
  from the UI to avoid multiple config file loads."""
34
44
  return ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment()
45
+
46
+
47
+ def get_default_error_messages(
48
+ service_payload: dict,
49
+ status_code: str,
50
+ default_msg: str = "Unknown HTTP Error.",
51
+ )-> str:
52
+ """Method that maps the error messages based on the operation performed or the status codes encountered."""
53
+
54
+ if service_payload and "operation_name" in service_payload:
55
+ operation_name = service_payload.get("operation_name")
56
+
57
+ if operation_name and status_code in STATUS_CODE_MESSAGES:
58
+ return f"{STATUS_CODE_MESSAGES[status_code]}\n{service_payload.get('message')}\nOperation Name: {operation_name}."
59
+
60
+ return STATUS_CODE_MESSAGES.get(status_code, default_msg)
61
+
62
+
63
+ def get_documentation_link(key: str) -> str:
64
+ """Generates appropriate GitHub link to AQUA Troubleshooting Documentation per the user's error."""
65
+ github_header = re.sub(r"_", "-", key)
66
+ return f"{AQUA_TROUBLESHOOTING_LINK}#{github_header}"
67
+
68
+
69
+ def get_troubleshooting_tips(service_payload: dict,
70
+ status_code: str) -> str:
71
+ """Maps authorization errors to potential solutions on Troubleshooting Page per Aqua Documentation on oci-data-science-ai-samples"""
72
+
73
+ tip = f"For general tips on troubleshooting: {AQUA_TROUBLESHOOTING_LINK}"
74
+
75
+ if status_code in (404, 400):
76
+ failed_operation = service_payload.get('operation_name')
77
+
78
+ if failed_operation in OCI_OPERATION_FAILURES:
79
+ link = get_documentation_link(failed_operation)
80
+ tip = OCI_OPERATION_FAILURES[failed_operation] + link
81
+
82
+ return tip
83
+
84
+
85
+ def construct_error(status_code: int, **kwargs) -> ReplyDetails:
86
+ """
87
+ Formats an error response based on the provided status code and optional details.
88
+
89
+ Args:
90
+ status_code (int): The HTTP status code of the error.
91
+ **kwargs: Additional optional parameters:
92
+ - reason (str, optional): A brief reason for the error.
93
+ - service_payload (dict, optional): Contextual error data from OCI SDK methods
94
+ - message (str, optional): A custom error message, from error raised from failed AQUA methods calling OCI SDK methods
95
+ - exc_info (tuple, optional): Exception information (e.g., from `sys.exc_info()`), used for logging.
96
+
97
+ Returns:
98
+ ReplyDetails: A Pydantic object containing details about the formatted error response.
99
+ kwargs:
100
+ - "status" (int): The HTTP status code.
101
+ - "troubleshooting_tips" (str): a GitHub link to AQUA troubleshooting docs, may be linked to a specific header.
102
+ - "message" (str): error message.
103
+ - "service_payload" (Dict[str, Any], optional) : Additional context from OCI Python SDK call.
104
+ - "reason" (str): The reason for the error.
105
+ - "request_id" (str): A unique identifier for tracking the error.
106
+
107
+ Logs:
108
+ - Logs the error details with a unique request ID.
109
+ - If `exc_info` is provided and contains an `HTTPError`, updates the response message and reason accordingly.
110
+
111
+ """
112
+ reason = kwargs.get("reason", "Unknown Error")
113
+ service_payload = kwargs.get("service_payload", {})
114
+ default_msg = responses.get(status_code, "Unknown HTTP Error")
115
+ message = get_default_error_messages(
116
+ service_payload, str(status_code), kwargs.get("message", default_msg)
117
+ )
118
+
119
+ tips = get_troubleshooting_tips(service_payload, status_code)
120
+
121
+
122
+ reply = ReplyDetails(
123
+ status = status_code,
124
+ troubleshooting_tips = tips,
125
+ message = message,
126
+ service_payload = service_payload,
127
+ reason = reason,
128
+ request_id = str(uuid.uuid4()),
129
+ )
130
+
131
+ exc_info = kwargs.get("exc_info")
132
+ if exc_info:
133
+ logger.error(
134
+ f"Error Request ID: {reply.request_id}\n"
135
+ f"Error: {''.join(traceback.format_exception(*exc_info))}"
136
+ )
137
+ e = exc_info[1]
138
+ if isinstance(e, HTTPError):
139
+ reply.message = e.log_message or message
140
+ reply.reason = e.reason if e.reason else reply.reason
141
+
142
+ logger.error(
143
+ f"Error Request ID: {reply.request_id}\n"
144
+ f"Error: {reply.message} {reply.reason}"
145
+ )
146
+ return reply
@@ -87,13 +87,62 @@ class AquaFineTuningApp(AquaApp):
87
87
  def create(
88
88
  self, create_fine_tuning_details: CreateFineTuningDetails = None, **kwargs
89
89
  ) -> "AquaFineTuningSummary":
90
- """Creates Aqua fine tuning for model.
90
+ """Creates Aqua fine tuning for model.\n
91
+ For detailed information about CLI flags see: https://github.com/oracle-samples/oci-data-science-ai-samples/blob/f271ca63d12e3c256718f23a14d93da4b4fc086b/ai-quick-actions/cli-tips.md#create-fine-tuned-model
91
92
 
92
93
  Parameters
93
94
  ----------
94
95
  create_fine_tuning_details: CreateFineTuningDetails
95
96
  The CreateFineTuningDetails data class which contains all
96
97
  required and optional fields to create the aqua fine tuning.
98
+ kwargs:
99
+ ft_source_id: str The fine tuning source id. Must be model OCID.
100
+ ft_name: str
101
+ The name for fine tuning.
102
+ dataset_path: str
103
+ The dataset path for fine tuning. Could be either a local path from notebook session
104
+ or an object storage path.
105
+ report_path: str
106
+ The report path for fine tuning. Must be an object storage path.
107
+ ft_parameters: dict
108
+ The parameters for fine tuning.
109
+ shape_name: str
110
+ The shape name for fine tuning job infrastructure.
111
+ replica: int
112
+ The replica for fine tuning job runtime.
113
+ validation_set_size: float
114
+ The validation set size for fine tuning job. Must be a float in between [0,1).
115
+ ft_description: (str, optional). Defaults to `None`.
116
+ The description for fine tuning.
117
+ compartment_id: (str, optional). Defaults to `None`.
118
+ The compartment id for fine tuning.
119
+ project_id: (str, optional). Defaults to `None`.
120
+ The project id for fine tuning.
121
+ experiment_id: (str, optional). Defaults to `None`.
122
+ The fine tuning model version set id. If provided,
123
+ fine tuning model will be associated with it.
124
+ experiment_name: (str, optional). Defaults to `None`.
125
+ The fine tuning model version set name. If provided,
126
+ the fine tuning version set with the same name will be used if exists,
127
+ otherwise a new model version set will be created with the name.
128
+ experiment_description: (str, optional). Defaults to `None`.
129
+ The description for fine tuning model version set.
130
+ block_storage_size: (int, optional). Defaults to 256.
131
+ The storage for fine tuning job infrastructure.
132
+ subnet_id: (str, optional). Defaults to `None`.
133
+ The custom egress for fine tuning job.
134
+ log_group_id: (str, optional). Defaults to `None`.
135
+ The log group id for fine tuning job infrastructure.
136
+ log_id: (str, optional). Defaults to `None`.
137
+ The log id for fine tuning job infrastructure.
138
+ watch_logs: (bool, optional). Defaults to `False`.
139
+ The flag to watch the job run logs when a fine-tuning job is created.
140
+ force_overwrite: (bool, optional). Defaults to `False`.
141
+ Whether to force overwrite the existing file in object storage.
142
+ freeform_tags: (dict, optional)
143
+ Freeform tags for the fine-tuning model
144
+ defined_tags: (dict, optional)
145
+ Defined tags for the fine-tuning model
97
146
  kwargs:
98
147
  The kwargs for creating CreateFineTuningDetails instance if
99
148
  no create_fine_tuning_details provided.
@@ -18,6 +18,8 @@ class ModelCustomMetadataFields(ExtendedEnum):
18
18
  EVALUATION_CONTAINER = "evaluation-container"
19
19
  FINETUNE_CONTAINER = "finetune-container"
20
20
  DEPLOYMENT_CONTAINER_URI = "deployment-container-uri"
21
+ MULTIMODEL_GROUP_COUNT = "model_group_count"
22
+ MULTIMODEL_METADATA = "multi_model_metadata"
21
23
 
22
24
 
23
25
  class ModelTask(ExtendedEnum):
@@ -34,6 +36,7 @@ class FineTuningMetricCategories(ExtendedEnum):
34
36
  class ModelType(ExtendedEnum):
35
37
  FT = "FT" # Fine Tuned Model
36
38
  BASE = "BASE" # Base model
39
+ MULTIMODEL = "MULTIMODEL"
37
40
 
38
41
 
39
42
  # TODO: merge metadata key used in create FT
ads/aqua/model/enums.py CHANGED
@@ -23,3 +23,8 @@ class FineTuningCustomMetadata(ExtendedEnum):
23
23
  VALIDATION_METRICS_FINAL = "val_metrics_final"
24
24
  TRAINING_METRICS_EPOCH = "train_metrics_epoch"
25
25
  VALIDATION_METRICS_EPOCH = "val_metrics_epoch"
26
+
27
+
28
+ class MultiModelSupportedTaskType(ExtendedEnum):
29
+ TEXT_GENERATION = "text-generation"
30
+ TEXT_GENERATION_ALT = "text_generation"