oracle-ads 2.13.4__py3-none-any.whl → 2.13.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/app.py +6 -0
- ads/aqua/client/openai_client.py +305 -0
- ads/aqua/common/entities.py +224 -2
- ads/aqua/common/enums.py +17 -0
- ads/aqua/common/utils.py +143 -3
- ads/aqua/config/container_config.py +9 -0
- ads/aqua/constants.py +29 -1
- ads/aqua/evaluation/entities.py +6 -1
- ads/aqua/evaluation/evaluation.py +191 -7
- ads/aqua/extension/aqua_ws_msg_handler.py +6 -36
- ads/aqua/extension/base_handler.py +13 -71
- ads/aqua/extension/deployment_handler.py +67 -76
- ads/aqua/extension/errors.py +19 -0
- ads/aqua/extension/utils.py +114 -2
- ads/aqua/finetuning/finetuning.py +50 -1
- ads/aqua/model/constants.py +3 -0
- ads/aqua/model/enums.py +5 -0
- ads/aqua/model/model.py +247 -24
- ads/aqua/modeldeployment/deployment.py +671 -152
- ads/aqua/modeldeployment/entities.py +551 -42
- ads/aqua/modeldeployment/inference.py +4 -5
- ads/aqua/modeldeployment/utils.py +525 -0
- ads/aqua/resources/gpu_shapes_index.json +94 -0
- ads/dataset/recommendation.py +11 -20
- ads/opctl/operator/lowcode/pii/model/report.py +9 -16
- ads/opctl/utils.py +1 -1
- {oracle_ads-2.13.4.dist-info → oracle_ads-2.13.6.dist-info}/METADATA +1 -1
- {oracle_ads-2.13.4.dist-info → oracle_ads-2.13.6.dist-info}/RECORD +31 -28
- {oracle_ads-2.13.4.dist-info → oracle_ads-2.13.6.dist-info}/WHEEL +0 -0
- {oracle_ads-2.13.4.dist-info → oracle_ads-2.13.6.dist-info}/entry_points.txt +0 -0
- {oracle_ads-2.13.4.dist-info → oracle_ads-2.13.6.dist-info}/licenses/LICENSE.txt +0 -0
ads/aqua/app.py
CHANGED
@@ -6,9 +6,11 @@ import json
|
|
6
6
|
import os
|
7
7
|
import traceback
|
8
8
|
from dataclasses import fields
|
9
|
+
from datetime import datetime, timedelta
|
9
10
|
from typing import Any, Dict, Optional, Union
|
10
11
|
|
11
12
|
import oci
|
13
|
+
from cachetools import TTLCache, cached
|
12
14
|
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
|
13
15
|
|
14
16
|
from ads import set_auth
|
@@ -269,6 +271,7 @@ class AquaApp:
|
|
269
271
|
logger.info(f"Artifact not found in model {model_id}.")
|
270
272
|
return False
|
271
273
|
|
274
|
+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
|
272
275
|
def get_config(
|
273
276
|
self,
|
274
277
|
model_id: str,
|
@@ -337,6 +340,9 @@ class AquaApp:
|
|
337
340
|
config_file_path = os.path.join(config_path, config_file_name)
|
338
341
|
if is_path_exists(config_file_path):
|
339
342
|
try:
|
343
|
+
logger.debug(
|
344
|
+
f"Loading config: `{config_file_name}` from `{config_path}`"
|
345
|
+
)
|
340
346
|
config = load_config(
|
341
347
|
config_path,
|
342
348
|
config_file_name=config_file_name,
|
@@ -0,0 +1,305 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# Copyright (c) 2025 Oracle and/or its affiliates.
|
3
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
+
|
5
|
+
import json
|
6
|
+
import logging
|
7
|
+
import re
|
8
|
+
from typing import Any, Dict, Optional
|
9
|
+
|
10
|
+
import httpx
|
11
|
+
from git import Union
|
12
|
+
|
13
|
+
from ads.aqua.client.client import get_async_httpx_client, get_httpx_client
|
14
|
+
from ads.common.extended_enum import ExtendedEnum
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
DEFAULT_TIMEOUT = httpx.Timeout(timeout=600, connect=5.0)
|
19
|
+
DEFAULT_MAX_RETRIES = 2
|
20
|
+
|
21
|
+
|
22
|
+
try:
|
23
|
+
import openai
|
24
|
+
except ImportError as e:
|
25
|
+
raise ModuleNotFoundError(
|
26
|
+
"The custom OpenAI client requires the `openai-python` package. "
|
27
|
+
"Please install it with `pip install openai`."
|
28
|
+
) from e
|
29
|
+
|
30
|
+
|
31
|
+
class ModelDeploymentBaseEndpoint(ExtendedEnum):
|
32
|
+
"""Supported base endpoints for model deployments."""
|
33
|
+
|
34
|
+
PREDICT = "predict"
|
35
|
+
PREDICT_WITH_RESPONSE_STREAM = "predictwithresponsestream"
|
36
|
+
|
37
|
+
|
38
|
+
class AquaOpenAIMixin:
|
39
|
+
"""
|
40
|
+
Mixin that provides common logic to patch HTTP request headers and URLs
|
41
|
+
for compatibility with the OCI Model Deployment service using the OpenAI API schema.
|
42
|
+
"""
|
43
|
+
|
44
|
+
def _patch_route(self, original_path: str) -> str:
|
45
|
+
"""
|
46
|
+
Extracts and formats the OpenAI-style route path from a full request path.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
original_path (str): The full URL path from the incoming request.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
str: The normalized OpenAI-compatible route path (e.g., '/v1/chat/completions').
|
53
|
+
"""
|
54
|
+
normalized_path = original_path.lower().rstrip("/")
|
55
|
+
|
56
|
+
match = re.search(r"/predict(withresponsestream)?", normalized_path)
|
57
|
+
if not match:
|
58
|
+
logger.debug("Route header cannot be resolved from path: %s", original_path)
|
59
|
+
return ""
|
60
|
+
|
61
|
+
route_suffix = normalized_path[match.end() :].lstrip("/")
|
62
|
+
if not route_suffix:
|
63
|
+
logger.warning(
|
64
|
+
"Missing OpenAI route suffix after '/predict'. "
|
65
|
+
"Expected something like '/v1/completions'."
|
66
|
+
)
|
67
|
+
return ""
|
68
|
+
|
69
|
+
if not route_suffix.startswith("v"):
|
70
|
+
logger.warning(
|
71
|
+
"Route suffix does not start with a version prefix (e.g., '/v1'). "
|
72
|
+
"This may lead to compatibility issues with OpenAI-style endpoints. "
|
73
|
+
"Consider updating the URL to include a version prefix, "
|
74
|
+
"such as '/predict/v1' or '/predictwithresponsestream/v1'."
|
75
|
+
)
|
76
|
+
# route_suffix = f"v1/{route_suffix}"
|
77
|
+
|
78
|
+
return f"/{route_suffix}"
|
79
|
+
|
80
|
+
def _patch_streaming(self, request: httpx.Request) -> None:
|
81
|
+
"""
|
82
|
+
Sets the 'enable-streaming' header based on the JSON request body contents.
|
83
|
+
|
84
|
+
If the request body contains `"stream": true`, the `enable-streaming` header is set to "true".
|
85
|
+
Otherwise, it defaults to "false".
|
86
|
+
|
87
|
+
Args:
|
88
|
+
request (httpx.Request): The outgoing HTTPX request.
|
89
|
+
"""
|
90
|
+
streaming_enabled = "false"
|
91
|
+
content_type = request.headers.get("Content-Type", "")
|
92
|
+
|
93
|
+
if "application/json" in content_type and request.content:
|
94
|
+
try:
|
95
|
+
body = (
|
96
|
+
request.content.decode("utf-8")
|
97
|
+
if isinstance(request.content, bytes)
|
98
|
+
else request.content
|
99
|
+
)
|
100
|
+
payload = json.loads(body)
|
101
|
+
if payload.get("stream") is True:
|
102
|
+
streaming_enabled = "true"
|
103
|
+
except Exception as e:
|
104
|
+
logger.exception(
|
105
|
+
"Failed to parse request JSON body for streaming flag: %s", e
|
106
|
+
)
|
107
|
+
|
108
|
+
request.headers.setdefault("enable-streaming", streaming_enabled)
|
109
|
+
logger.debug("Patched 'enable-streaming' header: %s", streaming_enabled)
|
110
|
+
|
111
|
+
def _patch_headers(self, request: httpx.Request) -> None:
|
112
|
+
"""
|
113
|
+
Patches request headers by injecting OpenAI-compatible values:
|
114
|
+
- `enable-streaming` for streaming-aware endpoints
|
115
|
+
- `route` for backend routing
|
116
|
+
|
117
|
+
Args:
|
118
|
+
request (httpx.Request): The outgoing HTTPX request.
|
119
|
+
"""
|
120
|
+
self._patch_streaming(request)
|
121
|
+
route_header = self._patch_route(request.url.path)
|
122
|
+
request.headers.setdefault("route", route_header)
|
123
|
+
logger.debug("Patched 'route' header: %s", route_header)
|
124
|
+
|
125
|
+
def _patch_url(self) -> httpx.URL:
|
126
|
+
"""
|
127
|
+
Strips any suffixes from the base URL to retain only the `/predict` or `/predictwithresponsestream` path.
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
httpx.URL: The normalized base URL with the correct model deployment path.
|
131
|
+
"""
|
132
|
+
base_path = f"{self.base_url.path.lower().rstrip('/')}/"
|
133
|
+
match = re.search(r"/predict(withresponsestream)?/", base_path)
|
134
|
+
if match:
|
135
|
+
trimmed = base_path[: match.end() - 1]
|
136
|
+
return self.base_url.copy_with(path=trimmed)
|
137
|
+
|
138
|
+
logger.debug("Could not determine a valid endpoint from path: %s", base_path)
|
139
|
+
return self.base_url
|
140
|
+
|
141
|
+
def _prepare_request_common(self, request: httpx.Request) -> None:
|
142
|
+
"""
|
143
|
+
Common preparation routine for all requests.
|
144
|
+
|
145
|
+
This includes:
|
146
|
+
- Patching headers with streaming and routing info.
|
147
|
+
- Normalizing the URL path to include only `/predict` or `/predictwithresponsestream`.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
request (httpx.Request): The outgoing HTTPX request.
|
151
|
+
"""
|
152
|
+
# Patch headers
|
153
|
+
logger.debug("Original headers: %s", request.headers)
|
154
|
+
self._patch_headers(request)
|
155
|
+
logger.debug("Headers after patching: %s", request.headers)
|
156
|
+
|
157
|
+
# Patch URL
|
158
|
+
logger.debug("URL before patching: %s", request.url)
|
159
|
+
request.url = self._patch_url()
|
160
|
+
logger.debug("URL after patching: %s", request.url)
|
161
|
+
|
162
|
+
|
163
|
+
class OpenAI(openai.OpenAI, AquaOpenAIMixin):
|
164
|
+
def __init__(
|
165
|
+
self,
|
166
|
+
*,
|
167
|
+
api_key: Optional[str] = None,
|
168
|
+
organization: Optional[str] = None,
|
169
|
+
project: Optional[str] = None,
|
170
|
+
base_url: Optional[Union[str, httpx.URL]] = None,
|
171
|
+
websocket_base_url: Optional[Union[str, httpx.URL]] = None,
|
172
|
+
timeout: Optional[Union[float, httpx.Timeout]] = DEFAULT_TIMEOUT,
|
173
|
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
174
|
+
default_headers: Optional[Dict[str, str]] = None,
|
175
|
+
default_query: Optional[Dict[str, object]] = None,
|
176
|
+
http_client: Optional[httpx.Client] = None,
|
177
|
+
http_client_kwargs: Optional[Dict[str, Any]] = None,
|
178
|
+
_strict_response_validation: bool = False,
|
179
|
+
**kwargs: Any,
|
180
|
+
) -> None:
|
181
|
+
"""
|
182
|
+
Construct a new synchronous OpenAI client instance.
|
183
|
+
|
184
|
+
If no http_client is provided, one will be automatically created using ads.aqua.get_httpx_client().
|
185
|
+
|
186
|
+
Args:
|
187
|
+
api_key (str, optional): API key for authentication. Defaults to env variable OPENAI_API_KEY.
|
188
|
+
organization (str, optional): Organization ID. Defaults to env variable OPENAI_ORG_ID.
|
189
|
+
project (str, optional): Project ID. Defaults to env variable OPENAI_PROJECT_ID.
|
190
|
+
base_url (str | httpx.URL, optional): Base URL for the API.
|
191
|
+
websocket_base_url (str | httpx.URL, optional): Base URL for WebSocket connections.
|
192
|
+
timeout (float | httpx.Timeout, optional): Timeout for API requests.
|
193
|
+
max_retries (int, optional): Maximum number of retries for API requests.
|
194
|
+
default_headers (dict[str, str], optional): Additional headers.
|
195
|
+
default_query (dict[str, object], optional): Additional query parameters.
|
196
|
+
http_client (httpx.Client, optional): Custom HTTP client; if not provided, one will be auto-created.
|
197
|
+
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
|
198
|
+
_strict_response_validation (bool, optional): Enable strict response validation.
|
199
|
+
**kwargs: Additional keyword arguments passed to the parent __init__.
|
200
|
+
"""
|
201
|
+
if http_client is None:
|
202
|
+
logger.debug(
|
203
|
+
"No http_client provided; auto-creating one using ads.aqua.get_httpx_client()"
|
204
|
+
)
|
205
|
+
http_client = get_httpx_client(**(http_client_kwargs or {}))
|
206
|
+
if not api_key:
|
207
|
+
logger.debug("API key not provided; using default placeholder for OCI.")
|
208
|
+
api_key = "OCI"
|
209
|
+
|
210
|
+
super().__init__(
|
211
|
+
api_key=api_key,
|
212
|
+
organization=organization,
|
213
|
+
project=project,
|
214
|
+
base_url=base_url,
|
215
|
+
websocket_base_url=websocket_base_url,
|
216
|
+
timeout=timeout,
|
217
|
+
max_retries=max_retries,
|
218
|
+
default_headers=default_headers,
|
219
|
+
default_query=default_query,
|
220
|
+
http_client=http_client,
|
221
|
+
_strict_response_validation=_strict_response_validation,
|
222
|
+
**kwargs,
|
223
|
+
)
|
224
|
+
|
225
|
+
def _prepare_request(self, request: httpx.Request) -> None:
|
226
|
+
"""
|
227
|
+
Prepare the synchronous HTTP request by applying common modifications.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
request (httpx.Request): The outgoing HTTP request.
|
231
|
+
"""
|
232
|
+
self._prepare_request_common(request)
|
233
|
+
|
234
|
+
|
235
|
+
class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
|
236
|
+
def __init__(
|
237
|
+
self,
|
238
|
+
*,
|
239
|
+
api_key: Optional[str] = None,
|
240
|
+
organization: Optional[str] = None,
|
241
|
+
project: Optional[str] = None,
|
242
|
+
base_url: Optional[Union[str, httpx.URL]] = None,
|
243
|
+
websocket_base_url: Optional[Union[str, httpx.URL]] = None,
|
244
|
+
timeout: Optional[Union[float, httpx.Timeout]] = DEFAULT_TIMEOUT,
|
245
|
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
246
|
+
default_headers: Optional[Dict[str, str]] = None,
|
247
|
+
default_query: Optional[Dict[str, object]] = None,
|
248
|
+
http_client: Optional[httpx.Client] = None,
|
249
|
+
http_client_kwargs: Optional[Dict[str, Any]] = None,
|
250
|
+
_strict_response_validation: bool = False,
|
251
|
+
**kwargs: Any,
|
252
|
+
) -> None:
|
253
|
+
"""
|
254
|
+
Construct a new asynchronous AsyncOpenAI client instance.
|
255
|
+
|
256
|
+
If no http_client is provided, one will be automatically created using
|
257
|
+
ads.aqua.get_async_httpx_client().
|
258
|
+
|
259
|
+
Args:
|
260
|
+
api_key (str, optional): API key for authentication. Defaults to env variable OPENAI_API_KEY.
|
261
|
+
organization (str, optional): Organization ID.
|
262
|
+
project (str, optional): Project ID.
|
263
|
+
base_url (str | httpx.URL, optional): Base URL for the API.
|
264
|
+
websocket_base_url (str | httpx.URL, optional): Base URL for WebSocket connections.
|
265
|
+
timeout (float | httpx.Timeout, optional): Timeout for API requests.
|
266
|
+
max_retries (int, optional): Maximum number of retries for API requests.
|
267
|
+
default_headers (dict[str, str], optional): Additional headers.
|
268
|
+
default_query (dict[str, object], optional): Additional query parameters.
|
269
|
+
http_client (httpx.AsyncClient, optional): Custom asynchronous HTTP client; if not provided, one will be auto-created.
|
270
|
+
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
|
271
|
+
_strict_response_validation (bool, optional): Enable strict response validation.
|
272
|
+
**kwargs: Additional keyword arguments passed to the parent __init__.
|
273
|
+
"""
|
274
|
+
if http_client is None:
|
275
|
+
logger.debug(
|
276
|
+
"No async http_client provided; auto-creating one using ads.aqua.get_async_httpx_client()"
|
277
|
+
)
|
278
|
+
http_client = get_async_httpx_client(**(http_client_kwargs or {}))
|
279
|
+
if not api_key:
|
280
|
+
logger.debug("API key not provided; using default placeholder for OCI.")
|
281
|
+
api_key = "OCI"
|
282
|
+
|
283
|
+
super().__init__(
|
284
|
+
api_key=api_key,
|
285
|
+
organization=organization,
|
286
|
+
project=project,
|
287
|
+
base_url=base_url,
|
288
|
+
websocket_base_url=websocket_base_url,
|
289
|
+
timeout=timeout,
|
290
|
+
max_retries=max_retries,
|
291
|
+
default_headers=default_headers,
|
292
|
+
default_query=default_query,
|
293
|
+
http_client=http_client,
|
294
|
+
_strict_response_validation=_strict_response_validation,
|
295
|
+
**kwargs,
|
296
|
+
)
|
297
|
+
|
298
|
+
async def _prepare_request(self, request: httpx.Request) -> None:
|
299
|
+
"""
|
300
|
+
Asynchronously prepare the HTTP request by applying common modifications.
|
301
|
+
|
302
|
+
Args:
|
303
|
+
request (httpx.Request): The outgoing HTTP request.
|
304
|
+
"""
|
305
|
+
self._prepare_request_common(request)
|
ads/aqua/common/entities.py
CHANGED
@@ -2,10 +2,14 @@
|
|
2
2
|
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
|
+
import re
|
5
6
|
from typing import Any, Dict, Optional
|
6
7
|
|
7
8
|
from oci.data_science.models import Model
|
8
|
-
from pydantic import BaseModel, Field
|
9
|
+
from pydantic import BaseModel, Field, model_validator
|
10
|
+
|
11
|
+
from ads.aqua import logger
|
12
|
+
from ads.aqua.config.utils.serializer import Serializable
|
9
13
|
|
10
14
|
|
11
15
|
class ContainerSpec:
|
@@ -25,7 +29,6 @@ class ContainerSpec:
|
|
25
29
|
class ModelConfigResult(BaseModel):
|
26
30
|
"""
|
27
31
|
Represents the result of getting the AQUA model configuration.
|
28
|
-
|
29
32
|
Attributes:
|
30
33
|
model_details (Dict[str, Any]): A dictionary containing model details extracted from OCI.
|
31
34
|
config (Dict[str, Any]): A dictionary of the loaded configuration.
|
@@ -42,3 +45,222 @@ class ModelConfigResult(BaseModel):
|
|
42
45
|
extra = "ignore"
|
43
46
|
arbitrary_types_allowed = True
|
44
47
|
protected_namespaces = ()
|
48
|
+
|
49
|
+
|
50
|
+
class GPUSpecs(Serializable):
|
51
|
+
"""
|
52
|
+
Represents the GPU specifications for a compute instance.
|
53
|
+
"""
|
54
|
+
|
55
|
+
gpu_memory_in_gbs: Optional[int] = Field(
|
56
|
+
default=None, description="The amount of GPU memory available (in GB)."
|
57
|
+
)
|
58
|
+
gpu_count: Optional[int] = Field(
|
59
|
+
default=None, description="The number of GPUs available."
|
60
|
+
)
|
61
|
+
gpu_type: Optional[str] = Field(
|
62
|
+
default=None, description="The type of GPU (e.g., 'V100, A100, H100')."
|
63
|
+
)
|
64
|
+
|
65
|
+
|
66
|
+
class GPUShapesIndex(Serializable):
|
67
|
+
"""
|
68
|
+
Represents the index of GPU shapes.
|
69
|
+
|
70
|
+
Attributes
|
71
|
+
----------
|
72
|
+
shapes (Dict[str, GPUSpecs]): A mapping of compute shape names to their GPU specifications.
|
73
|
+
"""
|
74
|
+
|
75
|
+
shapes: Dict[str, GPUSpecs] = Field(
|
76
|
+
default_factory=dict,
|
77
|
+
description="Mapping of shape names to GPU specifications.",
|
78
|
+
)
|
79
|
+
|
80
|
+
|
81
|
+
class ComputeShapeSummary(Serializable):
|
82
|
+
"""
|
83
|
+
Represents the specifications of a compute instance's shape.
|
84
|
+
"""
|
85
|
+
|
86
|
+
core_count: Optional[int] = Field(
|
87
|
+
default=None, description="The number of CPU cores available."
|
88
|
+
)
|
89
|
+
memory_in_gbs: Optional[int] = Field(
|
90
|
+
default=None, description="The amount of memory (in GB) available."
|
91
|
+
)
|
92
|
+
name: Optional[str] = Field(
|
93
|
+
default=None, description="The name identifier of the compute shape."
|
94
|
+
)
|
95
|
+
shape_series: Optional[str] = Field(
|
96
|
+
default=None, description="The series or category of the compute shape."
|
97
|
+
)
|
98
|
+
gpu_specs: Optional[GPUSpecs] = Field(
|
99
|
+
default=None,
|
100
|
+
description="The GPU specifications associated with the compute shape.",
|
101
|
+
)
|
102
|
+
|
103
|
+
@model_validator(mode="after")
|
104
|
+
@classmethod
|
105
|
+
def set_gpu_specs(cls, model: "ComputeShapeSummary") -> "ComputeShapeSummary":
|
106
|
+
"""
|
107
|
+
Validates and populates GPU specifications if the shape_series indicates a GPU-based shape.
|
108
|
+
|
109
|
+
- If the shape_series contains "GPU", the validator first checks if the shape name exists
|
110
|
+
in the GPU_SPECS dictionary. If found, it creates a GPUSpecs instance with the corresponding data.
|
111
|
+
- If the shape is not found in the GPU_SPECS, it attempts to extract the GPU count from the shape name
|
112
|
+
using a regex pattern (looking for a number following a dot at the end of the name).
|
113
|
+
|
114
|
+
The information about shapes is taken from: https://docs.oracle.com/en-us/iaas/data-science/using/supported-shapes.htm
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
ComputeShapeSummary: The updated instance with gpu_specs populated if applicable.
|
118
|
+
"""
|
119
|
+
try:
|
120
|
+
if (
|
121
|
+
model.shape_series
|
122
|
+
and "GPU" in model.shape_series.upper()
|
123
|
+
and model.name
|
124
|
+
and not model.gpu_specs
|
125
|
+
):
|
126
|
+
# Try to extract gpu_count from the shape name using a regex (e.g., "VM.GPU3.2" -> gpu_count=2)
|
127
|
+
match = re.search(r"\.(\d+)$", model.name)
|
128
|
+
if match:
|
129
|
+
gpu_count = int(match.group(1))
|
130
|
+
model.gpu_specs = GPUSpecs(gpu_count=gpu_count)
|
131
|
+
except Exception as err:
|
132
|
+
logger.debug(
|
133
|
+
f"Error occurred in attempt to extract GPU specification for the f{model.name}. "
|
134
|
+
f"Details: {err}"
|
135
|
+
)
|
136
|
+
return model
|
137
|
+
|
138
|
+
|
139
|
+
class AquaMultiModelRef(Serializable):
|
140
|
+
"""
|
141
|
+
Lightweight model descriptor used for multi-model deployment.
|
142
|
+
|
143
|
+
This class only contains essential details
|
144
|
+
required to fetch complete model metadata and deploy models.
|
145
|
+
|
146
|
+
Attributes
|
147
|
+
----------
|
148
|
+
model_id : str
|
149
|
+
The unique identifier of the model.
|
150
|
+
model_name : Optional[str]
|
151
|
+
The name of the model.
|
152
|
+
gpu_count : Optional[int]
|
153
|
+
Number of GPUs required for deployment.
|
154
|
+
env_var : Optional[Dict[str, Any]]
|
155
|
+
Optional environment variables to override during deployment.
|
156
|
+
artifact_location : Optional[str]
|
157
|
+
Artifact path of model in the multimodel group.
|
158
|
+
"""
|
159
|
+
|
160
|
+
model_id: str = Field(..., description="The model OCID to deploy.")
|
161
|
+
model_name: Optional[str] = Field(None, description="The name of model.")
|
162
|
+
gpu_count: Optional[int] = Field(
|
163
|
+
None, description="The gpu count allocation for the model."
|
164
|
+
)
|
165
|
+
env_var: Optional[dict] = Field(
|
166
|
+
default_factory=dict, description="The environment variables of the model."
|
167
|
+
)
|
168
|
+
artifact_location: Optional[str] = Field(
|
169
|
+
None, description="Artifact path of model in the multimodel group."
|
170
|
+
)
|
171
|
+
|
172
|
+
class Config:
|
173
|
+
extra = "ignore"
|
174
|
+
protected_namespaces = ()
|
175
|
+
|
176
|
+
|
177
|
+
class ContainerPath(Serializable):
|
178
|
+
"""
|
179
|
+
Represents a parsed container path, extracting the path, name, and version.
|
180
|
+
|
181
|
+
This model is designed to parse a container path string of the format
|
182
|
+
'<image_path>:<version>'. It extracts the following components:
|
183
|
+
- `path`: The full path up to the version.
|
184
|
+
- `name`: The last segment of the path, representing the image name.
|
185
|
+
- `version`: The version number following the final colon.
|
186
|
+
|
187
|
+
Example Usage:
|
188
|
+
--------------
|
189
|
+
>>> container = ContainerPath(full_path="iad.ocir.io/ociodscdev/odsc-llm-evaluate:0.1.2.9")
|
190
|
+
>>> container.path
|
191
|
+
'iad.ocir.io/ociodscdev/odsc-llm-evaluate'
|
192
|
+
>>> container.name
|
193
|
+
'odsc-llm-evaluate'
|
194
|
+
>>> container.version
|
195
|
+
'0.1.2.9'
|
196
|
+
|
197
|
+
>>> container = ContainerPath(full_path="custom-scheme://path/to/versioned-model:2.5.1")
|
198
|
+
>>> container.path
|
199
|
+
'custom-scheme://path/to/versioned-model'
|
200
|
+
>>> container.name
|
201
|
+
'versioned-model'
|
202
|
+
>>> container.version
|
203
|
+
'2.5.1'
|
204
|
+
|
205
|
+
Attributes
|
206
|
+
----------
|
207
|
+
full_path : str
|
208
|
+
The complete container path string to be parsed.
|
209
|
+
path : Optional[str]
|
210
|
+
The full path up to the version (e.g., 'iad.ocir.io/ociodscdev/odsc-llm-evaluate').
|
211
|
+
name : Optional[str]
|
212
|
+
The image name, which is the last segment of `path` (e.g., 'odsc-llm-evaluate').
|
213
|
+
version : Optional[str]
|
214
|
+
The version number following the final colon in the path (e.g., '0.1.2.9').
|
215
|
+
|
216
|
+
Methods
|
217
|
+
-------
|
218
|
+
validate(values: Any) -> Any
|
219
|
+
Validates and parses the `full_path`, extracting `path`, `name`, and `version`.
|
220
|
+
"""
|
221
|
+
|
222
|
+
full_path: str
|
223
|
+
path: Optional[str] = None
|
224
|
+
name: Optional[str] = None
|
225
|
+
version: Optional[str] = None
|
226
|
+
|
227
|
+
@model_validator(mode="before")
|
228
|
+
@classmethod
|
229
|
+
def validate(cls, values: Any) -> Any:
|
230
|
+
"""
|
231
|
+
Validates and parses the full container path, extracting the image path, image name, and version.
|
232
|
+
|
233
|
+
Parameters
|
234
|
+
----------
|
235
|
+
values : dict
|
236
|
+
The dictionary of values being validated, containing 'full_path'.
|
237
|
+
|
238
|
+
Returns
|
239
|
+
-------
|
240
|
+
dict
|
241
|
+
Updated values dictionary with extracted 'path', 'name', and 'version'.
|
242
|
+
"""
|
243
|
+
full_path = values.get("full_path", "").strip()
|
244
|
+
|
245
|
+
# Regex to parse <image_path>:<version>
|
246
|
+
match = re.match(
|
247
|
+
r"^(?P<image_path>.+?)(?::(?P<image_version>[\w\.]+))?$", full_path
|
248
|
+
)
|
249
|
+
|
250
|
+
if not match:
|
251
|
+
raise ValueError(
|
252
|
+
"Invalid container path format. Expected format: '<image_path>:<version>'"
|
253
|
+
)
|
254
|
+
|
255
|
+
# Extract image_path and version
|
256
|
+
values["path"] = match.group("image_path")
|
257
|
+
values["version"] = match.group("image_version")
|
258
|
+
|
259
|
+
# Extract image_name as the last segment of image_path
|
260
|
+
values["name"] = values["path"].split("/")[-1]
|
261
|
+
|
262
|
+
return values
|
263
|
+
|
264
|
+
class Config:
|
265
|
+
extra = "ignore"
|
266
|
+
protected_namespaces = ()
|
ads/aqua/common/enums.py
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
|
+
from typing import Dict, List
|
6
|
+
|
5
7
|
from ads.common.extended_enum import ExtendedEnum
|
6
8
|
|
7
9
|
|
@@ -25,6 +27,7 @@ class Tags(ExtendedEnum):
|
|
25
27
|
AQUA_TAG = "OCI_AQUA"
|
26
28
|
AQUA_SERVICE_MODEL_TAG = "aqua_service_model"
|
27
29
|
AQUA_FINE_TUNED_MODEL_TAG = "aqua_fine_tuned_model"
|
30
|
+
AQUA_MODEL_ID_TAG = "aqua_model_id"
|
28
31
|
AQUA_MODEL_NAME_TAG = "aqua_model_name"
|
29
32
|
AQUA_EVALUATION = "aqua_evaluation"
|
30
33
|
AQUA_FINE_TUNING = "aqua_finetuning"
|
@@ -34,6 +37,7 @@ class Tags(ExtendedEnum):
|
|
34
37
|
AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
|
35
38
|
MODEL_FORMAT = "model_format"
|
36
39
|
MODEL_ARTIFACT_FILE = "model_file"
|
40
|
+
MULTIMODEL_TYPE_TAG = "aqua_multimodel"
|
37
41
|
|
38
42
|
|
39
43
|
class InferenceContainerType(ExtendedEnum):
|
@@ -44,6 +48,7 @@ class InferenceContainerType(ExtendedEnum):
|
|
44
48
|
|
45
49
|
class InferenceContainerTypeFamily(ExtendedEnum):
|
46
50
|
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
|
51
|
+
AQUA_VLLM_V1_CONTAINER_FAMILY = "odsc-vllm-serving-v1"
|
47
52
|
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
|
48
53
|
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
|
49
54
|
|
@@ -103,3 +108,15 @@ class ModelFormat(ExtendedEnum):
|
|
103
108
|
class Platform(ExtendedEnum):
|
104
109
|
ARM_CPU = "ARM_CPU"
|
105
110
|
NVIDIA_GPU = "NVIDIA_GPU"
|
111
|
+
|
112
|
+
|
113
|
+
# This dictionary defines compatibility groups for container families.
|
114
|
+
# The structure is:
|
115
|
+
# - Key: The preferred container family to use when multiple compatible families are selected.
|
116
|
+
# - Value: A list of all compatible families (including the preferred one).
|
117
|
+
CONTAINER_FAMILY_COMPATIBILITY: Dict[str, List[str]] = {
|
118
|
+
InferenceContainerTypeFamily.AQUA_VLLM_V1_CONTAINER_FAMILY: [
|
119
|
+
InferenceContainerTypeFamily.AQUA_VLLM_V1_CONTAINER_FAMILY,
|
120
|
+
InferenceContainerTypeFamily.AQUA_VLLM_CONTAINER_FAMILY,
|
121
|
+
],
|
122
|
+
}
|