oracle-ads 2.13.11__py3-none-any.whl → 2.13.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. ads/aqua/app.py +73 -15
  2. ads/aqua/cli.py +17 -0
  3. ads/aqua/client/client.py +38 -21
  4. ads/aqua/client/openai_client.py +20 -10
  5. ads/aqua/common/entities.py +78 -12
  6. ads/aqua/common/utils.py +35 -0
  7. ads/aqua/constants.py +2 -0
  8. ads/aqua/evaluation/evaluation.py +5 -4
  9. ads/aqua/extension/common_handler.py +47 -2
  10. ads/aqua/extension/model_handler.py +51 -9
  11. ads/aqua/model/constants.py +1 -0
  12. ads/aqua/model/enums.py +19 -1
  13. ads/aqua/model/model.py +119 -51
  14. ads/aqua/model/utils.py +1 -2
  15. ads/aqua/modeldeployment/config_loader.py +815 -0
  16. ads/aqua/modeldeployment/constants.py +4 -1
  17. ads/aqua/modeldeployment/deployment.py +178 -129
  18. ads/aqua/modeldeployment/entities.py +150 -178
  19. ads/aqua/modeldeployment/model_group_config.py +233 -0
  20. ads/aqua/modeldeployment/utils.py +0 -539
  21. ads/aqua/verify_policies/__init__.py +8 -0
  22. ads/aqua/verify_policies/constants.py +13 -0
  23. ads/aqua/verify_policies/entities.py +29 -0
  24. ads/aqua/verify_policies/messages.py +101 -0
  25. ads/aqua/verify_policies/utils.py +432 -0
  26. ads/aqua/verify_policies/verify.py +345 -0
  27. ads/aqua/version.json +3 -0
  28. ads/common/oci_logging.py +4 -7
  29. ads/common/work_request.py +39 -38
  30. ads/jobs/builders/infrastructure/dsc_job.py +121 -24
  31. ads/jobs/builders/infrastructure/dsc_job_runtime.py +71 -24
  32. ads/jobs/builders/runtimes/base.py +7 -5
  33. ads/jobs/builders/runtimes/pytorch_runtime.py +6 -8
  34. ads/jobs/templates/driver_pytorch.py +486 -172
  35. ads/jobs/templates/driver_utils.py +27 -11
  36. ads/model/deployment/model_deployment.py +51 -38
  37. ads/model/service/oci_datascience_model_deployment.py +6 -11
  38. ads/telemetry/client.py +4 -4
  39. {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/METADATA +2 -1
  40. {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/RECORD +43 -34
  41. {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/WHEEL +0 -0
  42. {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/entry_points.txt +0 -0
  43. {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/licenses/LICENSE.txt +0 -0
ads/aqua/app.py CHANGED
@@ -5,6 +5,7 @@
5
5
  import json
6
6
  import os
7
7
  import traceback
8
+ from concurrent.futures import ThreadPoolExecutor
8
9
  from dataclasses import fields
9
10
  from datetime import datetime, timedelta
10
11
  from itertools import chain
@@ -22,7 +23,7 @@ from ads import set_auth
22
23
  from ads.aqua import logger
23
24
  from ads.aqua.common.entities import ModelConfigResult
24
25
  from ads.aqua.common.enums import ConfigFolder, Tags
25
- from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
26
+ from ads.aqua.common.errors import AquaValueError
26
27
  from ads.aqua.common.utils import (
27
28
  _is_valid_mvs,
28
29
  get_artifact_path,
@@ -58,6 +59,8 @@ from ads.telemetry.client import TelemetryClient
58
59
  class AquaApp:
59
60
  """Base Aqua App to contain common components."""
60
61
 
62
+ MAX_WORKERS = 10 # Number of workers for asynchronous resource loading
63
+
61
64
  @telemetry(name="aqua")
62
65
  def __init__(self) -> None:
63
66
  if OCI_RESOURCE_PRINCIPAL_VERSION:
@@ -128,20 +131,69 @@ class AquaApp:
128
131
  update_model_provenance_details=update_model_provenance_details,
129
132
  )
130
133
 
131
- # TODO: refactor model evaluation implementation to use it.
132
134
  @staticmethod
133
135
  def get_source(source_id: str) -> Union[ModelDeployment, DataScienceModel]:
134
- if is_valid_ocid(source_id):
135
- if "datasciencemodeldeployment" in source_id:
136
- return ModelDeployment.from_id(source_id)
137
- elif "datasciencemodel" in source_id:
138
- return DataScienceModel.from_id(source_id)
136
+ """
137
+ Fetches a model or model deployment based on the provided OCID.
138
+
139
+ Parameters
140
+ ----------
141
+ source_id : str
142
+ OCID of the Data Science model or model deployment.
143
+
144
+ Returns
145
+ -------
146
+ Union[ModelDeployment, DataScienceModel]
147
+ The corresponding resource object.
139
148
 
149
+ Raises
150
+ ------
151
+ AquaValueError
152
+ If the OCID is invalid or unsupported.
153
+ """
154
+ logger.debug(f"Resolving source for ID: {source_id}")
155
+ if not is_valid_ocid(source_id):
156
+ logger.error(f"Invalid OCID format: {source_id}")
157
+ raise AquaValueError(
158
+ f"Invalid source ID: {source_id}. Please provide a valid model or model deployment OCID."
159
+ )
160
+
161
+ if "datasciencemodeldeployment" in source_id:
162
+ logger.debug(f"Identified as ModelDeployment OCID: {source_id}")
163
+ return ModelDeployment.from_id(source_id)
164
+
165
+ if "datasciencemodel" in source_id:
166
+ logger.debug(f"Identified as DataScienceModel OCID: {source_id}")
167
+ return DataScienceModel.from_id(source_id)
168
+
169
+ logger.error(f"Unrecognized OCID type: {source_id}")
140
170
  raise AquaValueError(
141
- f"Invalid source {source_id}. "
142
- "Specify either a model or model deployment id."
171
+ f"Unsupported source ID type: {source_id}. Must be a model or model deployment OCID."
143
172
  )
144
173
 
174
+ def get_multi_source(
175
+ self,
176
+ ids: List[str],
177
+ ) -> Dict[str, Union[ModelDeployment, DataScienceModel]]:
178
+ """
179
+ Retrieves multiple DataScience resources concurrently.
180
+
181
+ Parameters
182
+ ----------
183
+ ids : List[str]
184
+ A list of DataScience OCIDs.
185
+
186
+ Returns
187
+ -------
188
+ Dict[str, Union[ModelDeployment, DataScienceModel]]
189
+ A mapping from OCID to the corresponding resolved resource object.
190
+ """
191
+ logger.debug(f"Fetching {ids} sources in parallel.")
192
+ with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
193
+ results = list(executor.map(self.get_source, ids))
194
+
195
+ return dict(zip(ids, results))
196
+
145
197
  # TODO: refactor model evaluation implementation to use it.
146
198
  @staticmethod
147
199
  def create_model_version_set(
@@ -284,8 +336,11 @@ class AquaApp:
284
336
  logger.info(f"Artifact not found in model {model_id}.")
285
337
  return False
286
338
 
339
+ @cached(cache=TTLCache(maxsize=5, ttl=timedelta(minutes=1), timer=datetime.now))
287
340
  def get_config_from_metadata(
288
- self, model_id: str, metadata_key: str
341
+ self,
342
+ model_id: str,
343
+ metadata_key: str,
289
344
  ) -> ModelConfigResult:
290
345
  """Gets the config for the given Aqua model from model catalog metadata content.
291
346
 
@@ -300,8 +355,9 @@ class AquaApp:
300
355
  ModelConfigResult
301
356
  A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
302
357
  """
303
- config = {}
358
+ config: Dict[str, Any] = {}
304
359
  oci_model = self.ds_client.get_model(model_id).data
360
+
305
361
  try:
306
362
  config = self.ds_client.get_model_defined_metadatum_artifact_content(
307
363
  model_id, metadata_key
@@ -321,7 +377,7 @@ class AquaApp:
321
377
  )
322
378
  return ModelConfigResult(config=config, model_details=oci_model)
323
379
 
324
- @cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
380
+ @cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
325
381
  def get_config(
326
382
  self,
327
383
  model_id: str,
@@ -346,8 +402,10 @@ class AquaApp:
346
402
  ModelConfigResult
347
403
  A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
348
404
  """
349
- config_folder = config_folder or ConfigFolder.CONFIG
405
+ config: Dict[str, Any] = {}
350
406
  oci_model = self.ds_client.get_model(model_id).data
407
+
408
+ config_folder = config_folder or ConfigFolder.CONFIG
351
409
  oci_aqua = (
352
410
  (
353
411
  Tags.AQUA_TAG in oci_model.freeform_tags
@@ -357,9 +415,9 @@ class AquaApp:
357
415
  else False
358
416
  )
359
417
  if not oci_aqua:
360
- raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
418
+ logger.debug(f"Target model {oci_model.id} is not an Aqua model.")
419
+ return ModelConfigResult(config=config, model_details=oci_model)
361
420
 
362
- config: Dict[str, Any] = {}
363
421
  artifact_path = get_artifact_path(oci_model.custom_metadata_list)
364
422
  if not artifact_path:
365
423
  logger.debug(
ads/aqua/cli.py CHANGED
@@ -14,6 +14,7 @@ from ads.aqua.evaluation import AquaEvaluationApp
14
14
  from ads.aqua.finetuning import AquaFineTuningApp
15
15
  from ads.aqua.model import AquaModelApp
16
16
  from ads.aqua.modeldeployment import AquaDeploymentApp
17
+ from ads.aqua.verify_policies import AquaVerifyPoliciesApp
17
18
  from ads.common.utils import LOG_LEVELS
18
19
 
19
20
 
@@ -29,6 +30,7 @@ class AquaCommand:
29
30
  fine_tuning = AquaFineTuningApp
30
31
  deployment = AquaDeploymentApp
31
32
  evaluation = AquaEvaluationApp
33
+ verify_policies = AquaVerifyPoliciesApp
32
34
 
33
35
  def __init__(
34
36
  self,
@@ -94,3 +96,18 @@ class AquaCommand:
94
96
  "If you intend to chain a function call to the result, please separate the "
95
97
  "flag and the subsequent function call with separator `-`."
96
98
  )
99
+
100
+ @staticmethod
101
+ def install():
102
+ """Install ADS Aqua Extension from wheel file. Set enviroment variable `AQUA_EXTENSTION_PATH` to change the wheel file path.
103
+
104
+ Return
105
+ ------
106
+ int:
107
+ Installatation status.
108
+ """
109
+ import subprocess
110
+
111
+ wheel_file_path = os.environ.get("AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl")
112
+ status = subprocess.run(f"pip install {wheel_file_path}",shell=True)
113
+ return status.check_returncode
ads/aqua/client/client.py CHANGED
@@ -61,13 +61,20 @@ class HttpxOCIAuth(httpx.Auth):
61
61
 
62
62
  def __init__(self, signer: Optional[oci.signer.Signer] = None):
63
63
  """
64
- Initialize the HttpxOCIAuth instance.
64
+ Initializes the authentication handler with the given or default OCI signer.
65
65
 
66
- Args:
67
- signer (oci.signer.Signer): The OCI signer to use for signing requests.
66
+ Parameters
67
+ ----------
68
+ signer : oci.signer.Signer, optional
69
+ The OCI signer instance to use. If None, a default signer will be retrieved.
68
70
  """
69
-
70
- self.signer = signer or authutil.default_signer().get("signer")
71
+ try:
72
+ self.signer = signer or authutil.default_signer().get("signer")
73
+ if not self.signer:
74
+ raise ValueError("OCI signer could not be initialized.")
75
+ except Exception as e:
76
+ logger.error("Failed to initialize OCI signer: %s", e)
77
+ raise
71
78
 
72
79
  def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
73
80
  """
@@ -80,21 +87,31 @@ class HttpxOCIAuth(httpx.Auth):
80
87
  httpx.Request: The signed HTTPX request.
81
88
  """
82
89
  # Create a requests.Request object from the HTTPX request
83
- req = requests.Request(
84
- method=request.method,
85
- url=str(request.url),
86
- headers=dict(request.headers),
87
- data=request.content,
88
- )
89
- prepared_request = req.prepare()
90
+ try:
91
+ req = requests.Request(
92
+ method=request.method,
93
+ url=str(request.url),
94
+ headers=dict(request.headers),
95
+ data=request.content,
96
+ )
97
+ prepared_request = req.prepare()
98
+ self.signer.do_request_sign(prepared_request)
99
+
100
+ # Replace headers on the original HTTPX request with signed headers
101
+ request.headers.update(prepared_request.headers)
102
+ logger.debug("Successfully signed request to %s", request.url)
90
103
 
91
- # Sign the request using the OCI Signer
92
- self.signer.do_request_sign(prepared_request)
104
+ # Fix for GET/DELETE requests that OCI Gateway expects with Content-Length
105
+ if (
106
+ request.method in ["GET", "DELETE"]
107
+ and "content-length" not in request.headers
108
+ ):
109
+ request.headers["content-length"] = "0"
93
110
 
94
- # Update the original HTTPX request with the signed headers
95
- request.headers.update(prepared_request.headers)
111
+ except Exception as e:
112
+ logger.error("Failed to sign request to %s: %s", request.url, e)
113
+ raise
96
114
 
97
- # Proceed with the request
98
115
  yield request
99
116
 
100
117
 
@@ -330,8 +347,8 @@ class BaseClient:
330
347
  "Content-Type": "application/json",
331
348
  "Accept": "text/event-stream" if stream else "application/json",
332
349
  }
333
- if stream:
334
- default_headers["enable-streaming"] = "true"
350
+ # if stream:
351
+ # default_headers["enable-streaming"] = "true"
335
352
  if headers:
336
353
  default_headers.update(headers)
337
354
 
@@ -495,7 +512,7 @@ class Client(BaseClient):
495
512
  prompt: str,
496
513
  payload: Optional[Dict[str, Any]] = None,
497
514
  headers: Optional[Dict[str, str]] = None,
498
- stream: bool = True,
515
+ stream: bool = False,
499
516
  ) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
500
517
  """
501
518
  Generate text completion for the given prompt.
@@ -521,7 +538,7 @@ class Client(BaseClient):
521
538
  messages: List[Dict[str, Any]],
522
539
  payload: Optional[Dict[str, Any]] = None,
523
540
  headers: Optional[Dict[str, str]] = None,
524
- stream: bool = True,
541
+ stream: bool = False,
525
542
  ) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
526
543
  """
527
544
  Perform a chat interaction with the model.
@@ -32,7 +32,7 @@ class ModelDeploymentBaseEndpoint(ExtendedEnum):
32
32
  """Supported base endpoints for model deployments."""
33
33
 
34
34
  PREDICT = "predict"
35
- PREDICT_WITH_RESPONSE_STREAM = "predictwithresponsestream"
35
+ PREDICT_WITH_RESPONSE_STREAM = "predictWithResponseStream"
36
36
 
37
37
 
38
38
  class AquaOpenAIMixin:
@@ -51,9 +51,9 @@ class AquaOpenAIMixin:
51
51
  Returns:
52
52
  str: The normalized OpenAI-compatible route path (e.g., '/v1/chat/completions').
53
53
  """
54
- normalized_path = original_path.lower().rstrip("/")
54
+ normalized_path = original_path.rstrip("/")
55
55
 
56
- match = re.search(r"/predict(withresponsestream)?", normalized_path)
56
+ match = re.search(r"/predict(WithResponseStream)?", normalized_path)
57
57
  if not match:
58
58
  logger.debug("Route header cannot be resolved from path: %s", original_path)
59
59
  return ""
@@ -71,7 +71,7 @@ class AquaOpenAIMixin:
71
71
  "Route suffix does not start with a version prefix (e.g., '/v1'). "
72
72
  "This may lead to compatibility issues with OpenAI-style endpoints. "
73
73
  "Consider updating the URL to include a version prefix, "
74
- "such as '/predict/v1' or '/predictwithresponsestream/v1'."
74
+ "such as '/predict/v1' or '/predictWithResponseStream/v1'."
75
75
  )
76
76
  # route_suffix = f"v1/{route_suffix}"
77
77
 
@@ -124,13 +124,13 @@ class AquaOpenAIMixin:
124
124
 
125
125
  def _patch_url(self) -> httpx.URL:
126
126
  """
127
- Strips any suffixes from the base URL to retain only the `/predict` or `/predictwithresponsestream` path.
127
+ Strips any suffixes from the base URL to retain only the `/predict` or `/predictWithResponseStream` path.
128
128
 
129
129
  Returns:
130
130
  httpx.URL: The normalized base URL with the correct model deployment path.
131
131
  """
132
- base_path = f"{self.base_url.path.lower().rstrip('/')}/"
133
- match = re.search(r"/predict(withresponsestream)?/", base_path)
132
+ base_path = f"{self.base_url.path.rstrip('/')}/"
133
+ match = re.search(r"/predict(WithResponseStream)?/", base_path)
134
134
  if match:
135
135
  trimmed = base_path[: match.end() - 1]
136
136
  return self.base_url.copy_with(path=trimmed)
@@ -144,7 +144,7 @@ class AquaOpenAIMixin:
144
144
 
145
145
  This includes:
146
146
  - Patching headers with streaming and routing info.
147
- - Normalizing the URL path to include only `/predict` or `/predictwithresponsestream`.
147
+ - Normalizing the URL path to include only `/predict` or `/predictWithResponseStream`.
148
148
 
149
149
  Args:
150
150
  request (httpx.Request): The outgoing HTTPX request.
@@ -176,6 +176,7 @@ class OpenAI(openai.OpenAI, AquaOpenAIMixin):
176
176
  http_client: Optional[httpx.Client] = None,
177
177
  http_client_kwargs: Optional[Dict[str, Any]] = None,
178
178
  _strict_response_validation: bool = False,
179
+ patch_headers: bool = False,
179
180
  **kwargs: Any,
180
181
  ) -> None:
181
182
  """
@@ -196,6 +197,7 @@ class OpenAI(openai.OpenAI, AquaOpenAIMixin):
196
197
  http_client (httpx.Client, optional): Custom HTTP client; if not provided, one will be auto-created.
197
198
  http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
198
199
  _strict_response_validation (bool, optional): Enable strict response validation.
200
+ patch_headers (bool, optional): If True, redirects the requests by modifying the headers.
199
201
  **kwargs: Additional keyword arguments passed to the parent __init__.
200
202
  """
201
203
  if http_client is None:
@@ -207,6 +209,8 @@ class OpenAI(openai.OpenAI, AquaOpenAIMixin):
207
209
  logger.debug("API key not provided; using default placeholder for OCI.")
208
210
  api_key = "OCI"
209
211
 
212
+ self.patch_headers = patch_headers
213
+
210
214
  super().__init__(
211
215
  api_key=api_key,
212
216
  organization=organization,
@@ -229,7 +233,8 @@ class OpenAI(openai.OpenAI, AquaOpenAIMixin):
229
233
  Args:
230
234
  request (httpx.Request): The outgoing HTTP request.
231
235
  """
232
- self._prepare_request_common(request)
236
+ if self.patch_headers:
237
+ self._prepare_request_common(request)
233
238
 
234
239
 
235
240
  class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
@@ -248,6 +253,7 @@ class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
248
253
  http_client: Optional[httpx.Client] = None,
249
254
  http_client_kwargs: Optional[Dict[str, Any]] = None,
250
255
  _strict_response_validation: bool = False,
256
+ patch_headers: bool = False,
251
257
  **kwargs: Any,
252
258
  ) -> None:
253
259
  """
@@ -269,6 +275,7 @@ class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
269
275
  http_client (httpx.AsyncClient, optional): Custom asynchronous HTTP client; if not provided, one will be auto-created.
270
276
  http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
271
277
  _strict_response_validation (bool, optional): Enable strict response validation.
278
+ patch_headers (bool, optional): If True, redirects the requests by modifying the headers.
272
279
  **kwargs: Additional keyword arguments passed to the parent __init__.
273
280
  """
274
281
  if http_client is None:
@@ -280,6 +287,8 @@ class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
280
287
  logger.debug("API key not provided; using default placeholder for OCI.")
281
288
  api_key = "OCI"
282
289
 
290
+ self.patch_headers = patch_headers
291
+
283
292
  super().__init__(
284
293
  api_key=api_key,
285
294
  organization=organization,
@@ -302,4 +311,5 @@ class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
302
311
  Args:
303
312
  request (httpx.Request): The outgoing HTTP request.
304
313
  """
305
- self._prepare_request_common(request)
314
+ if self.patch_headers:
315
+ self._prepare_request_common(request)
@@ -3,10 +3,10 @@
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
 
5
5
  import re
6
- from typing import Any, Dict, Optional
6
+ from typing import Any, Dict, List, Optional
7
7
 
8
8
  from oci.data_science.models import Model
9
- from pydantic import BaseModel, Field, model_validator
9
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
10
10
 
11
11
  from ads.aqua import logger
12
12
  from ads.aqua.config.utils.serializer import Serializable
@@ -80,24 +80,29 @@ class GPUShapesIndex(Serializable):
80
80
 
81
81
  class ComputeShapeSummary(Serializable):
82
82
  """
83
- Represents the specifications of a compute instance's shape.
83
+ Represents the specifications of a compute instance shape,
84
+ including CPU, memory, and optional GPU characteristics.
84
85
  """
85
86
 
86
87
  core_count: Optional[int] = Field(
87
- default=None, description="The number of CPU cores available."
88
+ default=None,
89
+ description="Total number of CPU cores available for the compute shape.",
88
90
  )
89
91
  memory_in_gbs: Optional[int] = Field(
90
- default=None, description="The amount of memory (in GB) available."
92
+ default=None,
93
+ description="Amount of memory (in GB) available for the compute shape.",
91
94
  )
92
95
  name: Optional[str] = Field(
93
- default=None, description="The name identifier of the compute shape."
96
+ default=None,
97
+ description="Full name of the compute shape, e.g., 'VM.GPU.A10.2'.",
94
98
  )
95
99
  shape_series: Optional[str] = Field(
96
- default=None, description="The series or category of the compute shape."
100
+ default=None,
101
+ description="Shape family or series, e.g., 'GPU', 'Standard', etc.",
97
102
  )
98
103
  gpu_specs: Optional[GPUSpecs] = Field(
99
104
  default=None,
100
- description="The GPU specifications associated with the compute shape.",
105
+ description="Optional GPU specifications associated with the shape.",
101
106
  )
102
107
 
103
108
  @model_validator(mode="after")
@@ -136,6 +141,47 @@ class ComputeShapeSummary(Serializable):
136
141
  return model
137
142
 
138
143
 
144
+ class LoraModuleSpec(BaseModel):
145
+ """
146
+ Descriptor for a LoRA (Low-Rank Adaptation) module used in fine-tuning base models.
147
+
148
+ This class is used to define a single fine-tuned module that can be loaded during
149
+ multi-model deployment alongside a base model.
150
+
151
+ Attributes
152
+ ----------
153
+ model_id : str
154
+ The OCID of the fine-tuned model registered in the OCI Model Catalog.
155
+ model_name : Optional[str]
156
+ The unique name used to route inference requests to this model variant.
157
+ model_path : Optional[str]
158
+ The relative path within the artifact pointing to the LoRA adapter weights.
159
+ """
160
+
161
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
162
+
163
+ model_id: str = Field(
164
+ ...,
165
+ description="OCID of the fine-tuned model (must be registered in the Model Catalog).",
166
+ )
167
+ model_name: Optional[str] = Field(
168
+ default=None,
169
+ description="Name assigned to the fine-tuned model for serving (used as inference route).",
170
+ )
171
+ model_path: Optional[str] = Field(
172
+ default=None,
173
+ description="Relative path to the LoRA weights inside the model artifact.",
174
+ )
175
+
176
+ @model_validator(mode="before")
177
+ @classmethod
178
+ def validate_lora_module(cls, data: dict) -> dict:
179
+ """Validates that required structure exists for a LoRA module."""
180
+ if "model_id" not in data or not data["model_id"]:
181
+ raise ValueError("Missing required field: 'model_id' for fine-tuned model.")
182
+ return data
183
+
184
+
139
185
  class AquaMultiModelRef(Serializable):
140
186
  """
141
187
  Lightweight model descriptor used for multi-model deployment.
@@ -157,7 +203,7 @@ class AquaMultiModelRef(Serializable):
157
203
  Optional environment variables to override during deployment.
158
204
  artifact_location : Optional[str]
159
205
  Artifact path of model in the multimodel group.
160
- fine_tune_weights_location : Optional[str]
206
+ fine_tune_weights : Optional[List[LoraModuleSpec]]
161
207
  For fine tuned models, the artifact path of the modified model weights
162
208
  """
163
209
 
@@ -166,17 +212,37 @@ class AquaMultiModelRef(Serializable):
166
212
  gpu_count: Optional[int] = Field(
167
213
  None, description="The gpu count allocation for the model."
168
214
  )
169
- model_task: Optional[str] = Field(None, description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType")
215
+ model_task: Optional[str] = Field(
216
+ None,
217
+ description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType",
218
+ )
170
219
  env_var: Optional[dict] = Field(
171
220
  default_factory=dict, description="The environment variables of the model."
172
221
  )
173
222
  artifact_location: Optional[str] = Field(
174
223
  None, description="Artifact path of model in the multimodel group."
175
224
  )
176
- fine_tune_weights_location: Optional[str] = Field(
177
- None, description="For fine tuned models, the artifact path of the modified model weights"
225
+ fine_tune_weights: Optional[List[LoraModuleSpec]] = Field(
226
+ None,
227
+ description="For fine tuned models, the artifact path of the modified model weights",
178
228
  )
179
229
 
230
+ def all_model_ids(self) -> List[str]:
231
+ """
232
+ Returns all associated model OCIDs, including the base model and any fine-tuned models.
233
+
234
+ Returns
235
+ -------
236
+ List[str]
237
+ A list of all model OCIDs associated with this multi-model reference.
238
+ """
239
+ ids = {self.model_id}
240
+ if self.fine_tune_weights:
241
+ ids.update(
242
+ module.model_id for module in self.fine_tune_weights if module.model_id
243
+ )
244
+ return list(ids)
245
+
180
246
  class Config:
181
247
  extra = "ignore"
182
248
  protected_namespaces = ()
ads/aqua/common/utils.py CHANGED
@@ -870,6 +870,41 @@ def get_combined_params(params1: str = None, params2: str = None) -> str:
870
870
  return " ".join(combined_params)
871
871
 
872
872
 
873
+ def find_restricted_params(
874
+ default_params: Union[str, List[str]],
875
+ user_params: Union[str, List[str]],
876
+ container_family: str,
877
+ ) -> List[str]:
878
+ """Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
879
+ The default parameters coming from the container index json file cannot be overridden.
880
+
881
+ Parameters
882
+ ----------
883
+ default_params:
884
+ Inference container parameter string with default values.
885
+ user_params:
886
+ Inference container parameter string with user provided values.
887
+ container_family: str
888
+ The image family of model deployment container runtime.
889
+
890
+ Returns
891
+ -------
892
+ A list with params keys common between params1 and params2.
893
+
894
+ """
895
+ restricted_params = []
896
+ if default_params and user_params:
897
+ default_params_dict = get_params_dict(default_params)
898
+ user_params_dict = get_params_dict(user_params)
899
+
900
+ restricted_params_set = get_restricted_params_by_container(container_family)
901
+ for key, _items in user_params_dict.items():
902
+ if key in default_params_dict or key in restricted_params_set:
903
+ restricted_params.append(key.lstrip("-"))
904
+
905
+ return restricted_params
906
+
907
+
873
908
  def build_params_string(params: dict) -> str:
874
909
  """Builds params string from params dict
875
910
 
ads/aqua/constants.py CHANGED
@@ -55,6 +55,8 @@ SERVICE_MANAGED_CONTAINER_URI_SCHEME = "dsmc://"
55
55
  SUPPORTED_FILE_FORMATS = ["jsonl"]
56
56
  MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
57
57
 
58
+ AQUA_CHAT_TEMPLATE_METADATA_KEY = "chat_template"
59
+
58
60
  CONSOLE_LINK_RESOURCE_TYPE_MAPPING = {
59
61
  "datasciencemodel": "models",
60
62
  "datasciencemodeldeployment": "model-deployments",
@@ -727,10 +727,11 @@ class AquaEvaluationApp(AquaApp):
727
727
  raise AquaRuntimeError(error_message) from ex
728
728
 
729
729
  # Build the list of valid model names from custom metadata.
730
- model_names = [
731
- AquaMultiModelRef(**metadata).model_name
732
- for metadata in multi_model_metadata
733
- ]
730
+ model_names = []
731
+ for metadata in multi_model_metadata:
732
+ model = AquaMultiModelRef(**metadata)
733
+ model_names.append(model.model_name)
734
+ model_names.extend(ft.model_name for ft in (model.fine_tune_weights or []) if ft.model_name)
734
735
 
735
736
  # Check if the provided model name is among the valid names.
736
737
  if user_model_name not in model_names: