oracle-ads 2.10.1__py3-none-any.whl → 2.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. ads/aqua/__init__.py +12 -0
  2. ads/aqua/base.py +324 -0
  3. ads/aqua/cli.py +19 -0
  4. ads/aqua/config/deployment_config_defaults.json +9 -0
  5. ads/aqua/config/resource_limit_names.json +7 -0
  6. ads/aqua/constants.py +45 -0
  7. ads/aqua/data.py +40 -0
  8. ads/aqua/decorator.py +101 -0
  9. ads/aqua/deployment.py +643 -0
  10. ads/aqua/dummy_data/icon.txt +1 -0
  11. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  12. ads/aqua/dummy_data/oci_models.json +1 -0
  13. ads/aqua/dummy_data/readme.md +26 -0
  14. ads/aqua/evaluation.py +1751 -0
  15. ads/aqua/exception.py +82 -0
  16. ads/aqua/extension/__init__.py +40 -0
  17. ads/aqua/extension/base_handler.py +138 -0
  18. ads/aqua/extension/common_handler.py +21 -0
  19. ads/aqua/extension/deployment_handler.py +202 -0
  20. ads/aqua/extension/evaluation_handler.py +135 -0
  21. ads/aqua/extension/finetune_handler.py +66 -0
  22. ads/aqua/extension/model_handler.py +59 -0
  23. ads/aqua/extension/ui_handler.py +201 -0
  24. ads/aqua/extension/utils.py +23 -0
  25. ads/aqua/finetune.py +579 -0
  26. ads/aqua/job.py +29 -0
  27. ads/aqua/model.py +819 -0
  28. ads/aqua/training/__init__.py +4 -0
  29. ads/aqua/training/exceptions.py +459 -0
  30. ads/aqua/ui.py +453 -0
  31. ads/aqua/utils.py +715 -0
  32. ads/cli.py +37 -6
  33. ads/common/decorator/__init__.py +7 -3
  34. ads/common/decorator/require_nonempty_arg.py +65 -0
  35. ads/common/object_storage_details.py +166 -7
  36. ads/common/oci_client.py +18 -1
  37. ads/common/oci_logging.py +2 -2
  38. ads/common/oci_mixin.py +4 -5
  39. ads/common/serializer.py +34 -5
  40. ads/common/utils.py +75 -10
  41. ads/config.py +40 -1
  42. ads/jobs/ads_job.py +43 -25
  43. ads/jobs/builders/infrastructure/base.py +4 -2
  44. ads/jobs/builders/infrastructure/dsc_job.py +49 -39
  45. ads/jobs/builders/runtimes/base.py +71 -1
  46. ads/jobs/builders/runtimes/container_runtime.py +4 -4
  47. ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
  48. ads/jobs/templates/driver_pytorch.py +27 -10
  49. ads/model/artifact_downloader.py +84 -14
  50. ads/model/artifact_uploader.py +25 -23
  51. ads/model/datascience_model.py +388 -38
  52. ads/model/deployment/model_deployment.py +10 -2
  53. ads/model/generic_model.py +8 -0
  54. ads/model/model_file_description_schema.json +68 -0
  55. ads/model/model_metadata.py +1 -1
  56. ads/model/service/oci_datascience_model.py +34 -5
  57. ads/opctl/operator/lowcode/anomaly/README.md +2 -1
  58. ads/opctl/operator/lowcode/anomaly/__main__.py +10 -4
  59. ads/opctl/operator/lowcode/anomaly/environment.yaml +2 -1
  60. ads/opctl/operator/lowcode/anomaly/model/automlx.py +12 -6
  61. ads/opctl/operator/lowcode/forecast/README.md +3 -2
  62. ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
  63. ads/opctl/operator/lowcode/forecast/model/automlx.py +12 -23
  64. ads/telemetry/base.py +62 -0
  65. ads/telemetry/client.py +105 -0
  66. ads/telemetry/telemetry.py +6 -3
  67. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +37 -7
  68. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +71 -36
  69. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
  70. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
  71. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
