oracle-ads 2.13.11__py3-none-any.whl → 2.13.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/app.py +73 -15
- ads/aqua/cli.py +17 -0
- ads/aqua/client/client.py +38 -21
- ads/aqua/client/openai_client.py +20 -10
- ads/aqua/common/entities.py +78 -12
- ads/aqua/common/utils.py +35 -0
- ads/aqua/constants.py +2 -0
- ads/aqua/evaluation/evaluation.py +5 -4
- ads/aqua/extension/common_handler.py +47 -2
- ads/aqua/extension/model_handler.py +51 -9
- ads/aqua/model/constants.py +1 -0
- ads/aqua/model/enums.py +19 -1
- ads/aqua/model/model.py +119 -51
- ads/aqua/model/utils.py +1 -2
- ads/aqua/modeldeployment/config_loader.py +815 -0
- ads/aqua/modeldeployment/constants.py +4 -1
- ads/aqua/modeldeployment/deployment.py +178 -129
- ads/aqua/modeldeployment/entities.py +150 -178
- ads/aqua/modeldeployment/model_group_config.py +233 -0
- ads/aqua/modeldeployment/utils.py +0 -539
- ads/aqua/verify_policies/__init__.py +8 -0
- ads/aqua/verify_policies/constants.py +13 -0
- ads/aqua/verify_policies/entities.py +29 -0
- ads/aqua/verify_policies/messages.py +101 -0
- ads/aqua/verify_policies/utils.py +432 -0
- ads/aqua/verify_policies/verify.py +345 -0
- ads/aqua/version.json +3 -0
- ads/common/oci_logging.py +4 -7
- ads/common/work_request.py +39 -38
- ads/jobs/builders/infrastructure/dsc_job.py +121 -24
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +71 -24
- ads/jobs/builders/runtimes/base.py +7 -5
- ads/jobs/builders/runtimes/pytorch_runtime.py +6 -8
- ads/jobs/templates/driver_pytorch.py +486 -172
- ads/jobs/templates/driver_utils.py +27 -11
- ads/model/deployment/model_deployment.py +51 -38
- ads/model/service/oci_datascience_model_deployment.py +6 -11
- ads/telemetry/client.py +4 -4
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/METADATA +2 -1
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/RECORD +43 -34
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/WHEEL +0 -0
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/entry_points.txt +0 -0
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/licenses/LICENSE.txt +0 -0
ads/aqua/app.py
CHANGED
@@ -5,6 +5,7 @@
|
|
5
5
|
import json
|
6
6
|
import os
|
7
7
|
import traceback
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
8
9
|
from dataclasses import fields
|
9
10
|
from datetime import datetime, timedelta
|
10
11
|
from itertools import chain
|
@@ -22,7 +23,7 @@ from ads import set_auth
|
|
22
23
|
from ads.aqua import logger
|
23
24
|
from ads.aqua.common.entities import ModelConfigResult
|
24
25
|
from ads.aqua.common.enums import ConfigFolder, Tags
|
25
|
-
from ads.aqua.common.errors import
|
26
|
+
from ads.aqua.common.errors import AquaValueError
|
26
27
|
from ads.aqua.common.utils import (
|
27
28
|
_is_valid_mvs,
|
28
29
|
get_artifact_path,
|
@@ -58,6 +59,8 @@ from ads.telemetry.client import TelemetryClient
|
|
58
59
|
class AquaApp:
|
59
60
|
"""Base Aqua App to contain common components."""
|
60
61
|
|
62
|
+
MAX_WORKERS = 10 # Number of workers for asynchronous resource loading
|
63
|
+
|
61
64
|
@telemetry(name="aqua")
|
62
65
|
def __init__(self) -> None:
|
63
66
|
if OCI_RESOURCE_PRINCIPAL_VERSION:
|
@@ -128,20 +131,69 @@ class AquaApp:
|
|
128
131
|
update_model_provenance_details=update_model_provenance_details,
|
129
132
|
)
|
130
133
|
|
131
|
-
# TODO: refactor model evaluation implementation to use it.
|
132
134
|
@staticmethod
|
133
135
|
def get_source(source_id: str) -> Union[ModelDeployment, DataScienceModel]:
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
136
|
+
"""
|
137
|
+
Fetches a model or model deployment based on the provided OCID.
|
138
|
+
|
139
|
+
Parameters
|
140
|
+
----------
|
141
|
+
source_id : str
|
142
|
+
OCID of the Data Science model or model deployment.
|
143
|
+
|
144
|
+
Returns
|
145
|
+
-------
|
146
|
+
Union[ModelDeployment, DataScienceModel]
|
147
|
+
The corresponding resource object.
|
139
148
|
|
149
|
+
Raises
|
150
|
+
------
|
151
|
+
AquaValueError
|
152
|
+
If the OCID is invalid or unsupported.
|
153
|
+
"""
|
154
|
+
logger.debug(f"Resolving source for ID: {source_id}")
|
155
|
+
if not is_valid_ocid(source_id):
|
156
|
+
logger.error(f"Invalid OCID format: {source_id}")
|
157
|
+
raise AquaValueError(
|
158
|
+
f"Invalid source ID: {source_id}. Please provide a valid model or model deployment OCID."
|
159
|
+
)
|
160
|
+
|
161
|
+
if "datasciencemodeldeployment" in source_id:
|
162
|
+
logger.debug(f"Identified as ModelDeployment OCID: {source_id}")
|
163
|
+
return ModelDeployment.from_id(source_id)
|
164
|
+
|
165
|
+
if "datasciencemodel" in source_id:
|
166
|
+
logger.debug(f"Identified as DataScienceModel OCID: {source_id}")
|
167
|
+
return DataScienceModel.from_id(source_id)
|
168
|
+
|
169
|
+
logger.error(f"Unrecognized OCID type: {source_id}")
|
140
170
|
raise AquaValueError(
|
141
|
-
f"
|
142
|
-
"Specify either a model or model deployment id."
|
171
|
+
f"Unsupported source ID type: {source_id}. Must be a model or model deployment OCID."
|
143
172
|
)
|
144
173
|
|
174
|
+
def get_multi_source(
|
175
|
+
self,
|
176
|
+
ids: List[str],
|
177
|
+
) -> Dict[str, Union[ModelDeployment, DataScienceModel]]:
|
178
|
+
"""
|
179
|
+
Retrieves multiple DataScience resources concurrently.
|
180
|
+
|
181
|
+
Parameters
|
182
|
+
----------
|
183
|
+
ids : List[str]
|
184
|
+
A list of DataScience OCIDs.
|
185
|
+
|
186
|
+
Returns
|
187
|
+
-------
|
188
|
+
Dict[str, Union[ModelDeployment, DataScienceModel]]
|
189
|
+
A mapping from OCID to the corresponding resolved resource object.
|
190
|
+
"""
|
191
|
+
logger.debug(f"Fetching {ids} sources in parallel.")
|
192
|
+
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
|
193
|
+
results = list(executor.map(self.get_source, ids))
|
194
|
+
|
195
|
+
return dict(zip(ids, results))
|
196
|
+
|
145
197
|
# TODO: refactor model evaluation implementation to use it.
|
146
198
|
@staticmethod
|
147
199
|
def create_model_version_set(
|
@@ -284,8 +336,11 @@ class AquaApp:
|
|
284
336
|
logger.info(f"Artifact not found in model {model_id}.")
|
285
337
|
return False
|
286
338
|
|
339
|
+
@cached(cache=TTLCache(maxsize=5, ttl=timedelta(minutes=1), timer=datetime.now))
|
287
340
|
def get_config_from_metadata(
|
288
|
-
self,
|
341
|
+
self,
|
342
|
+
model_id: str,
|
343
|
+
metadata_key: str,
|
289
344
|
) -> ModelConfigResult:
|
290
345
|
"""Gets the config for the given Aqua model from model catalog metadata content.
|
291
346
|
|
@@ -300,8 +355,9 @@ class AquaApp:
|
|
300
355
|
ModelConfigResult
|
301
356
|
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
|
302
357
|
"""
|
303
|
-
config = {}
|
358
|
+
config: Dict[str, Any] = {}
|
304
359
|
oci_model = self.ds_client.get_model(model_id).data
|
360
|
+
|
305
361
|
try:
|
306
362
|
config = self.ds_client.get_model_defined_metadatum_artifact_content(
|
307
363
|
model_id, metadata_key
|
@@ -321,7 +377,7 @@ class AquaApp:
|
|
321
377
|
)
|
322
378
|
return ModelConfigResult(config=config, model_details=oci_model)
|
323
379
|
|
324
|
-
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=
|
380
|
+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
|
325
381
|
def get_config(
|
326
382
|
self,
|
327
383
|
model_id: str,
|
@@ -346,8 +402,10 @@ class AquaApp:
|
|
346
402
|
ModelConfigResult
|
347
403
|
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
|
348
404
|
"""
|
349
|
-
|
405
|
+
config: Dict[str, Any] = {}
|
350
406
|
oci_model = self.ds_client.get_model(model_id).data
|
407
|
+
|
408
|
+
config_folder = config_folder or ConfigFolder.CONFIG
|
351
409
|
oci_aqua = (
|
352
410
|
(
|
353
411
|
Tags.AQUA_TAG in oci_model.freeform_tags
|
@@ -357,9 +415,9 @@ class AquaApp:
|
|
357
415
|
else False
|
358
416
|
)
|
359
417
|
if not oci_aqua:
|
360
|
-
|
418
|
+
logger.debug(f"Target model {oci_model.id} is not an Aqua model.")
|
419
|
+
return ModelConfigResult(config=config, model_details=oci_model)
|
361
420
|
|
362
|
-
config: Dict[str, Any] = {}
|
363
421
|
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
|
364
422
|
if not artifact_path:
|
365
423
|
logger.debug(
|
ads/aqua/cli.py
CHANGED
@@ -14,6 +14,7 @@ from ads.aqua.evaluation import AquaEvaluationApp
|
|
14
14
|
from ads.aqua.finetuning import AquaFineTuningApp
|
15
15
|
from ads.aqua.model import AquaModelApp
|
16
16
|
from ads.aqua.modeldeployment import AquaDeploymentApp
|
17
|
+
from ads.aqua.verify_policies import AquaVerifyPoliciesApp
|
17
18
|
from ads.common.utils import LOG_LEVELS
|
18
19
|
|
19
20
|
|
@@ -29,6 +30,7 @@ class AquaCommand:
|
|
29
30
|
fine_tuning = AquaFineTuningApp
|
30
31
|
deployment = AquaDeploymentApp
|
31
32
|
evaluation = AquaEvaluationApp
|
33
|
+
verify_policies = AquaVerifyPoliciesApp
|
32
34
|
|
33
35
|
def __init__(
|
34
36
|
self,
|
@@ -94,3 +96,18 @@ class AquaCommand:
|
|
94
96
|
"If you intend to chain a function call to the result, please separate the "
|
95
97
|
"flag and the subsequent function call with separator `-`."
|
96
98
|
)
|
99
|
+
|
100
|
+
@staticmethod
|
101
|
+
def install():
|
102
|
+
"""Install ADS Aqua Extension from wheel file. Set enviroment variable `AQUA_EXTENSTION_PATH` to change the wheel file path.
|
103
|
+
|
104
|
+
Return
|
105
|
+
------
|
106
|
+
int:
|
107
|
+
Installatation status.
|
108
|
+
"""
|
109
|
+
import subprocess
|
110
|
+
|
111
|
+
wheel_file_path = os.environ.get("AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl")
|
112
|
+
status = subprocess.run(f"pip install {wheel_file_path}",shell=True)
|
113
|
+
return status.check_returncode
|
ads/aqua/client/client.py
CHANGED
@@ -61,13 +61,20 @@ class HttpxOCIAuth(httpx.Auth):
|
|
61
61
|
|
62
62
|
def __init__(self, signer: Optional[oci.signer.Signer] = None):
|
63
63
|
"""
|
64
|
-
|
64
|
+
Initializes the authentication handler with the given or default OCI signer.
|
65
65
|
|
66
|
-
|
67
|
-
|
66
|
+
Parameters
|
67
|
+
----------
|
68
|
+
signer : oci.signer.Signer, optional
|
69
|
+
The OCI signer instance to use. If None, a default signer will be retrieved.
|
68
70
|
"""
|
69
|
-
|
70
|
-
|
71
|
+
try:
|
72
|
+
self.signer = signer or authutil.default_signer().get("signer")
|
73
|
+
if not self.signer:
|
74
|
+
raise ValueError("OCI signer could not be initialized.")
|
75
|
+
except Exception as e:
|
76
|
+
logger.error("Failed to initialize OCI signer: %s", e)
|
77
|
+
raise
|
71
78
|
|
72
79
|
def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
|
73
80
|
"""
|
@@ -80,21 +87,31 @@ class HttpxOCIAuth(httpx.Auth):
|
|
80
87
|
httpx.Request: The signed HTTPX request.
|
81
88
|
"""
|
82
89
|
# Create a requests.Request object from the HTTPX request
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
+
try:
|
91
|
+
req = requests.Request(
|
92
|
+
method=request.method,
|
93
|
+
url=str(request.url),
|
94
|
+
headers=dict(request.headers),
|
95
|
+
data=request.content,
|
96
|
+
)
|
97
|
+
prepared_request = req.prepare()
|
98
|
+
self.signer.do_request_sign(prepared_request)
|
99
|
+
|
100
|
+
# Replace headers on the original HTTPX request with signed headers
|
101
|
+
request.headers.update(prepared_request.headers)
|
102
|
+
logger.debug("Successfully signed request to %s", request.url)
|
90
103
|
|
91
|
-
|
92
|
-
|
104
|
+
# Fix for GET/DELETE requests that OCI Gateway expects with Content-Length
|
105
|
+
if (
|
106
|
+
request.method in ["GET", "DELETE"]
|
107
|
+
and "content-length" not in request.headers
|
108
|
+
):
|
109
|
+
request.headers["content-length"] = "0"
|
93
110
|
|
94
|
-
|
95
|
-
|
111
|
+
except Exception as e:
|
112
|
+
logger.error("Failed to sign request to %s: %s", request.url, e)
|
113
|
+
raise
|
96
114
|
|
97
|
-
# Proceed with the request
|
98
115
|
yield request
|
99
116
|
|
100
117
|
|
@@ -330,8 +347,8 @@ class BaseClient:
|
|
330
347
|
"Content-Type": "application/json",
|
331
348
|
"Accept": "text/event-stream" if stream else "application/json",
|
332
349
|
}
|
333
|
-
if stream:
|
334
|
-
|
350
|
+
# if stream:
|
351
|
+
# default_headers["enable-streaming"] = "true"
|
335
352
|
if headers:
|
336
353
|
default_headers.update(headers)
|
337
354
|
|
@@ -495,7 +512,7 @@ class Client(BaseClient):
|
|
495
512
|
prompt: str,
|
496
513
|
payload: Optional[Dict[str, Any]] = None,
|
497
514
|
headers: Optional[Dict[str, str]] = None,
|
498
|
-
stream: bool =
|
515
|
+
stream: bool = False,
|
499
516
|
) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
|
500
517
|
"""
|
501
518
|
Generate text completion for the given prompt.
|
@@ -521,7 +538,7 @@ class Client(BaseClient):
|
|
521
538
|
messages: List[Dict[str, Any]],
|
522
539
|
payload: Optional[Dict[str, Any]] = None,
|
523
540
|
headers: Optional[Dict[str, str]] = None,
|
524
|
-
stream: bool =
|
541
|
+
stream: bool = False,
|
525
542
|
) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
|
526
543
|
"""
|
527
544
|
Perform a chat interaction with the model.
|
ads/aqua/client/openai_client.py
CHANGED
@@ -32,7 +32,7 @@ class ModelDeploymentBaseEndpoint(ExtendedEnum):
|
|
32
32
|
"""Supported base endpoints for model deployments."""
|
33
33
|
|
34
34
|
PREDICT = "predict"
|
35
|
-
PREDICT_WITH_RESPONSE_STREAM = "
|
35
|
+
PREDICT_WITH_RESPONSE_STREAM = "predictWithResponseStream"
|
36
36
|
|
37
37
|
|
38
38
|
class AquaOpenAIMixin:
|
@@ -51,9 +51,9 @@ class AquaOpenAIMixin:
|
|
51
51
|
Returns:
|
52
52
|
str: The normalized OpenAI-compatible route path (e.g., '/v1/chat/completions').
|
53
53
|
"""
|
54
|
-
normalized_path = original_path.
|
54
|
+
normalized_path = original_path.rstrip("/")
|
55
55
|
|
56
|
-
match = re.search(r"/predict(
|
56
|
+
match = re.search(r"/predict(WithResponseStream)?", normalized_path)
|
57
57
|
if not match:
|
58
58
|
logger.debug("Route header cannot be resolved from path: %s", original_path)
|
59
59
|
return ""
|
@@ -71,7 +71,7 @@ class AquaOpenAIMixin:
|
|
71
71
|
"Route suffix does not start with a version prefix (e.g., '/v1'). "
|
72
72
|
"This may lead to compatibility issues with OpenAI-style endpoints. "
|
73
73
|
"Consider updating the URL to include a version prefix, "
|
74
|
-
"such as '/predict/v1' or '/
|
74
|
+
"such as '/predict/v1' or '/predictWithResponseStream/v1'."
|
75
75
|
)
|
76
76
|
# route_suffix = f"v1/{route_suffix}"
|
77
77
|
|
@@ -124,13 +124,13 @@ class AquaOpenAIMixin:
|
|
124
124
|
|
125
125
|
def _patch_url(self) -> httpx.URL:
|
126
126
|
"""
|
127
|
-
Strips any suffixes from the base URL to retain only the `/predict` or `/
|
127
|
+
Strips any suffixes from the base URL to retain only the `/predict` or `/predictWithResponseStream` path.
|
128
128
|
|
129
129
|
Returns:
|
130
130
|
httpx.URL: The normalized base URL with the correct model deployment path.
|
131
131
|
"""
|
132
|
-
base_path = f"{self.base_url.path.
|
133
|
-
match = re.search(r"/predict(
|
132
|
+
base_path = f"{self.base_url.path.rstrip('/')}/"
|
133
|
+
match = re.search(r"/predict(WithResponseStream)?/", base_path)
|
134
134
|
if match:
|
135
135
|
trimmed = base_path[: match.end() - 1]
|
136
136
|
return self.base_url.copy_with(path=trimmed)
|
@@ -144,7 +144,7 @@ class AquaOpenAIMixin:
|
|
144
144
|
|
145
145
|
This includes:
|
146
146
|
- Patching headers with streaming and routing info.
|
147
|
-
- Normalizing the URL path to include only `/predict` or `/
|
147
|
+
- Normalizing the URL path to include only `/predict` or `/predictWithResponseStream`.
|
148
148
|
|
149
149
|
Args:
|
150
150
|
request (httpx.Request): The outgoing HTTPX request.
|
@@ -176,6 +176,7 @@ class OpenAI(openai.OpenAI, AquaOpenAIMixin):
|
|
176
176
|
http_client: Optional[httpx.Client] = None,
|
177
177
|
http_client_kwargs: Optional[Dict[str, Any]] = None,
|
178
178
|
_strict_response_validation: bool = False,
|
179
|
+
patch_headers: bool = False,
|
179
180
|
**kwargs: Any,
|
180
181
|
) -> None:
|
181
182
|
"""
|
@@ -196,6 +197,7 @@ class OpenAI(openai.OpenAI, AquaOpenAIMixin):
|
|
196
197
|
http_client (httpx.Client, optional): Custom HTTP client; if not provided, one will be auto-created.
|
197
198
|
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
|
198
199
|
_strict_response_validation (bool, optional): Enable strict response validation.
|
200
|
+
patch_headers (bool, optional): If True, redirects the requests by modifying the headers.
|
199
201
|
**kwargs: Additional keyword arguments passed to the parent __init__.
|
200
202
|
"""
|
201
203
|
if http_client is None:
|
@@ -207,6 +209,8 @@ class OpenAI(openai.OpenAI, AquaOpenAIMixin):
|
|
207
209
|
logger.debug("API key not provided; using default placeholder for OCI.")
|
208
210
|
api_key = "OCI"
|
209
211
|
|
212
|
+
self.patch_headers = patch_headers
|
213
|
+
|
210
214
|
super().__init__(
|
211
215
|
api_key=api_key,
|
212
216
|
organization=organization,
|
@@ -229,7 +233,8 @@ class OpenAI(openai.OpenAI, AquaOpenAIMixin):
|
|
229
233
|
Args:
|
230
234
|
request (httpx.Request): The outgoing HTTP request.
|
231
235
|
"""
|
232
|
-
self.
|
236
|
+
if self.patch_headers:
|
237
|
+
self._prepare_request_common(request)
|
233
238
|
|
234
239
|
|
235
240
|
class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
|
@@ -248,6 +253,7 @@ class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
|
|
248
253
|
http_client: Optional[httpx.Client] = None,
|
249
254
|
http_client_kwargs: Optional[Dict[str, Any]] = None,
|
250
255
|
_strict_response_validation: bool = False,
|
256
|
+
patch_headers: bool = False,
|
251
257
|
**kwargs: Any,
|
252
258
|
) -> None:
|
253
259
|
"""
|
@@ -269,6 +275,7 @@ class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
|
|
269
275
|
http_client (httpx.AsyncClient, optional): Custom asynchronous HTTP client; if not provided, one will be auto-created.
|
270
276
|
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
|
271
277
|
_strict_response_validation (bool, optional): Enable strict response validation.
|
278
|
+
patch_headers (bool, optional): If True, redirects the requests by modifying the headers.
|
272
279
|
**kwargs: Additional keyword arguments passed to the parent __init__.
|
273
280
|
"""
|
274
281
|
if http_client is None:
|
@@ -280,6 +287,8 @@ class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
|
|
280
287
|
logger.debug("API key not provided; using default placeholder for OCI.")
|
281
288
|
api_key = "OCI"
|
282
289
|
|
290
|
+
self.patch_headers = patch_headers
|
291
|
+
|
283
292
|
super().__init__(
|
284
293
|
api_key=api_key,
|
285
294
|
organization=organization,
|
@@ -302,4 +311,5 @@ class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
|
|
302
311
|
Args:
|
303
312
|
request (httpx.Request): The outgoing HTTP request.
|
304
313
|
"""
|
305
|
-
self.
|
314
|
+
if self.patch_headers:
|
315
|
+
self._prepare_request_common(request)
|
ads/aqua/common/entities.py
CHANGED
@@ -3,10 +3,10 @@
|
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
5
|
import re
|
6
|
-
from typing import Any, Dict, Optional
|
6
|
+
from typing import Any, Dict, List, Optional
|
7
7
|
|
8
8
|
from oci.data_science.models import Model
|
9
|
-
from pydantic import BaseModel, Field, model_validator
|
9
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
10
10
|
|
11
11
|
from ads.aqua import logger
|
12
12
|
from ads.aqua.config.utils.serializer import Serializable
|
@@ -80,24 +80,29 @@ class GPUShapesIndex(Serializable):
|
|
80
80
|
|
81
81
|
class ComputeShapeSummary(Serializable):
|
82
82
|
"""
|
83
|
-
Represents the specifications of a compute instance
|
83
|
+
Represents the specifications of a compute instance shape,
|
84
|
+
including CPU, memory, and optional GPU characteristics.
|
84
85
|
"""
|
85
86
|
|
86
87
|
core_count: Optional[int] = Field(
|
87
|
-
default=None,
|
88
|
+
default=None,
|
89
|
+
description="Total number of CPU cores available for the compute shape.",
|
88
90
|
)
|
89
91
|
memory_in_gbs: Optional[int] = Field(
|
90
|
-
default=None,
|
92
|
+
default=None,
|
93
|
+
description="Amount of memory (in GB) available for the compute shape.",
|
91
94
|
)
|
92
95
|
name: Optional[str] = Field(
|
93
|
-
default=None,
|
96
|
+
default=None,
|
97
|
+
description="Full name of the compute shape, e.g., 'VM.GPU.A10.2'.",
|
94
98
|
)
|
95
99
|
shape_series: Optional[str] = Field(
|
96
|
-
default=None,
|
100
|
+
default=None,
|
101
|
+
description="Shape family or series, e.g., 'GPU', 'Standard', etc.",
|
97
102
|
)
|
98
103
|
gpu_specs: Optional[GPUSpecs] = Field(
|
99
104
|
default=None,
|
100
|
-
description="
|
105
|
+
description="Optional GPU specifications associated with the shape.",
|
101
106
|
)
|
102
107
|
|
103
108
|
@model_validator(mode="after")
|
@@ -136,6 +141,47 @@ class ComputeShapeSummary(Serializable):
|
|
136
141
|
return model
|
137
142
|
|
138
143
|
|
144
|
+
class LoraModuleSpec(BaseModel):
|
145
|
+
"""
|
146
|
+
Descriptor for a LoRA (Low-Rank Adaptation) module used in fine-tuning base models.
|
147
|
+
|
148
|
+
This class is used to define a single fine-tuned module that can be loaded during
|
149
|
+
multi-model deployment alongside a base model.
|
150
|
+
|
151
|
+
Attributes
|
152
|
+
----------
|
153
|
+
model_id : str
|
154
|
+
The OCID of the fine-tuned model registered in the OCI Model Catalog.
|
155
|
+
model_name : Optional[str]
|
156
|
+
The unique name used to route inference requests to this model variant.
|
157
|
+
model_path : Optional[str]
|
158
|
+
The relative path within the artifact pointing to the LoRA adapter weights.
|
159
|
+
"""
|
160
|
+
|
161
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
162
|
+
|
163
|
+
model_id: str = Field(
|
164
|
+
...,
|
165
|
+
description="OCID of the fine-tuned model (must be registered in the Model Catalog).",
|
166
|
+
)
|
167
|
+
model_name: Optional[str] = Field(
|
168
|
+
default=None,
|
169
|
+
description="Name assigned to the fine-tuned model for serving (used as inference route).",
|
170
|
+
)
|
171
|
+
model_path: Optional[str] = Field(
|
172
|
+
default=None,
|
173
|
+
description="Relative path to the LoRA weights inside the model artifact.",
|
174
|
+
)
|
175
|
+
|
176
|
+
@model_validator(mode="before")
|
177
|
+
@classmethod
|
178
|
+
def validate_lora_module(cls, data: dict) -> dict:
|
179
|
+
"""Validates that required structure exists for a LoRA module."""
|
180
|
+
if "model_id" not in data or not data["model_id"]:
|
181
|
+
raise ValueError("Missing required field: 'model_id' for fine-tuned model.")
|
182
|
+
return data
|
183
|
+
|
184
|
+
|
139
185
|
class AquaMultiModelRef(Serializable):
|
140
186
|
"""
|
141
187
|
Lightweight model descriptor used for multi-model deployment.
|
@@ -157,7 +203,7 @@ class AquaMultiModelRef(Serializable):
|
|
157
203
|
Optional environment variables to override during deployment.
|
158
204
|
artifact_location : Optional[str]
|
159
205
|
Artifact path of model in the multimodel group.
|
160
|
-
|
206
|
+
fine_tune_weights : Optional[List[LoraModuleSpec]]
|
161
207
|
For fine tuned models, the artifact path of the modified model weights
|
162
208
|
"""
|
163
209
|
|
@@ -166,17 +212,37 @@ class AquaMultiModelRef(Serializable):
|
|
166
212
|
gpu_count: Optional[int] = Field(
|
167
213
|
None, description="The gpu count allocation for the model."
|
168
214
|
)
|
169
|
-
model_task: Optional[str] = Field(
|
215
|
+
model_task: Optional[str] = Field(
|
216
|
+
None,
|
217
|
+
description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType",
|
218
|
+
)
|
170
219
|
env_var: Optional[dict] = Field(
|
171
220
|
default_factory=dict, description="The environment variables of the model."
|
172
221
|
)
|
173
222
|
artifact_location: Optional[str] = Field(
|
174
223
|
None, description="Artifact path of model in the multimodel group."
|
175
224
|
)
|
176
|
-
|
177
|
-
None,
|
225
|
+
fine_tune_weights: Optional[List[LoraModuleSpec]] = Field(
|
226
|
+
None,
|
227
|
+
description="For fine tuned models, the artifact path of the modified model weights",
|
178
228
|
)
|
179
229
|
|
230
|
+
def all_model_ids(self) -> List[str]:
|
231
|
+
"""
|
232
|
+
Returns all associated model OCIDs, including the base model and any fine-tuned models.
|
233
|
+
|
234
|
+
Returns
|
235
|
+
-------
|
236
|
+
List[str]
|
237
|
+
A list of all model OCIDs associated with this multi-model reference.
|
238
|
+
"""
|
239
|
+
ids = {self.model_id}
|
240
|
+
if self.fine_tune_weights:
|
241
|
+
ids.update(
|
242
|
+
module.model_id for module in self.fine_tune_weights if module.model_id
|
243
|
+
)
|
244
|
+
return list(ids)
|
245
|
+
|
180
246
|
class Config:
|
181
247
|
extra = "ignore"
|
182
248
|
protected_namespaces = ()
|
ads/aqua/common/utils.py
CHANGED
@@ -870,6 +870,41 @@ def get_combined_params(params1: str = None, params2: str = None) -> str:
|
|
870
870
|
return " ".join(combined_params)
|
871
871
|
|
872
872
|
|
873
|
+
def find_restricted_params(
|
874
|
+
default_params: Union[str, List[str]],
|
875
|
+
user_params: Union[str, List[str]],
|
876
|
+
container_family: str,
|
877
|
+
) -> List[str]:
|
878
|
+
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
|
879
|
+
The default parameters coming from the container index json file cannot be overridden.
|
880
|
+
|
881
|
+
Parameters
|
882
|
+
----------
|
883
|
+
default_params:
|
884
|
+
Inference container parameter string with default values.
|
885
|
+
user_params:
|
886
|
+
Inference container parameter string with user provided values.
|
887
|
+
container_family: str
|
888
|
+
The image family of model deployment container runtime.
|
889
|
+
|
890
|
+
Returns
|
891
|
+
-------
|
892
|
+
A list with params keys common between params1 and params2.
|
893
|
+
|
894
|
+
"""
|
895
|
+
restricted_params = []
|
896
|
+
if default_params and user_params:
|
897
|
+
default_params_dict = get_params_dict(default_params)
|
898
|
+
user_params_dict = get_params_dict(user_params)
|
899
|
+
|
900
|
+
restricted_params_set = get_restricted_params_by_container(container_family)
|
901
|
+
for key, _items in user_params_dict.items():
|
902
|
+
if key in default_params_dict or key in restricted_params_set:
|
903
|
+
restricted_params.append(key.lstrip("-"))
|
904
|
+
|
905
|
+
return restricted_params
|
906
|
+
|
907
|
+
|
873
908
|
def build_params_string(params: dict) -> str:
|
874
909
|
"""Builds params string from params dict
|
875
910
|
|
ads/aqua/constants.py
CHANGED
@@ -55,6 +55,8 @@ SERVICE_MANAGED_CONTAINER_URI_SCHEME = "dsmc://"
|
|
55
55
|
SUPPORTED_FILE_FORMATS = ["jsonl"]
|
56
56
|
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
|
57
57
|
|
58
|
+
AQUA_CHAT_TEMPLATE_METADATA_KEY = "chat_template"
|
59
|
+
|
58
60
|
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = {
|
59
61
|
"datasciencemodel": "models",
|
60
62
|
"datasciencemodeldeployment": "model-deployments",
|
@@ -727,10 +727,11 @@ class AquaEvaluationApp(AquaApp):
|
|
727
727
|
raise AquaRuntimeError(error_message) from ex
|
728
728
|
|
729
729
|
# Build the list of valid model names from custom metadata.
|
730
|
-
model_names = [
|
731
|
-
|
732
|
-
|
733
|
-
|
730
|
+
model_names = []
|
731
|
+
for metadata in multi_model_metadata:
|
732
|
+
model = AquaMultiModelRef(**metadata)
|
733
|
+
model_names.append(model.model_name)
|
734
|
+
model_names.extend(ft.model_name for ft in (model.fine_tune_weights or []) if ft.model_name)
|
734
735
|
|
735
736
|
# Check if the provided model name is among the valid names.
|
736
737
|
if user_model_name not in model_names:
|