oracle-ads 2.10.1__py3-none-any.whl → 2.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +12 -0
- ads/aqua/base.py +324 -0
- ads/aqua/cli.py +19 -0
- ads/aqua/config/deployment_config_defaults.json +9 -0
- ads/aqua/config/resource_limit_names.json +7 -0
- ads/aqua/constants.py +45 -0
- ads/aqua/data.py +40 -0
- ads/aqua/decorator.py +101 -0
- ads/aqua/deployment.py +643 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation.py +1751 -0
- ads/aqua/exception.py +82 -0
- ads/aqua/extension/__init__.py +40 -0
- ads/aqua/extension/base_handler.py +138 -0
- ads/aqua/extension/common_handler.py +21 -0
- ads/aqua/extension/deployment_handler.py +202 -0
- ads/aqua/extension/evaluation_handler.py +135 -0
- ads/aqua/extension/finetune_handler.py +66 -0
- ads/aqua/extension/model_handler.py +59 -0
- ads/aqua/extension/ui_handler.py +201 -0
- ads/aqua/extension/utils.py +23 -0
- ads/aqua/finetune.py +579 -0
- ads/aqua/job.py +29 -0
- ads/aqua/model.py +819 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +459 -0
- ads/aqua/ui.py +453 -0
- ads/aqua/utils.py +715 -0
- ads/cli.py +37 -6
- ads/common/decorator/__init__.py +7 -3
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/object_storage_details.py +166 -7
- ads/common/oci_client.py +18 -1
- ads/common/oci_logging.py +2 -2
- ads/common/oci_mixin.py +4 -5
- ads/common/serializer.py +34 -5
- ads/common/utils.py +75 -10
- ads/config.py +40 -1
- ads/jobs/ads_job.py +43 -25
- ads/jobs/builders/infrastructure/base.py +4 -2
- ads/jobs/builders/infrastructure/dsc_job.py +49 -39
- ads/jobs/builders/runtimes/base.py +71 -1
- ads/jobs/builders/runtimes/container_runtime.py +4 -4
- ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
- ads/jobs/templates/driver_pytorch.py +27 -10
- ads/model/artifact_downloader.py +84 -14
- ads/model/artifact_uploader.py +25 -23
- ads/model/datascience_model.py +388 -38
- ads/model/deployment/model_deployment.py +10 -2
- ads/model/generic_model.py +8 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_metadata.py +1 -1
- ads/model/service/oci_datascience_model.py +34 -5
- ads/opctl/operator/lowcode/anomaly/README.md +2 -1
- ads/opctl/operator/lowcode/anomaly/__main__.py +10 -4
- ads/opctl/operator/lowcode/anomaly/environment.yaml +2 -1
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +12 -6
- ads/opctl/operator/lowcode/forecast/README.md +3 -2
- ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
- ads/opctl/operator/lowcode/forecast/model/automlx.py +12 -23
- ads/telemetry/base.py +62 -0
- ads/telemetry/client.py +105 -0
- ads/telemetry/telemetry.py +6 -3
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +37 -7
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +71 -36
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
ads/aqua/utils.py
ADDED
@@ -0,0 +1,715 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
"""AQUA utils and constants."""
|
6
|
+
import asyncio
|
7
|
+
import base64
|
8
|
+
import json
|
9
|
+
import logging
|
10
|
+
import os
|
11
|
+
import random
|
12
|
+
import re
|
13
|
+
import sys
|
14
|
+
from enum import Enum
|
15
|
+
from functools import wraps
|
16
|
+
from pathlib import Path
|
17
|
+
from string import Template
|
18
|
+
from typing import List, Union
|
19
|
+
|
20
|
+
import fsspec
|
21
|
+
import oci
|
22
|
+
from oci.data_science.models import JobRun, Model
|
23
|
+
|
24
|
+
from ads.aqua.constants import RqsAdditionalDetails
|
25
|
+
from ads.aqua.data import AquaResourceIdentifier, Tags
|
26
|
+
from ads.aqua.exception import AquaFileNotFoundError, AquaRuntimeError, AquaValueError
|
27
|
+
from ads.common.auth import default_signer
|
28
|
+
from ads.common.object_storage_details import ObjectStorageDetails
|
29
|
+
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
|
30
|
+
from ads.common.utils import get_console_link, upload_to_os
|
31
|
+
from ads.config import (
|
32
|
+
AQUA_CONFIG_FOLDER,
|
33
|
+
AQUA_SERVICE_MODELS_BUCKET,
|
34
|
+
TENANCY_OCID,
|
35
|
+
CONDA_BUCKET_NS,
|
36
|
+
)
|
37
|
+
from ads.model import DataScienceModel, ModelVersionSet
|
38
|
+
|
39
|
+
# TODO: allow the user to setup the logging level?
|
40
|
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
41
|
+
logger = logging.getLogger("ODSC_AQUA")
|
42
|
+
|
43
|
+
UNKNOWN = ""
|
44
|
+
UNKNOWN_DICT = {}
|
45
|
+
README = "README.md"
|
46
|
+
LICENSE_TXT = "config/LICENSE.txt"
|
47
|
+
DEPLOYMENT_CONFIG = "deployment_config.json"
|
48
|
+
CONTAINER_INDEX = "container_index.json"
|
49
|
+
EVALUATION_REPORT_JSON = "report.json"
|
50
|
+
EVALUATION_REPORT_MD = "report.md"
|
51
|
+
EVALUATION_REPORT = "report.html"
|
52
|
+
UNKNOWN_JSON_STR = "{}"
|
53
|
+
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = dict(
|
54
|
+
datasciencemodel="models",
|
55
|
+
datasciencemodeldeployment="model-deployments",
|
56
|
+
datasciencemodeldeploymentdev="model-deployments",
|
57
|
+
datasciencemodeldeploymentint="model-deployments",
|
58
|
+
datasciencemodeldeploymentpre="model-deployments",
|
59
|
+
datasciencejob="jobs",
|
60
|
+
datasciencejobrun="job-runs",
|
61
|
+
datasciencejobrundev="job-runs",
|
62
|
+
datasciencejobrunint="job-runs",
|
63
|
+
datasciencejobrunpre="job-runs",
|
64
|
+
datasciencemodelversionset="model-version-sets",
|
65
|
+
datasciencemodelversionsetpre="model-version-sets",
|
66
|
+
datasciencemodelversionsetint="model-version-sets",
|
67
|
+
datasciencemodelversionsetdev="model-version-sets",
|
68
|
+
)
|
69
|
+
FINE_TUNING_RUNTIME_CONTAINER = "iad.ocir.io/ociodscdev/aqua_ft_cuda121:0.3.17.20"
|
70
|
+
DEFAULT_FT_BLOCK_STORAGE_SIZE = 256
|
71
|
+
DEFAULT_FT_REPLICA = 1
|
72
|
+
DEFAULT_FT_BATCH_SIZE = 1
|
73
|
+
DEFAULT_FT_VALIDATION_SET_SIZE = 0.1
|
74
|
+
|
75
|
+
HF_MODELS = "/home/datascience/conda/pytorch21_p39_gpu_v1/"
|
76
|
+
MAXIMUM_ALLOWED_DATASET_IN_BYTE = 52428800 # 1024 x 1024 x 50 = 50MB
|
77
|
+
JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING = "ME_STANDALONE"
|
78
|
+
NB_SESSION_IDENTIFIER = "NB_SESSION_OCID"
|
79
|
+
LIFECYCLE_DETAILS_MISSING_JOBRUN = "The asscociated JobRun resource has been deleted."
|
80
|
+
READY_TO_DEPLOY_STATUS = "ACTIVE"
|
81
|
+
|
82
|
+
|
83
|
+
class LifecycleStatus(Enum):
|
84
|
+
UNKNOWN = ""
|
85
|
+
|
86
|
+
@property
|
87
|
+
def detail(self) -> str:
|
88
|
+
"""Returns the detail message corresponding to the status."""
|
89
|
+
return LIFECYCLE_DETAILS_MAPPING.get(
|
90
|
+
self.name, f"No detail available for the status {self.name}."
|
91
|
+
)
|
92
|
+
|
93
|
+
@staticmethod
|
94
|
+
def get_status(evaluation_status: str, job_run_status: str = None):
|
95
|
+
"""
|
96
|
+
Maps the combination of evaluation status and job run status to a standard status.
|
97
|
+
|
98
|
+
Parameters
|
99
|
+
----------
|
100
|
+
evaluation_status (str):
|
101
|
+
The status of the evaluation.
|
102
|
+
job_run_status (str):
|
103
|
+
The status of the job run.
|
104
|
+
|
105
|
+
Returns
|
106
|
+
-------
|
107
|
+
LifecycleStatus
|
108
|
+
The mapped status ("Completed", "In Progress", "Canceled").
|
109
|
+
"""
|
110
|
+
if not job_run_status:
|
111
|
+
logger.error("Failed to get jobrun state.")
|
112
|
+
# case1 : failed to create jobrun
|
113
|
+
# case2: jobrun is deleted - rqs cannot retreive deleted resource
|
114
|
+
return JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION
|
115
|
+
|
116
|
+
status = LifecycleStatus.UNKNOWN
|
117
|
+
if evaluation_status == Model.LIFECYCLE_STATE_ACTIVE:
|
118
|
+
if (
|
119
|
+
job_run_status == JobRun.LIFECYCLE_STATE_IN_PROGRESS
|
120
|
+
or job_run_status == JobRun.LIFECYCLE_STATE_ACCEPTED
|
121
|
+
):
|
122
|
+
status = JobRun.LIFECYCLE_STATE_IN_PROGRESS
|
123
|
+
elif (
|
124
|
+
job_run_status == JobRun.LIFECYCLE_STATE_FAILED
|
125
|
+
or job_run_status == JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION
|
126
|
+
):
|
127
|
+
status = JobRun.LIFECYCLE_STATE_FAILED
|
128
|
+
else:
|
129
|
+
status = job_run_status
|
130
|
+
else:
|
131
|
+
status = evaluation_status
|
132
|
+
|
133
|
+
return status
|
134
|
+
|
135
|
+
|
136
|
+
LIFECYCLE_DETAILS_MAPPING = {
|
137
|
+
JobRun.LIFECYCLE_STATE_SUCCEEDED: "The evaluation ran successfully.",
|
138
|
+
JobRun.LIFECYCLE_STATE_IN_PROGRESS: "The evaluation is running.",
|
139
|
+
JobRun.LIFECYCLE_STATE_FAILED: "The evaluation failed.",
|
140
|
+
JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION: "Missing jobrun information.",
|
141
|
+
}
|
142
|
+
SUPPORTED_FILE_FORMATS = ["jsonl"]
|
143
|
+
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
|
144
|
+
|
145
|
+
|
146
|
+
def get_logger():
|
147
|
+
return logger
|
148
|
+
|
149
|
+
|
150
|
+
def random_color_generator(word: str):
|
151
|
+
seed = sum([ord(c) for c in word]) % 13
|
152
|
+
random.seed(seed)
|
153
|
+
r = random.randint(10, 245)
|
154
|
+
g = random.randint(10, 245)
|
155
|
+
b = random.randint(10, 245)
|
156
|
+
|
157
|
+
text_color = "black" if (0.299 * r + 0.587 * g + 0.114 * b) / 255 > 0.5 else "white"
|
158
|
+
|
159
|
+
return f"#{r:02x}{g:02x}{b:02x}", text_color
|
160
|
+
|
161
|
+
|
162
|
+
def svg_to_base64_datauri(svg_contents: str):
|
163
|
+
base64_encoded_svg_contents = base64.b64encode(svg_contents.encode())
|
164
|
+
return "data:image/svg+xml;base64," + base64_encoded_svg_contents.decode()
|
165
|
+
|
166
|
+
|
167
|
+
def create_word_icon(label: str, width: int = 150, return_as_datauri=True):
|
168
|
+
match = re.findall(r"(^[a-zA-Z]{1}).*?(\d+[a-z]?)", label)
|
169
|
+
icon_text = "".join(match[0] if match else [label[0]])
|
170
|
+
icon_color, text_color = random_color_generator(label)
|
171
|
+
cx = width / 2
|
172
|
+
cy = width / 2
|
173
|
+
r = width / 2
|
174
|
+
fs = int(r / 25)
|
175
|
+
|
176
|
+
t = Template(
|
177
|
+
"""
|
178
|
+
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" width="${width}" height="${width}">
|
179
|
+
|
180
|
+
<style>
|
181
|
+
text {
|
182
|
+
font-size: ${fs}em;
|
183
|
+
font-family: lucida console, Fira Mono, monospace;
|
184
|
+
text-anchor: middle;
|
185
|
+
stroke-width: 1px;
|
186
|
+
font-weight: bold;
|
187
|
+
alignment-baseline: central;
|
188
|
+
}
|
189
|
+
|
190
|
+
</style>
|
191
|
+
|
192
|
+
<circle cx="${cx}" cy="${cy}" r="${r}" fill="${icon_color}" />
|
193
|
+
<text x="50%" y="50%" fill="${text_color}">${icon_text}</text>
|
194
|
+
</svg>
|
195
|
+
""".strip()
|
196
|
+
)
|
197
|
+
|
198
|
+
icon_svg = t.substitute(**locals())
|
199
|
+
if return_as_datauri:
|
200
|
+
return svg_to_base64_datauri(icon_svg)
|
201
|
+
else:
|
202
|
+
return icon_svg
|
203
|
+
|
204
|
+
|
205
|
+
def get_artifact_path(custom_metadata_list: List) -> str:
|
206
|
+
"""Get the artifact path from the custom metadata list of model.
|
207
|
+
|
208
|
+
Parameters
|
209
|
+
----------
|
210
|
+
custom_metadata_list: List
|
211
|
+
A list of custom metadata of model.
|
212
|
+
|
213
|
+
Returns
|
214
|
+
-------
|
215
|
+
str:
|
216
|
+
The artifact path from model.
|
217
|
+
"""
|
218
|
+
for custom_metadata in custom_metadata_list:
|
219
|
+
if custom_metadata.key == MODEL_BY_REFERENCE_OSS_PATH_KEY:
|
220
|
+
if ObjectStorageDetails.is_oci_path(custom_metadata.value):
|
221
|
+
artifact_path = custom_metadata.value
|
222
|
+
else:
|
223
|
+
artifact_path = ObjectStorageDetails(
|
224
|
+
AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, custom_metadata.value
|
225
|
+
).path
|
226
|
+
return artifact_path
|
227
|
+
logger.debug("Failed to get artifact path from custom metadata.")
|
228
|
+
return UNKNOWN
|
229
|
+
|
230
|
+
|
231
|
+
def read_file(file_path: str, **kwargs) -> str:
|
232
|
+
try:
|
233
|
+
with fsspec.open(file_path, "r", **kwargs.get("auth", {})) as f:
|
234
|
+
return f.read()
|
235
|
+
except Exception as e:
|
236
|
+
logger.error(f"Failed to read file {file_path}. {e}")
|
237
|
+
return UNKNOWN
|
238
|
+
|
239
|
+
|
240
|
+
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
|
241
|
+
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
|
242
|
+
if artifact_path.startswith("oci://"):
|
243
|
+
signer = default_signer()
|
244
|
+
else:
|
245
|
+
signer = {}
|
246
|
+
config = json.loads(
|
247
|
+
read_file(file_path=artifact_path, auth=signer, **kwargs) or UNKNOWN_JSON_STR
|
248
|
+
)
|
249
|
+
if not config:
|
250
|
+
raise AquaFileNotFoundError(
|
251
|
+
f"Config file `{config_file_name}` is either empty or missing at {artifact_path}",
|
252
|
+
500,
|
253
|
+
)
|
254
|
+
return config
|
255
|
+
|
256
|
+
|
257
|
+
def is_valid_ocid(ocid: str) -> bool:
|
258
|
+
"""Checks if the given ocid is valid.
|
259
|
+
|
260
|
+
Parameters
|
261
|
+
----------
|
262
|
+
ocid: str
|
263
|
+
Oracle Cloud Identifier (OCID).
|
264
|
+
|
265
|
+
Returns
|
266
|
+
-------
|
267
|
+
bool:
|
268
|
+
Whether the given ocid is valid.
|
269
|
+
"""
|
270
|
+
pattern = r"^ocid1\.([a-z0-9_]+)\.([a-z0-9]+)\.([a-z0-9]*)(\.[^.]+)?\.([a-z0-9_]+)$"
|
271
|
+
match = re.match(pattern, ocid)
|
272
|
+
return bool(match)
|
273
|
+
|
274
|
+
|
275
|
+
def get_resource_type(ocid: str) -> str:
|
276
|
+
"""Gets resource type based on the given ocid.
|
277
|
+
|
278
|
+
Parameters
|
279
|
+
----------
|
280
|
+
ocid: str
|
281
|
+
Oracle Cloud Identifier (OCID).
|
282
|
+
|
283
|
+
Returns
|
284
|
+
-------
|
285
|
+
str:
|
286
|
+
The resource type indicated in the given ocid.
|
287
|
+
|
288
|
+
Raises
|
289
|
+
-------
|
290
|
+
ValueError:
|
291
|
+
When the given ocid is not a valid ocid.
|
292
|
+
"""
|
293
|
+
if not is_valid_ocid(ocid):
|
294
|
+
raise ValueError(
|
295
|
+
f"The given ocid {ocid} is not a valid ocid."
|
296
|
+
"Check out this page https://docs.oracle.com/en-us/iaas/Content/General/Concepts/identifiers.htm to see more details."
|
297
|
+
)
|
298
|
+
return ocid.split(".")[1]
|
299
|
+
|
300
|
+
|
301
|
+
def query_resource(
|
302
|
+
ocid, return_all: bool = True
|
303
|
+
) -> "oci.resource_search.models.ResourceSummary":
|
304
|
+
"""Use Search service to find a single resource within a tenancy.
|
305
|
+
|
306
|
+
Parameters
|
307
|
+
----------
|
308
|
+
ocid: str
|
309
|
+
Oracle Cloud Identifier (OCID).
|
310
|
+
return_all: bool
|
311
|
+
Whether to return allAdditionalFields.
|
312
|
+
|
313
|
+
Returns
|
314
|
+
-------
|
315
|
+
oci.resource_search.models.ResourceSummary:
|
316
|
+
The retrieved resource.
|
317
|
+
"""
|
318
|
+
|
319
|
+
return_all = " return allAdditionalFields " if return_all else " "
|
320
|
+
resource_type = get_resource_type(ocid)
|
321
|
+
query = f"query {resource_type} resources{return_all}where (identifier = '{ocid}')"
|
322
|
+
logger.debug(query)
|
323
|
+
|
324
|
+
resources = OCIResource.search(
|
325
|
+
query,
|
326
|
+
type=SEARCH_TYPE.STRUCTURED,
|
327
|
+
tenant_id=TENANCY_OCID,
|
328
|
+
)
|
329
|
+
if len(resources) == 0:
|
330
|
+
raise AquaRuntimeError(
|
331
|
+
f"Failed to retreive {resource_type}'s information.",
|
332
|
+
service_payload={"query": query, "tenant_id": TENANCY_OCID},
|
333
|
+
)
|
334
|
+
return resources[0]
|
335
|
+
|
336
|
+
|
337
|
+
def query_resources(
|
338
|
+
compartment_id,
|
339
|
+
resource_type: str,
|
340
|
+
return_all: bool = True,
|
341
|
+
tag_list: list = None,
|
342
|
+
status_list: list = None,
|
343
|
+
connect_by_ampersands: bool = True,
|
344
|
+
**kwargs,
|
345
|
+
) -> List["oci.resource_search.models.ResourceSummary"]:
|
346
|
+
"""Use Search service to find resources within compartment.
|
347
|
+
|
348
|
+
Parameters
|
349
|
+
----------
|
350
|
+
compartment_id: str
|
351
|
+
The compartment ocid.
|
352
|
+
resource_type: str
|
353
|
+
The type of the target resources.
|
354
|
+
return_all: bool
|
355
|
+
Whether to return allAdditionalFields.
|
356
|
+
tag_list: list
|
357
|
+
List of tags will be applied for filtering.
|
358
|
+
status_list: list
|
359
|
+
List of lifecycleState will be applied for filtering.
|
360
|
+
connect_by_ampersands: bool
|
361
|
+
Whether to use `&&` to group multiple conditions.
|
362
|
+
if `connect_by_ampersands=False`, `||` will be used.
|
363
|
+
**kwargs:
|
364
|
+
Additional arguments.
|
365
|
+
|
366
|
+
Returns
|
367
|
+
-------
|
368
|
+
List[oci.resource_search.models.ResourceSummary]:
|
369
|
+
The retrieved resources.
|
370
|
+
"""
|
371
|
+
return_all = " return allAdditionalFields " if return_all else " "
|
372
|
+
condition_lifecycle = _construct_condition(
|
373
|
+
field_name="lifecycleState",
|
374
|
+
allowed_values=status_list,
|
375
|
+
connect_by_ampersands=False,
|
376
|
+
)
|
377
|
+
condition_tags = _construct_condition(
|
378
|
+
field_name="freeformTags.key",
|
379
|
+
allowed_values=tag_list,
|
380
|
+
connect_by_ampersands=connect_by_ampersands,
|
381
|
+
)
|
382
|
+
query = f"query {resource_type} resources{return_all}where (compartmentId = '{compartment_id}'{condition_lifecycle}{condition_tags})"
|
383
|
+
logger.debug(query)
|
384
|
+
logger.debug(f"tenant_id=`{TENANCY_OCID}`")
|
385
|
+
|
386
|
+
return OCIResource.search(
|
387
|
+
query, type=SEARCH_TYPE.STRUCTURED, tenant_id=TENANCY_OCID, **kwargs
|
388
|
+
)
|
389
|
+
|
390
|
+
|
391
|
+
def _construct_condition(
|
392
|
+
field_name: str, allowed_values: list = None, connect_by_ampersands: bool = True
|
393
|
+
) -> str:
|
394
|
+
"""Returns tag condition applied in query statement.
|
395
|
+
|
396
|
+
Parameters
|
397
|
+
----------
|
398
|
+
field_name: str
|
399
|
+
The field_name keyword is the resource attribute against which the
|
400
|
+
operation and chosen value of that attribute are evaluated.
|
401
|
+
allowed_values: list
|
402
|
+
List of value will be applied for filtering.
|
403
|
+
connect_by_ampersands: bool
|
404
|
+
Whether to use `&&` to group multiple tag conditions.
|
405
|
+
if `connect_by_ampersands=False`, `||` will be used.
|
406
|
+
|
407
|
+
Returns
|
408
|
+
-------
|
409
|
+
str:
|
410
|
+
The tag condition.
|
411
|
+
"""
|
412
|
+
if not allowed_values:
|
413
|
+
return ""
|
414
|
+
|
415
|
+
joint = "&&" if connect_by_ampersands else "||"
|
416
|
+
formatted_tags = [f"{field_name} = '{value}'" for value in allowed_values]
|
417
|
+
joined_tags = f" {joint} ".join(formatted_tags)
|
418
|
+
condition = f" && ({joined_tags})" if joined_tags else ""
|
419
|
+
return condition
|
420
|
+
|
421
|
+
|
422
|
+
def upload_local_to_os(
|
423
|
+
src_uri: str, dst_uri: str, auth: dict = None, force_overwrite: bool = False
|
424
|
+
):
|
425
|
+
expanded_path = os.path.expanduser(src_uri)
|
426
|
+
if not os.path.isfile(expanded_path):
|
427
|
+
raise AquaFileNotFoundError("Invalid input file path. Specify a valid one.")
|
428
|
+
if Path(expanded_path).suffix.lstrip(".") not in SUPPORTED_FILE_FORMATS:
|
429
|
+
raise AquaValueError(
|
430
|
+
f"Invalid input file. Only {', '.join(SUPPORTED_FILE_FORMATS)} files are supported."
|
431
|
+
)
|
432
|
+
if os.path.getsize(expanded_path) == 0:
|
433
|
+
raise AquaValueError("Empty input file. Specify a valid file path.")
|
434
|
+
if os.path.getsize(expanded_path) > MAXIMUM_ALLOWED_DATASET_IN_BYTE:
|
435
|
+
raise AquaValueError(
|
436
|
+
f"Local dataset file can't exceed {MAXIMUM_ALLOWED_DATASET_IN_BYTE} bytes."
|
437
|
+
)
|
438
|
+
|
439
|
+
upload_to_os(
|
440
|
+
src_uri=expanded_path,
|
441
|
+
dst_uri=dst_uri,
|
442
|
+
auth=auth,
|
443
|
+
force_overwrite=force_overwrite,
|
444
|
+
)
|
445
|
+
|
446
|
+
|
447
|
+
def sanitize_response(oci_client, response: list):
|
448
|
+
"""Builds a JSON POST object for the response from OCI clients.
|
449
|
+
|
450
|
+
Parameters
|
451
|
+
----------
|
452
|
+
oci_client
|
453
|
+
OCI client object
|
454
|
+
|
455
|
+
response
|
456
|
+
list of results from the OCI client
|
457
|
+
|
458
|
+
Returns
|
459
|
+
-------
|
460
|
+
The serialized form of data.
|
461
|
+
|
462
|
+
"""
|
463
|
+
return oci_client.base_client.sanitize_for_serialization(response)
|
464
|
+
|
465
|
+
|
466
|
+
def _build_resource_identifier(
|
467
|
+
id: str = None, name: str = None, region: str = None
|
468
|
+
) -> AquaResourceIdentifier:
|
469
|
+
"""Constructs AquaResourceIdentifier based on the given ocid and display name."""
|
470
|
+
try:
|
471
|
+
resource_type = CONSOLE_LINK_RESOURCE_TYPE_MAPPING.get(get_resource_type(id))
|
472
|
+
|
473
|
+
return AquaResourceIdentifier(
|
474
|
+
id=id,
|
475
|
+
name=name,
|
476
|
+
url=get_console_link(
|
477
|
+
resource=resource_type,
|
478
|
+
ocid=id,
|
479
|
+
region=region,
|
480
|
+
),
|
481
|
+
)
|
482
|
+
except Exception as e:
|
483
|
+
logger.error(
|
484
|
+
f"Failed to construct AquaResourceIdentifier from given id=`{id}`, and name=`{name}`, {str(e)}"
|
485
|
+
)
|
486
|
+
return AquaResourceIdentifier()
|
487
|
+
|
488
|
+
|
489
|
+
def _get_experiment_info(
|
490
|
+
model: Union[oci.resource_search.models.ResourceSummary, DataScienceModel]
|
491
|
+
) -> tuple:
|
492
|
+
"""Returns ocid and name of the experiment."""
|
493
|
+
return (
|
494
|
+
(
|
495
|
+
model.additional_details.get(RqsAdditionalDetails.MODEL_VERSION_SET_ID),
|
496
|
+
model.additional_details.get(RqsAdditionalDetails.MODEL_VERSION_SET_NAME),
|
497
|
+
)
|
498
|
+
if isinstance(model, oci.resource_search.models.ResourceSummary)
|
499
|
+
else (model.model_version_set_id, model.model_version_set_name)
|
500
|
+
)
|
501
|
+
|
502
|
+
|
503
|
+
def _build_job_identifier(
|
504
|
+
job_run_details: Union[
|
505
|
+
oci.data_science.models.JobRun, oci.resource_search.models.ResourceSummary
|
506
|
+
] = None,
|
507
|
+
**kwargs,
|
508
|
+
) -> AquaResourceIdentifier:
|
509
|
+
try:
|
510
|
+
job_id = (
|
511
|
+
job_run_details.id
|
512
|
+
if isinstance(job_run_details, oci.data_science.models.JobRun)
|
513
|
+
else job_run_details.identifier
|
514
|
+
)
|
515
|
+
return _build_resource_identifier(
|
516
|
+
id=job_id, name=job_run_details.display_name, **kwargs
|
517
|
+
)
|
518
|
+
|
519
|
+
except Exception as e:
|
520
|
+
logger.debug(
|
521
|
+
f"Failed to get job details from job_run_details: {job_run_details}"
|
522
|
+
f"DEBUG INFO:{str(e)}"
|
523
|
+
)
|
524
|
+
return AquaResourceIdentifier()
|
525
|
+
|
526
|
+
|
527
|
+
def get_container_image(
|
528
|
+
config_file_name: str = None, container_type: str = None
|
529
|
+
) -> str:
|
530
|
+
"""Gets the image name from the given model and container type.
|
531
|
+
Parameters
|
532
|
+
----------
|
533
|
+
config_file_name: str
|
534
|
+
name of the config file
|
535
|
+
container_type: str
|
536
|
+
type of container, can be either deployment-container, finetune-container, evaluation-container
|
537
|
+
|
538
|
+
Returns
|
539
|
+
-------
|
540
|
+
Dict:
|
541
|
+
A dict of allowed configs.
|
542
|
+
"""
|
543
|
+
|
544
|
+
config_file_name = (
|
545
|
+
f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
|
546
|
+
)
|
547
|
+
|
548
|
+
config = load_config(
|
549
|
+
file_path=config_file_name,
|
550
|
+
config_file_name=CONTAINER_INDEX,
|
551
|
+
)
|
552
|
+
|
553
|
+
if container_type not in config:
|
554
|
+
raise AquaValueError(
|
555
|
+
f"{config_file_name} does not have config details for model: {container_type}"
|
556
|
+
)
|
557
|
+
|
558
|
+
container_image = None
|
559
|
+
mapping = config[container_type]
|
560
|
+
versions = [obj["version"] for obj in mapping]
|
561
|
+
# assumes numbered versions, update if `latest` is used
|
562
|
+
latest = get_max_version(versions)
|
563
|
+
for obj in mapping:
|
564
|
+
if obj["version"] == str(latest):
|
565
|
+
container_image = f"{obj['name']}:{obj['version']}"
|
566
|
+
break
|
567
|
+
|
568
|
+
if not container_image:
|
569
|
+
raise AquaValueError(
|
570
|
+
f"{config_file_name} is missing name and/or version details."
|
571
|
+
)
|
572
|
+
|
573
|
+
return container_image
|
574
|
+
|
575
|
+
|
576
|
+
def get_max_version(versions):
|
577
|
+
"""Takes in a list of versions and returns the higher version."""
|
578
|
+
if not versions:
|
579
|
+
return UNKNOWN
|
580
|
+
|
581
|
+
def compare_versions(version1, version2):
|
582
|
+
# split version strings into parts and convert to int values for comparison
|
583
|
+
parts1 = list(map(int, version1.split(".")))
|
584
|
+
parts2 = list(map(int, version2.split(".")))
|
585
|
+
|
586
|
+
# compare each part
|
587
|
+
for idx in range(min(len(parts1), len(parts2))):
|
588
|
+
if parts1[idx] < parts2[idx]:
|
589
|
+
return version2
|
590
|
+
elif parts1[idx] > parts2[idx]:
|
591
|
+
return version1
|
592
|
+
|
593
|
+
# if all parts are equal up to this point, return the longer version string
|
594
|
+
return version1 if len(parts1) > len(parts2) else version2
|
595
|
+
|
596
|
+
max_version = versions[0]
|
597
|
+
for version in versions[1:]:
|
598
|
+
max_version = compare_versions(max_version, version)
|
599
|
+
|
600
|
+
return max_version
|
601
|
+
|
602
|
+
|
603
|
+
def fire_and_forget(func):
|
604
|
+
"""Decorator to push execution of methods to the background."""
|
605
|
+
|
606
|
+
@wraps(func)
|
607
|
+
def wrapped(*args, **kwargs):
|
608
|
+
return asyncio.get_event_loop().run_in_executor(None, func, *args, *kwargs)
|
609
|
+
|
610
|
+
return wrapped
|
611
|
+
|
612
|
+
|
613
|
+
def get_base_model_from_tags(tags):
|
614
|
+
base_model_ocid = ""
|
615
|
+
base_model_name = ""
|
616
|
+
if Tags.AQUA_FINE_TUNED_MODEL_TAG.value in tags:
|
617
|
+
tag = tags[Tags.AQUA_FINE_TUNED_MODEL_TAG.value]
|
618
|
+
if "#" in tag:
|
619
|
+
base_model_ocid, base_model_name = tag.split("#")
|
620
|
+
|
621
|
+
if not (is_valid_ocid(base_model_ocid) and base_model_name):
|
622
|
+
raise AquaValueError(
|
623
|
+
f"{Tags.AQUA_FINE_TUNED_MODEL_TAG.value} tag should have the format `Service Model OCID#Model Name`."
|
624
|
+
)
|
625
|
+
|
626
|
+
return base_model_ocid, base_model_name
|
627
|
+
|
628
|
+
|
629
|
+
def get_resource_name(ocid: str) -> str:
|
630
|
+
"""Gets resource name based on the given ocid.
|
631
|
+
|
632
|
+
Parameters
|
633
|
+
----------
|
634
|
+
ocid: str
|
635
|
+
Oracle Cloud Identifier (OCID).
|
636
|
+
|
637
|
+
Returns
|
638
|
+
-------
|
639
|
+
str:
|
640
|
+
The resource name indicated in the given ocid.
|
641
|
+
|
642
|
+
Raises
|
643
|
+
-------
|
644
|
+
ValueError:
|
645
|
+
When the given ocid is not a valid ocid.
|
646
|
+
"""
|
647
|
+
if not is_valid_ocid(ocid):
|
648
|
+
raise ValueError(
|
649
|
+
f"The given ocid {ocid} is not a valid ocid."
|
650
|
+
"Check out this page https://docs.oracle.com/en-us/iaas/Content/General/Concepts/identifiers.htm to see more details."
|
651
|
+
)
|
652
|
+
try:
|
653
|
+
resource = query_resource(ocid, return_all=False)
|
654
|
+
name = resource.display_name if resource else UNKNOWN
|
655
|
+
except:
|
656
|
+
name = UNKNOWN
|
657
|
+
return name
|
658
|
+
|
659
|
+
|
660
|
+
def get_model_by_reference_paths(model_file_description: dict):
|
661
|
+
"""Reads the model file description json dict and returns the base model path and fine-tuned path for
|
662
|
+
models created by reference.
|
663
|
+
|
664
|
+
Parameters
|
665
|
+
----------
|
666
|
+
model_file_description: dict
|
667
|
+
json dict containing model paths and objects for models created by reference.
|
668
|
+
|
669
|
+
Returns
|
670
|
+
-------
|
671
|
+
a tuple with base_model_path and fine_tune_output_path
|
672
|
+
"""
|
673
|
+
base_model_path = UNKNOWN
|
674
|
+
fine_tune_output_path = UNKNOWN
|
675
|
+
models = model_file_description["models"]
|
676
|
+
|
677
|
+
for model in models:
|
678
|
+
namespace, bucket_name, prefix = (
|
679
|
+
model["namespace"],
|
680
|
+
model["bucketName"],
|
681
|
+
model["prefix"],
|
682
|
+
)
|
683
|
+
bucket_uri = f"oci://{bucket_name}@{namespace}/{prefix}".rstrip("/")
|
684
|
+
if bucket_name == AQUA_SERVICE_MODELS_BUCKET:
|
685
|
+
base_model_path = bucket_uri
|
686
|
+
else:
|
687
|
+
fine_tune_output_path = bucket_uri
|
688
|
+
|
689
|
+
if not base_model_path:
|
690
|
+
raise AquaValueError(
|
691
|
+
f"Base Model should come from the bucket {AQUA_SERVICE_MODELS_BUCKET}. "
|
692
|
+
f"Other paths are not supported by Aqua."
|
693
|
+
)
|
694
|
+
return base_model_path, fine_tune_output_path
|
695
|
+
|
696
|
+
|
697
|
+
def _is_valid_mvs(mvs: ModelVersionSet, target_tag: str) -> bool:
|
698
|
+
"""Returns whether the given model version sets has the target tag.
|
699
|
+
|
700
|
+
Parameters
|
701
|
+
----------
|
702
|
+
mvs: str
|
703
|
+
The instance of `ads.model.ModelVersionSet`.
|
704
|
+
target_tag: list
|
705
|
+
Target tag expected to be in MVS.
|
706
|
+
|
707
|
+
Returns
|
708
|
+
-------
|
709
|
+
bool:
|
710
|
+
Return True if the given model version sets is valid.
|
711
|
+
"""
|
712
|
+
if mvs.freeform_tags is None:
|
713
|
+
return False
|
714
|
+
|
715
|
+
return target_tag in mvs.freeform_tags
|