ads/aqua/utils.py ADDED
@@ -0,0 +1,715 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+ """AQUA utils and constants."""
6
+ import asyncio
7
+ import base64
8
+ import json
9
+ import logging
10
+ import os
11
+ import random
12
+ import re
13
+ import sys
14
+ from enum import Enum
15
+ from functools import wraps
16
+ from pathlib import Path
17
+ from string import Template
18
+ from typing import List, Union
19
+
20
+ import fsspec
21
+ import oci
22
+ from oci.data_science.models import JobRun, Model
23
+
24
+ from ads.aqua.constants import RqsAdditionalDetails
25
+ from ads.aqua.data import AquaResourceIdentifier, Tags
26
+ from ads.aqua.exception import AquaFileNotFoundError, AquaRuntimeError, AquaValueError
27
+ from ads.common.auth import default_signer
28
+ from ads.common.object_storage_details import ObjectStorageDetails
29
+ from ads.common.oci_resource import SEARCH_TYPE, OCIResource
30
+ from ads.common.utils import get_console_link, upload_to_os
31
+ from ads.config import (
32
+ AQUA_CONFIG_FOLDER,
33
+ AQUA_SERVICE_MODELS_BUCKET,
34
+ TENANCY_OCID,
35
+ CONDA_BUCKET_NS,
36
+ )
37
+ from ads.model import DataScienceModel, ModelVersionSet
38
+
39
+ # TODO: allow the user to setup the logging level?
40
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
41
+ logger = logging.getLogger("ODSC_AQUA")
42
+
43
+ UNKNOWN = ""
44
+ UNKNOWN_DICT = {}
45
+ README = "README.md"
46
+ LICENSE_TXT = "config/LICENSE.txt"
47
+ DEPLOYMENT_CONFIG = "deployment_config.json"
48
+ CONTAINER_INDEX = "container_index.json"
49
+ EVALUATION_REPORT_JSON = "report.json"
50
+ EVALUATION_REPORT_MD = "report.md"
51
+ EVALUATION_REPORT = "report.html"
52
+ UNKNOWN_JSON_STR = "{}"
53
+ CONSOLE_LINK_RESOURCE_TYPE_MAPPING = dict(
54
+ datasciencemodel="models",
55
+ datasciencemodeldeployment="model-deployments",
56
+ datasciencemodeldeploymentdev="model-deployments",
57
+ datasciencemodeldeploymentint="model-deployments",
58
+ datasciencemodeldeploymentpre="model-deployments",
59
+ datasciencejob="jobs",
60
+ datasciencejobrun="job-runs",
61
+ datasciencejobrundev="job-runs",
62
+ datasciencejobrunint="job-runs",
63
+ datasciencejobrunpre="job-runs",
64
+ datasciencemodelversionset="model-version-sets",
65
+ datasciencemodelversionsetpre="model-version-sets",
66
+ datasciencemodelversionsetint="model-version-sets",
67
+ datasciencemodelversionsetdev="model-version-sets",
68
+ )
69
+ FINE_TUNING_RUNTIME_CONTAINER = "iad.ocir.io/ociodscdev/aqua_ft_cuda121:0.3.17.20"
70
+ DEFAULT_FT_BLOCK_STORAGE_SIZE = 256
71
+ DEFAULT_FT_REPLICA = 1
72
+ DEFAULT_FT_BATCH_SIZE = 1
73
+ DEFAULT_FT_VALIDATION_SET_SIZE = 0.1
74
+
75
+ HF_MODELS = "/home/datascience/conda/pytorch21_p39_gpu_v1/"
76
+ MAXIMUM_ALLOWED_DATASET_IN_BYTE = 52428800 # 1024 x 1024 x 50 = 50MB
77
+ JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING = "ME_STANDALONE"
78
+ NB_SESSION_IDENTIFIER = "NB_SESSION_OCID"
79
+ LIFECYCLE_DETAILS_MISSING_JOBRUN = "The asscociated JobRun resource has been deleted."
80
+ READY_TO_DEPLOY_STATUS = "ACTIVE"
81
+
82
+
83
+ class LifecycleStatus(Enum):
84
+ UNKNOWN = ""
85
+
86
+ @property
87
+ def detail(self) -> str:
88
+ """Returns the detail message corresponding to the status."""
89
+ return LIFECYCLE_DETAILS_MAPPING.get(
90
+ self.name, f"No detail available for the status {self.name}."
91
+ )
92
+
93
+ @staticmethod
94
+ def get_status(evaluation_status: str, job_run_status: str = None):
95
+ """
96
+ Maps the combination of evaluation status and job run status to a standard status.
97
+
98
+ Parameters
99
+ ----------
100
+ evaluation_status (str):
101
+ The status of the evaluation.
102
+ job_run_status (str):
103
+ The status of the job run.
104
+
105
+ Returns
106
+ -------
107
+ LifecycleStatus
108
+ The mapped status ("Completed", "In Progress", "Canceled").
109
+ """
110
+ if not job_run_status:
111
+ logger.error("Failed to get jobrun state.")
112
+ # case1 : failed to create jobrun
113
+ # case2: jobrun is deleted - rqs cannot retreive deleted resource
114
+ return JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION
115
+
116
+ status = LifecycleStatus.UNKNOWN
117
+ if evaluation_status == Model.LIFECYCLE_STATE_ACTIVE:
118
+ if (
119
+ job_run_status == JobRun.LIFECYCLE_STATE_IN_PROGRESS
120
+ or job_run_status == JobRun.LIFECYCLE_STATE_ACCEPTED
121
+ ):
122
+ status = JobRun.LIFECYCLE_STATE_IN_PROGRESS
123
+ elif (
124
+ job_run_status == JobRun.LIFECYCLE_STATE_FAILED
125
+ or job_run_status == JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION
126
+ ):
127
+ status = JobRun.LIFECYCLE_STATE_FAILED
128
+ else:
129
+ status = job_run_status
130
+ else:
131
+ status = evaluation_status
132
+
133
+ return status
134
+
135
+
136
+ LIFECYCLE_DETAILS_MAPPING = {
137
+ JobRun.LIFECYCLE_STATE_SUCCEEDED: "The evaluation ran successfully.",
138
+ JobRun.LIFECYCLE_STATE_IN_PROGRESS: "The evaluation is running.",
139
+ JobRun.LIFECYCLE_STATE_FAILED: "The evaluation failed.",
140
+ JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION: "Missing jobrun information.",
141
+ }
142
+ SUPPORTED_FILE_FORMATS = ["jsonl"]
143
+ MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
144
+
145
+
146
+ def get_logger():
147
+ return logger
148
+
149
+
150
+ def random_color_generator(word: str):
151
+ seed = sum([ord(c) for c in word]) % 13
152
+ random.seed(seed)
153
+ r = random.randint(10, 245)
154
+ g = random.randint(10, 245)
155
+ b = random.randint(10, 245)
156
+
157
+ text_color = "black" if (0.299 * r + 0.587 * g + 0.114 * b) / 255 > 0.5 else "white"
158
+
159
+ return f"#{r:02x}{g:02x}{b:02x}", text_color
160
+
161
+
162
+ def svg_to_base64_datauri(svg_contents: str):
163
+ base64_encoded_svg_contents = base64.b64encode(svg_contents.encode())
164
+ return "data:image/svg+xml;base64," + base64_encoded_svg_contents.decode()
165
+
166
+
167
+ def create_word_icon(label: str, width: int = 150, return_as_datauri=True):
168
+ match = re.findall(r"(^[a-zA-Z]{1}).*?(\d+[a-z]?)", label)
169
+ icon_text = "".join(match[0] if match else [label[0]])
170
+ icon_color, text_color = random_color_generator(label)
171
+ cx = width / 2
172
+ cy = width / 2
173
+ r = width / 2
174
+ fs = int(r / 25)
175
+
176
+ t = Template(
177
+ """
178
+ <svg xmlns="http://www.w3.org/2000/svg" version="1.1" width="${width}" height="${width}">
179
+
180
+ <style>
181
+ text {
182
+ font-size: ${fs}em;
183
+ font-family: lucida console, Fira Mono, monospace;
184
+ text-anchor: middle;
185
+ stroke-width: 1px;
186
+ font-weight: bold;
187
+ alignment-baseline: central;
188
+ }
189
+
190
+ </style>
191
+
192
+ <circle cx="${cx}" cy="${cy}" r="${r}" fill="${icon_color}" />
193
+ <text x="50%" y="50%" fill="${text_color}">${icon_text}</text>
194
+ </svg>
195
+ """.strip()
196
+ )
197
+
198
+ icon_svg = t.substitute(**locals())
199
+ if return_as_datauri:
200
+ return svg_to_base64_datauri(icon_svg)
201
+ else:
202
+ return icon_svg
203
+
204
+
205
+ def get_artifact_path(custom_metadata_list: List) -> str:
206
+ """Get the artifact path from the custom metadata list of model.
207
+
208
+ Parameters
209
+ ----------
210
+ custom_metadata_list: List
211
+ A list of custom metadata of model.
212
+
213
+ Returns
214
+ -------
215
+ str:
216
+ The artifact path from model.
217
+ """
218
+ for custom_metadata in custom_metadata_list:
219
+ if custom_metadata.key == MODEL_BY_REFERENCE_OSS_PATH_KEY:
220
+ if ObjectStorageDetails.is_oci_path(custom_metadata.value):
221
+ artifact_path = custom_metadata.value
222
+ else:
223
+ artifact_path = ObjectStorageDetails(
224
+ AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, custom_metadata.value
225
+ ).path
226
+ return artifact_path
227
+ logger.debug("Failed to get artifact path from custom metadata.")
228
+ return UNKNOWN
229
+
230
+
231
+ def read_file(file_path: str, **kwargs) -> str:
232
+ try:
233
+ with fsspec.open(file_path, "r", **kwargs.get("auth", {})) as f:
234
+ return f.read()
235
+ except Exception as e:
236
+ logger.error(f"Failed to read file {file_path}. {e}")
237
+ return UNKNOWN
238
+
239
+
240
+ def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
241
+ artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
242
+ if artifact_path.startswith("oci://"):
243
+ signer = default_signer()
244
+ else:
245
+ signer = {}
246
+ config = json.loads(
247
+ read_file(file_path=artifact_path, auth=signer, **kwargs) or UNKNOWN_JSON_STR
248
+ )
249
+ if not config:
250
+ raise AquaFileNotFoundError(
251
+ f"Config file `{config_file_name}` is either empty or missing at {artifact_path}",
252
+ 500,
253
+ )
254
+ return config
255
+
256
+
257
+ def is_valid_ocid(ocid: str) -> bool:
258
+ """Checks if the given ocid is valid.
259
+
260
+ Parameters
261
+ ----------
262
+ ocid: str
263
+ Oracle Cloud Identifier (OCID).
264
+
265
+ Returns
266
+ -------
267
+ bool:
268
+ Whether the given ocid is valid.
269
+ """
270
+ pattern = r"^ocid1\.([a-z0-9_]+)\.([a-z0-9]+)\.([a-z0-9]*)(\.[^.]+)?\.([a-z0-9_]+)$"
271
+ match = re.match(pattern, ocid)
272
+ return bool(match)
273
+
274
+
275
+ def get_resource_type(ocid: str) -> str:
276
+ """Gets resource type based on the given ocid.
277
+
278
+ Parameters
279
+ ----------
280
+ ocid: str
281
+ Oracle Cloud Identifier (OCID).
282
+
283
+ Returns
284
+ -------
285
+ str:
286
+ The resource type indicated in the given ocid.
287
+
288
+ Raises
289
+ -------
290
+ ValueError:
291
+ When the given ocid is not a valid ocid.
292
+ """
293
+ if not is_valid_ocid(ocid):
294
+ raise ValueError(
295
+ f"The given ocid {ocid} is not a valid ocid."
296
+ "Check out this page https://docs.oracle.com/en-us/iaas/Content/General/Concepts/identifiers.htm to see more details."
297
+ )
298
+ return ocid.split(".")[1]
299
+
300
+
301
+ def query_resource(
302
+ ocid, return_all: bool = True
303
+ ) -> "oci.resource_search.models.ResourceSummary":
304
+ """Use Search service to find a single resource within a tenancy.
305
+
306
+ Parameters
307
+ ----------
308
+ ocid: str
309
+ Oracle Cloud Identifier (OCID).
310
+ return_all: bool
311
+ Whether to return allAdditionalFields.
312
+
313
+ Returns
314
+ -------
315
+ oci.resource_search.models.ResourceSummary:
316
+ The retrieved resource.
317
+ """
318
+
319
+ return_all = " return allAdditionalFields " if return_all else " "
320
+ resource_type = get_resource_type(ocid)
321
+ query = f"query {resource_type} resources{return_all}where (identifier = '{ocid}')"
322
+ logger.debug(query)
323
+
324
+ resources = OCIResource.search(
325
+ query,
326
+ type=SEARCH_TYPE.STRUCTURED,
327
+ tenant_id=TENANCY_OCID,
328
+ )
329
+ if len(resources) == 0:
330
+ raise AquaRuntimeError(
331
+ f"Failed to retreive {resource_type}'s information.",
332
+ service_payload={"query": query, "tenant_id": TENANCY_OCID},
333
+ )
334
+ return resources[0]
335
+
336
+
337
+ def query_resources(
338
+ compartment_id,
339
+ resource_type: str,
340
+ return_all: bool = True,
341
+ tag_list: list = None,
342
+ status_list: list = None,
343
+ connect_by_ampersands: bool = True,
344
+ **kwargs,
345
+ ) -> List["oci.resource_search.models.ResourceSummary"]:
346
+ """Use Search service to find resources within compartment.
347
+
348
+ Parameters
349
+ ----------
350
+ compartment_id: str
351
+ The compartment ocid.
352
+ resource_type: str
353
+ The type of the target resources.
354
+ return_all: bool
355
+ Whether to return allAdditionalFields.
356
+ tag_list: list
357
+ List of tags will be applied for filtering.
358
+ status_list: list
359
+ List of lifecycleState will be applied for filtering.
360
+ connect_by_ampersands: bool
361
+ Whether to use `&&` to group multiple conditions.
362
+ if `connect_by_ampersands=False`, `||` will be used.
363
+ **kwargs:
364
+ Additional arguments.
365
+
366
+ Returns
367
+ -------
368
+ List[oci.resource_search.models.ResourceSummary]:
369
+ The retrieved resources.
370
+ """
371
+ return_all = " return allAdditionalFields " if return_all else " "
372
+ condition_lifecycle = _construct_condition(
373
+ field_name="lifecycleState",
374
+ allowed_values=status_list,
375
+ connect_by_ampersands=False,
376
+ )
377
+ condition_tags = _construct_condition(
378
+ field_name="freeformTags.key",
379
+ allowed_values=tag_list,
380
+ connect_by_ampersands=connect_by_ampersands,
381
+ )
382
+ query = f"query {resource_type} resources{return_all}where (compartmentId = '{compartment_id}'{condition_lifecycle}{condition_tags})"
383
+ logger.debug(query)
384
+ logger.debug(f"tenant_id=`{TENANCY_OCID}`")
385
+
386
+ return OCIResource.search(
387
+ query, type=SEARCH_TYPE.STRUCTURED, tenant_id=TENANCY_OCID, **kwargs
388
+ )
389
+
390
+
391
+ def _construct_condition(
392
+ field_name: str, allowed_values: list = None, connect_by_ampersands: bool = True
393
+ ) -> str:
394
+ """Returns tag condition applied in query statement.
395
+
396
+ Parameters
397
+ ----------
398
+ field_name: str
399
+ The field_name keyword is the resource attribute against which the
400
+ operation and chosen value of that attribute are evaluated.
401
+ allowed_values: list
402
+ List of value will be applied for filtering.
403
+ connect_by_ampersands: bool
404
+ Whether to use `&&` to group multiple tag conditions.
405
+ if `connect_by_ampersands=False`, `||` will be used.
406
+
407
+ Returns
408
+ -------
409
+ str:
410
+ The tag condition.
411
+ """
412
+ if not allowed_values:
413
+ return ""
414
+
415
+ joint = "&&" if connect_by_ampersands else "||"
416
+ formatted_tags = [f"{field_name} = '{value}'" for value in allowed_values]
417
+ joined_tags = f" {joint} ".join(formatted_tags)
418
+ condition = f" && ({joined_tags})" if joined_tags else ""
419
+ return condition
420
+
421
+
422
+ def upload_local_to_os(
423
+ src_uri: str, dst_uri: str, auth: dict = None, force_overwrite: bool = False
424
+ ):
425
+ expanded_path = os.path.expanduser(src_uri)
426
+ if not os.path.isfile(expanded_path):
427
+ raise AquaFileNotFoundError("Invalid input file path. Specify a valid one.")
428
+ if Path(expanded_path).suffix.lstrip(".") not in SUPPORTED_FILE_FORMATS:
429
+ raise AquaValueError(
430
+ f"Invalid input file. Only {', '.join(SUPPORTED_FILE_FORMATS)} files are supported."
431
+ )
432
+ if os.path.getsize(expanded_path) == 0:
433
+ raise AquaValueError("Empty input file. Specify a valid file path.")
434
+ if os.path.getsize(expanded_path) > MAXIMUM_ALLOWED_DATASET_IN_BYTE:
435
+ raise AquaValueError(
436
+ f"Local dataset file can't exceed {MAXIMUM_ALLOWED_DATASET_IN_BYTE} bytes."
437
+ )
438
+
439
+ upload_to_os(
440
+ src_uri=expanded_path,
441
+ dst_uri=dst_uri,
442
+ auth=auth,
443
+ force_overwrite=force_overwrite,
444
+ )
445
+
446
+
447
+ def sanitize_response(oci_client, response: list):
448
+ """Builds a JSON POST object for the response from OCI clients.
449
+
450
+ Parameters
451
+ ----------
452
+ oci_client
453
+ OCI client object
454
+
455
+ response
456
+ list of results from the OCI client
457
+
458
+ Returns
459
+ -------
460
+ The serialized form of data.
461
+
462
+ """
463
+ return oci_client.base_client.sanitize_for_serialization(response)
464
+
465
+
466
+ def _build_resource_identifier(
467
+ id: str = None, name: str = None, region: str = None
468
+ ) -> AquaResourceIdentifier:
469
+ """Constructs AquaResourceIdentifier based on the given ocid and display name."""
470
+ try:
471
+ resource_type = CONSOLE_LINK_RESOURCE_TYPE_MAPPING.get(get_resource_type(id))
472
+
473
+ return AquaResourceIdentifier(
474
+ id=id,
475
+ name=name,
476
+ url=get_console_link(
477
+ resource=resource_type,
478
+ ocid=id,
479
+ region=region,
480
+ ),
481
+ )
482
+ except Exception as e:
483
+ logger.error(
484
+ f"Failed to construct AquaResourceIdentifier from given id=`{id}`, and name=`{name}`, {str(e)}"
485
+ )
486
+ return AquaResourceIdentifier()
487
+
488
+
489
+ def _get_experiment_info(
490
+ model: Union[oci.resource_search.models.ResourceSummary, DataScienceModel]
491
+ ) -> tuple:
492
+ """Returns ocid and name of the experiment."""
493
+ return (
494
+ (
495
+ model.additional_details.get(RqsAdditionalDetails.MODEL_VERSION_SET_ID),
496
+ model.additional_details.get(RqsAdditionalDetails.MODEL_VERSION_SET_NAME),
497
+ )
498
+ if isinstance(model, oci.resource_search.models.ResourceSummary)
499
+ else (model.model_version_set_id, model.model_version_set_name)
500
+ )
501
+
502
+
503
+ def _build_job_identifier(
504
+ job_run_details: Union[
505
+ oci.data_science.models.JobRun, oci.resource_search.models.ResourceSummary
506
+ ] = None,
507
+ **kwargs,
508
+ ) -> AquaResourceIdentifier:
509
+ try:
510
+ job_id = (
511
+ job_run_details.id
512
+ if isinstance(job_run_details, oci.data_science.models.JobRun)
513
+ else job_run_details.identifier
514
+ )
515
+ return _build_resource_identifier(
516
+ id=job_id, name=job_run_details.display_name, **kwargs
517
+ )
518
+
519
+ except Exception as e:
520
+ logger.debug(
521
+ f"Failed to get job details from job_run_details: {job_run_details}"
522
+ f"DEBUG INFO:{str(e)}"
523
+ )
524
+ return AquaResourceIdentifier()
525
+
526
+
527
+ def get_container_image(
528
+ config_file_name: str = None, container_type: str = None
529
+ ) -> str:
530
+ """Gets the image name from the given model and container type.
531
+ Parameters
532
+ ----------
533
+ config_file_name: str
534
+ name of the config file
535
+ container_type: str
536
+ type of container, can be either deployment-container, finetune-container, evaluation-container
537
+
538
+ Returns
539
+ -------
540
+ Dict:
541
+ A dict of allowed configs.
542
+ """
543
+
544
+ config_file_name = (
545
+ f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
546
+ )
547
+
548
+ config = load_config(
549
+ file_path=config_file_name,
550
+ config_file_name=CONTAINER_INDEX,
551
+ )
552
+
553
+ if container_type not in config:
554
+ raise AquaValueError(
555
+ f"{config_file_name} does not have config details for model: {container_type}"
556
+ )
557
+
558
+ container_image = None
559
+ mapping = config[container_type]
560
+ versions = [obj["version"] for obj in mapping]
561
+ # assumes numbered versions, update if `latest` is used
562
+ latest = get_max_version(versions)
563
+ for obj in mapping:
564
+ if obj["version"] == str(latest):
565
+ container_image = f"{obj['name']}:{obj['version']}"
566
+ break
567
+
568
+ if not container_image:
569
+ raise AquaValueError(
570
+ f"{config_file_name} is missing name and/or version details."
571
+ )
572
+
573
+ return container_image
574
+
575
+
576
+ def get_max_version(versions):
577
+ """Takes in a list of versions and returns the higher version."""
578
+ if not versions:
579
+ return UNKNOWN
580
+
581
+ def compare_versions(version1, version2):
582
+ # split version strings into parts and convert to int values for comparison
583
+ parts1 = list(map(int, version1.split(".")))
584
+ parts2 = list(map(int, version2.split(".")))
585
+
586
+ # compare each part
587
+ for idx in range(min(len(parts1), len(parts2))):
588
+ if parts1[idx] < parts2[idx]:
589
+ return version2
590
+ elif parts1[idx] > parts2[idx]:
591
+ return version1
592
+
593
+ # if all parts are equal up to this point, return the longer version string
594
+ return version1 if len(parts1) > len(parts2) else version2
595
+
596
+ max_version = versions[0]
597
+ for version in versions[1:]:
598
+ max_version = compare_versions(max_version, version)
599
+
600
+ return max_version
601
+
602
+
603
+ def fire_and_forget(func):
604
+ """Decorator to push execution of methods to the background."""
605
+
606
+ @wraps(func)
607
+ def wrapped(*args, **kwargs):
608
+ return asyncio.get_event_loop().run_in_executor(None, func, *args, *kwargs)
609
+
610
+ return wrapped
611
+
612
+
613
+ def get_base_model_from_tags(tags):
614
+ base_model_ocid = ""
615
+ base_model_name = ""
616
+ if Tags.AQUA_FINE_TUNED_MODEL_TAG.value in tags:
617
+ tag = tags[Tags.AQUA_FINE_TUNED_MODEL_TAG.value]
618
+ if "#" in tag:
619
+ base_model_ocid, base_model_name = tag.split("#")
620
+
621
+ if not (is_valid_ocid(base_model_ocid) and base_model_name):
622
+ raise AquaValueError(
623
+ f"{Tags.AQUA_FINE_TUNED_MODEL_TAG.value} tag should have the format `Service Model OCID#Model Name`."
624
+ )
625
+
626
+ return base_model_ocid, base_model_name
627
+
628
+
629
+ def get_resource_name(ocid: str) -> str:
630
+ """Gets resource name based on the given ocid.
631
+
632
+ Parameters
633
+ ----------
634
+ ocid: str
635
+ Oracle Cloud Identifier (OCID).
636
+
637
+ Returns
638
+ -------
639
+ str:
640
+ The resource name indicated in the given ocid.
641
+
642
+ Raises
643
+ -------
644
+ ValueError:
645
+ When the given ocid is not a valid ocid.
646
+ """
647
+ if not is_valid_ocid(ocid):
648
+ raise ValueError(
649
+ f"The given ocid {ocid} is not a valid ocid."
650
+ "Check out this page https://docs.oracle.com/en-us/iaas/Content/General/Concepts/identifiers.htm to see more details."
651
+ )
652
+ try:
653
+ resource = query_resource(ocid, return_all=False)
654
+ name = resource.display_name if resource else UNKNOWN
655
+ except:
656
+ name = UNKNOWN
657
+ return name
658
+
659
+
660
+ def get_model_by_reference_paths(model_file_description: dict):
661
+ """Reads the model file description json dict and returns the base model path and fine-tuned path for
662
+ models created by reference.
663
+
664
+ Parameters
665
+ ----------
666
+ model_file_description: dict
667
+ json dict containing model paths and objects for models created by reference.
668
+
669
+ Returns
670
+ -------
671
+ a tuple with base_model_path and fine_tune_output_path
672
+ """
673
+ base_model_path = UNKNOWN
674
+ fine_tune_output_path = UNKNOWN
675
+ models = model_file_description["models"]
676
+
677
+ for model in models:
678
+ namespace, bucket_name, prefix = (
679
+ model["namespace"],
680
+ model["bucketName"],
681
+ model["prefix"],
682
+ )
683
+ bucket_uri = f"oci://{bucket_name}@{namespace}/{prefix}".rstrip("/")
684
+ if bucket_name == AQUA_SERVICE_MODELS_BUCKET:
685
+ base_model_path = bucket_uri
686
+ else:
687
+ fine_tune_output_path = bucket_uri
688
+
689
+ if not base_model_path:
690
+ raise AquaValueError(
691
+ f"Base Model should come from the bucket {AQUA_SERVICE_MODELS_BUCKET}. "
692
+ f"Other paths are not supported by Aqua."
693
+ )
694
+ return base_model_path, fine_tune_output_path
695
+
696
+
697
+ def _is_valid_mvs(mvs: ModelVersionSet, target_tag: str) -> bool:
698
+ """Returns whether the given model version sets has the target tag.
699
+
700
+ Parameters
701
+ ----------
702
+ mvs: str
703
+ The instance of `ads.model.ModelVersionSet`.
704
+ target_tag: list
705
+ Target tag expected to be in MVS.
706
+
707
+ Returns
708
+ -------
709
+ bool:
710
+ Return True if the given model version sets is valid.
711
+ """
712
+ if mvs.freeform_tags is None:
713
+ return False
714
+
715
+ return target_tag in mvs.freeform_tags