oracle-ads 2.10.1__py3-none-any.whl → 2.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. ads/aqua/__init__.py +12 -0
  2. ads/aqua/base.py +324 -0
  3. ads/aqua/cli.py +19 -0
  4. ads/aqua/config/deployment_config_defaults.json +9 -0
  5. ads/aqua/config/resource_limit_names.json +7 -0
  6. ads/aqua/constants.py +45 -0
  7. ads/aqua/data.py +40 -0
  8. ads/aqua/decorator.py +101 -0
  9. ads/aqua/deployment.py +643 -0
  10. ads/aqua/dummy_data/icon.txt +1 -0
  11. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  12. ads/aqua/dummy_data/oci_models.json +1 -0
  13. ads/aqua/dummy_data/readme.md +26 -0
  14. ads/aqua/evaluation.py +1751 -0
  15. ads/aqua/exception.py +82 -0
  16. ads/aqua/extension/__init__.py +40 -0
  17. ads/aqua/extension/base_handler.py +138 -0
  18. ads/aqua/extension/common_handler.py +21 -0
  19. ads/aqua/extension/deployment_handler.py +202 -0
  20. ads/aqua/extension/evaluation_handler.py +135 -0
  21. ads/aqua/extension/finetune_handler.py +66 -0
  22. ads/aqua/extension/model_handler.py +59 -0
  23. ads/aqua/extension/ui_handler.py +201 -0
  24. ads/aqua/extension/utils.py +23 -0
  25. ads/aqua/finetune.py +579 -0
  26. ads/aqua/job.py +29 -0
  27. ads/aqua/model.py +819 -0
  28. ads/aqua/training/__init__.py +4 -0
  29. ads/aqua/training/exceptions.py +459 -0
  30. ads/aqua/ui.py +453 -0
  31. ads/aqua/utils.py +715 -0
  32. ads/cli.py +37 -6
  33. ads/common/decorator/__init__.py +7 -3
  34. ads/common/decorator/require_nonempty_arg.py +65 -0
  35. ads/common/object_storage_details.py +166 -7
  36. ads/common/oci_client.py +18 -1
  37. ads/common/oci_logging.py +2 -2
  38. ads/common/oci_mixin.py +4 -5
  39. ads/common/serializer.py +34 -5
  40. ads/common/utils.py +75 -10
  41. ads/config.py +40 -1
  42. ads/jobs/ads_job.py +43 -25
  43. ads/jobs/builders/infrastructure/base.py +4 -2
  44. ads/jobs/builders/infrastructure/dsc_job.py +49 -39
  45. ads/jobs/builders/runtimes/base.py +71 -1
  46. ads/jobs/builders/runtimes/container_runtime.py +4 -4
  47. ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
  48. ads/jobs/templates/driver_pytorch.py +27 -10
  49. ads/model/artifact_downloader.py +84 -14
  50. ads/model/artifact_uploader.py +25 -23
  51. ads/model/datascience_model.py +388 -38
  52. ads/model/deployment/model_deployment.py +10 -2
  53. ads/model/generic_model.py +8 -0
  54. ads/model/model_file_description_schema.json +68 -0
  55. ads/model/model_metadata.py +1 -1
  56. ads/model/service/oci_datascience_model.py +34 -5
  57. ads/opctl/operator/lowcode/anomaly/README.md +2 -1
  58. ads/opctl/operator/lowcode/anomaly/__main__.py +10 -4
  59. ads/opctl/operator/lowcode/anomaly/environment.yaml +2 -1
  60. ads/opctl/operator/lowcode/anomaly/model/automlx.py +12 -6
  61. ads/opctl/operator/lowcode/forecast/README.md +3 -2
  62. ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
  63. ads/opctl/operator/lowcode/forecast/model/automlx.py +12 -23
  64. ads/telemetry/base.py +62 -0
  65. ads/telemetry/client.py +105 -0
  66. ads/telemetry/telemetry.py +6 -3
  67. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/METADATA +37 -7
  68. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/RECORD +71 -36
  69. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/LICENSE.txt +0 -0
  70. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/WHEEL +0 -0
  71. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.1.dist-info}/entry_points.txt +0 -0
ads/aqua/model.py ADDED
@@ -0,0 +1,819 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+ import re
6
+ from dataclasses import InitVar, dataclass, field
7
+ from datetime import datetime, timedelta
8
+ from enum import Enum
9
+ import os
10
+ from threading import Lock
11
+ from typing import List, Union
12
+
13
+ import oci
14
+ from cachetools import TTLCache
15
+ from oci.data_science.models import JobRun, Model
16
+
17
+ from ads.aqua import logger, utils
18
+ from ads.aqua.base import AquaApp
19
+ from ads.aqua.constants import (
20
+ TRAINING_METRICS_FINAL,
21
+ TRINING_METRICS,
22
+ UNKNOWN_VALUE,
23
+ VALIDATION_METRICS,
24
+ VALIDATION_METRICS_FINAL,
25
+ FineTuningDefinedMetadata,
26
+ )
27
+ from ads.aqua.data import AquaResourceIdentifier, Tags
28
+
29
+ from ads.aqua.exception import AquaRuntimeError
30
+ from ads.aqua.utils import (
31
+ LICENSE_TXT,
32
+ README,
33
+ READY_TO_DEPLOY_STATUS,
34
+ UNKNOWN,
35
+ create_word_icon,
36
+ get_artifact_path,
37
+ read_file,
38
+ )
39
+ from ads.aqua.training.exceptions import exit_code_dict
40
+ from ads.common.auth import default_signer
41
+ from ads.common.object_storage_details import ObjectStorageDetails
42
+ from ads.common.oci_resource import SEARCH_TYPE, OCIResource
43
+ from ads.common.serializer import DataClassSerializable
44
+ from ads.common.utils import get_console_link, get_log_links
45
+ from ads.config import (
46
+ AQUA_SERVICE_MODELS_BUCKET,
47
+ COMPARTMENT_OCID,
48
+ ODSC_MODEL_COMPARTMENT_OCID,
49
+ PROJECT_OCID,
50
+ TENANCY_OCID,
51
+ CONDA_BUCKET_NS,
52
+ )
53
+ from ads.model import DataScienceModel
54
+ from ads.model.model_metadata import MetadataTaxonomyKeys, ModelCustomMetadata
55
+ from ads.telemetry import telemetry
56
+
57
+
58
+ class FineTuningMetricCategories(Enum):
59
+ VALIDATION = "validation"
60
+ TRAINING = "training"
61
+
62
+
63
+ @dataclass(repr=False)
64
+ class FineTuningShapeInfo(DataClassSerializable):
65
+ instance_shape: str = field(default_factory=str)
66
+ replica: int = field(default_factory=int)
67
+
68
+
69
+ # TODO: give a better name
70
+ @dataclass(repr=False)
71
+ class AquaFineTuneValidation(DataClassSerializable):
72
+ type: str = "Automatic split"
73
+ value: str = ""
74
+
75
+
76
+ @dataclass(repr=False)
77
+ class AquaFineTuningMetric(DataClassSerializable):
78
+ name: str = field(default_factory=str)
79
+ category: str = field(default_factory=str)
80
+ scores: list = field(default_factory=list)
81
+
82
+
83
+ @dataclass(repr=False)
84
+ class AquaModelLicense(DataClassSerializable):
85
+ """Represents the response of Get Model License."""
86
+
87
+ id: str = field(default_factory=str)
88
+ license: str = field(default_factory=str)
89
+
90
+
91
+ @dataclass(repr=False)
92
+ class AquaModelSummary(DataClassSerializable):
93
+ """Represents a summary of Aqua model."""
94
+
95
+ compartment_id: str = None
96
+ icon: str = None
97
+ id: str = None
98
+ is_fine_tuned_model: bool = None
99
+ license: str = None
100
+ name: str = None
101
+ organization: str = None
102
+ project_id: str = None
103
+ tags: dict = None
104
+ task: str = None
105
+ time_created: str = None
106
+ console_link: str = None
107
+ search_text: str = None
108
+ ready_to_deploy: bool = True
109
+
110
+
111
+ @dataclass(repr=False)
112
+ class AquaModel(AquaModelSummary, DataClassSerializable):
113
+ """Represents an Aqua model."""
114
+
115
+ model_card: str = None
116
+
117
+
118
+ @dataclass(repr=False)
119
+ class AquaEvalFTCommon(DataClassSerializable):
120
+ """Represents common fields for evaluation and fine-tuning."""
121
+
122
+ lifecycle_state: str = None
123
+ lifecycle_details: str = None
124
+ job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
125
+ source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
126
+ experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
127
+ log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
128
+ log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
129
+
130
+ model: InitVar = None
131
+ region: InitVar = None
132
+ jobrun: InitVar = None
133
+
134
+ def __post_init__(
135
+ self, model, region: str, jobrun: oci.data_science.models.JobRun = None
136
+ ):
137
+ try:
138
+ log_id = jobrun.log_details.log_id
139
+ except Exception as e:
140
+ logger.debug(f"No associated log found. {str(e)}")
141
+ log_id = ""
142
+
143
+ try:
144
+ loggroup_id = jobrun.log_details.log_group_id
145
+ except Exception as e:
146
+ logger.debug(f"No associated loggroup found. {str(e)}")
147
+ loggroup_id = ""
148
+
149
+ loggroup_url = get_log_links(region=region, log_group_id=loggroup_id)
150
+ log_url = (
151
+ get_log_links(
152
+ region=region,
153
+ log_group_id=loggroup_id,
154
+ log_id=log_id,
155
+ compartment_id=jobrun.compartment_id,
156
+ source_id=jobrun.id,
157
+ )
158
+ if jobrun
159
+ else ""
160
+ )
161
+
162
+ log_name = None
163
+ loggroup_name = None
164
+
165
+ if log_id:
166
+ try:
167
+ log = utils.query_resource(log_id, return_all=False)
168
+ log_name = log.display_name if log else ""
169
+ except:
170
+ pass
171
+
172
+ if loggroup_id:
173
+ try:
174
+ loggroup = utils.query_resource(loggroup_id, return_all=False)
175
+ loggroup_name = loggroup.display_name if loggroup else ""
176
+ except:
177
+ pass
178
+
179
+ experiment_id, experiment_name = utils._get_experiment_info(model)
180
+
181
+ self.log_group = AquaResourceIdentifier(
182
+ loggroup_id, loggroup_name, loggroup_url
183
+ )
184
+ self.log = AquaResourceIdentifier(log_id, log_name, log_url)
185
+ self.experiment = utils._build_resource_identifier(
186
+ id=experiment_id, name=experiment_name, region=region
187
+ )
188
+ self.job = utils._build_job_identifier(job_run_details=jobrun, region=region)
189
+ self.lifecycle_details = (
190
+ utils.LIFECYCLE_DETAILS_MISSING_JOBRUN
191
+ if not jobrun
192
+ else jobrun.lifecycle_details
193
+ )
194
+
195
+
196
+ @dataclass(repr=False)
197
+ class AquaFineTuneModel(AquaModel, AquaEvalFTCommon, DataClassSerializable):
198
+ """Represents an Aqua Fine Tuned Model."""
199
+
200
+ dataset: str = field(default_factory=str)
201
+ validation: AquaFineTuneValidation = field(default_factory=AquaFineTuneValidation)
202
+ shape_info: FineTuningShapeInfo = field(default_factory=FineTuningShapeInfo)
203
+ metrics: List[AquaFineTuningMetric] = field(default_factory=list)
204
+
205
+ def __post_init__(
206
+ self,
207
+ model: DataScienceModel,
208
+ region: str,
209
+ jobrun: oci.data_science.models.JobRun = None,
210
+ ):
211
+ super().__post_init__(model=model, region=region, jobrun=jobrun)
212
+
213
+ if jobrun is not None:
214
+ jobrun_env_vars = (
215
+ jobrun.job_configuration_override_details.environment_variables or {}
216
+ )
217
+ self.shape_info = FineTuningShapeInfo(
218
+ instance_shape=jobrun.job_infrastructure_configuration_details.shape_name,
219
+ # TODO: use variable for `NODE_COUNT` in ads/jobs/builders/runtimes/base.py
220
+ replica=jobrun_env_vars.get("NODE_COUNT", UNKNOWN_VALUE),
221
+ )
222
+
223
+ try:
224
+ model_hyperparameters = model.defined_metadata_list.get(
225
+ MetadataTaxonomyKeys.HYPERPARAMETERS
226
+ ).value
227
+ except Exception as e:
228
+ logger.debug(
229
+ f"Failed to extract model hyperparameters from {model.id}:" f"{str(e)}"
230
+ )
231
+ model_hyperparameters = {}
232
+
233
+ self.dataset = model_hyperparameters.get(
234
+ FineTuningDefinedMetadata.TRAINING_DATA.value
235
+ )
236
+ if not self.dataset:
237
+ logger.debug(
238
+ f"Key={FineTuningDefinedMetadata.TRAINING_DATA.value} not found in model hyperparameters."
239
+ )
240
+
241
+ self.validation = AquaFineTuneValidation(
242
+ value=model_hyperparameters.get(
243
+ FineTuningDefinedMetadata.VAL_SET_SIZE.value
244
+ )
245
+ )
246
+ if not self.validation:
247
+ logger.debug(
248
+ f"Key={FineTuningDefinedMetadata.VAL_SET_SIZE.value} not found in model hyperparameters."
249
+ )
250
+
251
+ if self.lifecycle_details:
252
+ self.lifecycle_details = self._extract_job_lifecycle_details(
253
+ self.lifecycle_details
254
+ )
255
+
256
+ def _extract_job_lifecycle_details(self, lifecycle_details):
257
+ message = lifecycle_details
258
+ try:
259
+ # Extract exit code
260
+ match = re.search(r"exit code (\d+)", lifecycle_details)
261
+ if match:
262
+ exit_code = int(match.group(1))
263
+ if exit_code == 1:
264
+ return message
265
+ # Match exit code to message
266
+ exception = exit_code_dict().get(
267
+ exit_code,
268
+ lifecycle_details,
269
+ )
270
+ message = f"{exception.reason} (exit code {exit_code})"
271
+ except:
272
+ pass
273
+
274
+ return message
275
+
276
+
277
+ # TODO: merge metadata key used in create FT
278
+
279
+
280
+ class FineTuningCustomMetadata(Enum):
281
+ FT_SOURCE = "fine_tune_source"
282
+ FT_SOURCE_NAME = "fine_tune_source_name"
283
+ FT_OUTPUT_PATH = "fine_tune_output_path"
284
+ FT_JOB_ID = "fine_tune_job_id"
285
+ FT_JOB_RUN_ID = "fine_tune_jobrun_id"
286
+ TRAINING_METRICS_FINAL = "train_metrics_final"
287
+ VALIDATION_METRICS_FINAL = "val_metrics_final"
288
+ TRAINING_METRICS_EPOCH = "train_metrics_epoch"
289
+ VALIDATION_METRICS_EPOCH = "val_metrics_epoch"
290
+
291
+
292
+ class AquaModelApp(AquaApp):
293
+ """Provides a suite of APIs to interact with Aqua models within the Oracle
294
+ Cloud Infrastructure Data Science service, serving as an interface for
295
+ managing machine learning models.
296
+
297
+
298
+ Methods
299
+ -------
300
+ create(model_id: str, project_id: str, compartment_id: str = None, **kwargs) -> "AquaModel"
301
+ Creates custom aqua model from service model.
302
+ get(model_id: str) -> AquaModel:
303
+ Retrieves details of an Aqua model by its unique identifier.
304
+ list(compartment_id: str = None, project_id: str = None, **kwargs) -> List[AquaModelSummary]:
305
+ Lists all Aqua models within a specified compartment and/or project.
306
+ clear_model_list_cache()
307
+ Allows clear list model cache items from the service models compartment.
308
+
309
+ Note:
310
+ This class is designed to work within the Oracle Cloud Infrastructure
311
+ and requires proper configuration and authentication set up to interact
312
+ with OCI services.
313
+ """
314
+
315
+ _service_models_cache = TTLCache(
316
+ maxsize=10, ttl=timedelta(hours=5), timer=datetime.now
317
+ )
318
+ # Used for saving service model details
319
+ _service_model_details_cache = TTLCache(
320
+ maxsize=10, ttl=timedelta(hours=5), timer=datetime.now
321
+ )
322
+ _cache_lock = Lock()
323
+
324
+ @telemetry(entry_point="plugin=model&action=create", name="aqua")
325
+ def create(
326
+ self, model_id: str, project_id: str, compartment_id: str = None, **kwargs
327
+ ) -> DataScienceModel:
328
+ """Creates custom aqua model from service model.
329
+
330
+ Parameters
331
+ ----------
332
+ model_id: str
333
+ The service model id.
334
+ project_id: str
335
+ The project id for custom model.
336
+ compartment_id: str
337
+ The compartment id for custom model. Defaults to None.
338
+ If not provided, compartment id will be fetched from environment variables.
339
+
340
+ Returns
341
+ -------
342
+ DataScienceModel:
343
+ The instance of DataScienceModel.
344
+ """
345
+ service_model = DataScienceModel.from_id(model_id)
346
+ target_project = project_id or PROJECT_OCID
347
+ target_compartment = compartment_id or COMPARTMENT_OCID
348
+
349
+ if service_model.compartment_id != ODSC_MODEL_COMPARTMENT_OCID:
350
+ logger.debug(
351
+ f"Aqua Model {model_id} already exists in user's compartment."
352
+ "Skipped copying."
353
+ )
354
+ return service_model
355
+
356
+ custom_model = (
357
+ DataScienceModel()
358
+ .with_compartment_id(target_compartment)
359
+ .with_project_id(target_project)
360
+ .with_model_file_description(json_dict=service_model.model_file_description)
361
+ .with_display_name(service_model.display_name)
362
+ .with_description(service_model.description)
363
+ .with_freeform_tags(**(service_model.freeform_tags or {}))
364
+ .with_defined_tags(**(service_model.defined_tags or {}))
365
+ .with_custom_metadata_list(service_model.custom_metadata_list)
366
+ .with_defined_metadata_list(service_model.defined_metadata_list)
367
+ .with_provenance_metadata(service_model.provenance_metadata)
368
+ # TODO: decide what kwargs will be needed.
369
+ .create(model_by_reference=True, **kwargs)
370
+ )
371
+ logger.debug(
372
+ f"Aqua Model {custom_model.id} created with the service model {model_id}"
373
+ )
374
+
375
+ # tracks unique models that were created in the user compartment
376
+ self.telemetry.record_event_async(
377
+ category="aqua/service/model",
378
+ action="create",
379
+ detail=service_model.display_name,
380
+ )
381
+
382
+ return custom_model
383
+
384
+ @telemetry(entry_point="plugin=model&action=get", name="aqua")
385
+ def get(self, model_id) -> "AquaModel":
386
+ """Gets the information of an Aqua model.
387
+
388
+ Parameters
389
+ ----------
390
+ model_id: str
391
+ The model OCID.
392
+
393
+ Returns
394
+ -------
395
+ AquaModel:
396
+ The instance of AquaModel.
397
+ """
398
+
399
+ cached_item = self._service_model_details_cache.get(model_id)
400
+ if cached_item:
401
+ return cached_item
402
+
403
+ ds_model = DataScienceModel.from_id(model_id)
404
+ if not self._if_show(ds_model):
405
+ raise AquaRuntimeError(f"Target model `{ds_model.id} `is not Aqua model.")
406
+
407
+ is_fine_tuned_model = (
408
+ True
409
+ if ds_model.freeform_tags
410
+ and ds_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG.value)
411
+ else False
412
+ )
413
+
414
+ # todo: consolidate this logic in utils for model and deployment use
415
+ try:
416
+ artifact_path = ds_model.custom_metadata_list.get(
417
+ utils.MODEL_BY_REFERENCE_OSS_PATH_KEY
418
+ ).value.rstrip("/")
419
+ if not ObjectStorageDetails.is_oci_path(artifact_path):
420
+ artifact_path = ObjectStorageDetails(
421
+ AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, artifact_path
422
+ ).path
423
+ except ValueError:
424
+ artifact_path = utils.UNKNOWN
425
+
426
+ if not artifact_path:
427
+ logger.debug("Failed to get artifact path from custom metadata.")
428
+
429
+ aqua_model_atttributes = dict(
430
+ **self._process_model(ds_model, self.region),
431
+ project_id=ds_model.project_id,
432
+ model_card=str(
433
+ read_file(
434
+ file_path=f"{artifact_path}/{README}",
435
+ auth=self._auth,
436
+ )
437
+ ),
438
+ )
439
+
440
+ if not is_fine_tuned_model:
441
+ model_details = AquaModel(**aqua_model_atttributes)
442
+ self._service_model_details_cache.__setitem__(
443
+ key=model_id, value=model_details
444
+ )
445
+
446
+ else:
447
+ try:
448
+ jobrun_ocid = ds_model.provenance_metadata.training_id
449
+ jobrun = self.ds_client.get_job_run(jobrun_ocid).data
450
+ except Exception as e:
451
+ logger.debug(
452
+ f"Missing jobrun information in the provenance metadata of the given model {model_id}."
453
+ )
454
+ jobrun = None
455
+
456
+ try:
457
+ source_id = ds_model.custom_metadata_list.get(
458
+ FineTuningCustomMetadata.FT_SOURCE.value
459
+ ).value
460
+ except ValueError as e:
461
+ logger.debug(str(e))
462
+ source_id = UNKNOWN
463
+
464
+ try:
465
+ source_name = ds_model.custom_metadata_list.get(
466
+ FineTuningCustomMetadata.FT_SOURCE_NAME.value
467
+ ).value
468
+ except ValueError as e:
469
+ logger.debug(str(e))
470
+ source_name = UNKNOWN
471
+
472
+ source_identifier = utils._build_resource_identifier(
473
+ id=source_id,
474
+ name=source_name,
475
+ region=self.region,
476
+ )
477
+
478
+ ft_metrics = self._build_ft_metrics(ds_model.custom_metadata_list)
479
+
480
+ job_run_status = (
481
+ jobrun.lifecycle_state
482
+ if jobrun
483
+ and not jobrun.lifecycle_state == JobRun.LIFECYCLE_STATE_DELETED
484
+ else (
485
+ JobRun.LIFECYCLE_STATE_SUCCEEDED
486
+ if self.if_artifact_exist(ds_model.id)
487
+ else JobRun.LIFECYCLE_STATE_FAILED
488
+ )
489
+ )
490
+ # TODO: change the argument's name.
491
+ lifecycle_state = utils.LifecycleStatus.get_status(
492
+ evaluation_status=ds_model.lifecycle_state,
493
+ job_run_status=job_run_status,
494
+ )
495
+
496
+ model_details = AquaFineTuneModel(
497
+ **aqua_model_atttributes,
498
+ source=source_identifier,
499
+ lifecycle_state=(
500
+ Model.LIFECYCLE_STATE_ACTIVE
501
+ if lifecycle_state == JobRun.LIFECYCLE_STATE_SUCCEEDED
502
+ else lifecycle_state
503
+ ),
504
+ metrics=ft_metrics,
505
+ model=ds_model,
506
+ jobrun=jobrun,
507
+ region=self.region,
508
+ )
509
+
510
+ return model_details
511
+
512
+ def _fetch_metric_from_metadata(
513
+ self,
514
+ custom_metadata_list: ModelCustomMetadata,
515
+ target: str,
516
+ category: str,
517
+ metric_name: str,
518
+ ) -> AquaFineTuningMetric:
519
+ """Gets target metric from `ads.model.model_metadata.ModelCustomMetadata`."""
520
+ try:
521
+ scores = []
522
+ for custom_metadata in custom_metadata_list._items:
523
+ # We use description to group metrics
524
+ if custom_metadata.description == target:
525
+ scores.append(custom_metadata.value)
526
+ if metric_name.endswith("final"):
527
+ break
528
+
529
+ return AquaFineTuningMetric(
530
+ name=metric_name,
531
+ category=category,
532
+ scores=scores,
533
+ )
534
+ except:
535
+ return AquaFineTuningMetric(name=metric_name, category=category, scores=[])
536
+
537
+ def _build_ft_metrics(
538
+ self, custom_metadata_list: ModelCustomMetadata
539
+ ) -> List[AquaFineTuningMetric]:
540
+ """Builds Fine Tuning metrics."""
541
+
542
+ validation_metrics = self._fetch_metric_from_metadata(
543
+ custom_metadata_list=custom_metadata_list,
544
+ target=FineTuningCustomMetadata.VALIDATION_METRICS_EPOCH.value,
545
+ category=FineTuningMetricCategories.VALIDATION.value,
546
+ metric_name=VALIDATION_METRICS,
547
+ )
548
+
549
+ training_metrics = self._fetch_metric_from_metadata(
550
+ custom_metadata_list=custom_metadata_list,
551
+ target=FineTuningCustomMetadata.TRAINING_METRICS_EPOCH.value,
552
+ category=FineTuningMetricCategories.TRAINING.value,
553
+ metric_name=TRINING_METRICS,
554
+ )
555
+
556
+ validation_final = self._fetch_metric_from_metadata(
557
+ custom_metadata_list=custom_metadata_list,
558
+ target=FineTuningCustomMetadata.VALIDATION_METRICS_FINAL.value,
559
+ category=FineTuningMetricCategories.VALIDATION.value,
560
+ metric_name=VALIDATION_METRICS_FINAL,
561
+ )
562
+
563
+ training_final = self._fetch_metric_from_metadata(
564
+ custom_metadata_list=custom_metadata_list,
565
+ target=FineTuningCustomMetadata.TRAINING_METRICS_FINAL.value,
566
+ category=FineTuningMetricCategories.TRAINING.value,
567
+ metric_name=TRAINING_METRICS_FINAL,
568
+ )
569
+
570
+ return [
571
+ validation_metrics,
572
+ training_metrics,
573
+ validation_final,
574
+ training_final,
575
+ ]
576
+
577
+ def _process_model(
578
+ self,
579
+ model: Union[
580
+ DataScienceModel,
581
+ oci.data_science.models.model.Model,
582
+ oci.data_science.models.ModelSummary,
583
+ oci.resource_search.models.ResourceSummary,
584
+ ],
585
+ region: str,
586
+ ) -> dict:
587
+ """Constructs required fields for AquaModelSummary."""
588
+
589
+ # todo: revisit icon generation code
590
+ # icon = self._load_icon(model.display_name)
591
+ icon = ""
592
+
593
+ tags = {}
594
+ tags.update(model.defined_tags or {})
595
+ tags.update(model.freeform_tags or {})
596
+
597
+ model_id = (
598
+ model.identifier
599
+ if isinstance(model, oci.resource_search.models.ResourceSummary)
600
+ else model.id
601
+ )
602
+
603
+ console_link = (
604
+ get_console_link(
605
+ resource="models",
606
+ ocid=model_id,
607
+ region=region,
608
+ ),
609
+ )
610
+
611
+ description = ""
612
+ if isinstance(model, DataScienceModel) or isinstance(
613
+ model, oci.data_science.models.model.Model
614
+ ):
615
+ description = model.description
616
+ elif isinstance(model, oci.resource_search.models.ResourceSummary):
617
+ description = model.additional_details.get("description")
618
+
619
+ search_text = (
620
+ self._build_search_text(tags=tags, description=description)
621
+ if tags
622
+ else UNKNOWN
623
+ )
624
+
625
+ freeform_tags = model.freeform_tags or {}
626
+ is_fine_tuned_model = Tags.AQUA_FINE_TUNED_MODEL_TAG.value in freeform_tags
627
+ ready_to_deploy = (
628
+ freeform_tags.get(Tags.AQUA_TAG.value, "").upper() == READY_TO_DEPLOY_STATUS
629
+ if is_fine_tuned_model
630
+ else True
631
+ )
632
+
633
+ return dict(
634
+ compartment_id=model.compartment_id,
635
+ icon=icon or UNKNOWN,
636
+ id=model_id,
637
+ license=freeform_tags.get(Tags.LICENSE.value, UNKNOWN),
638
+ name=model.display_name,
639
+ organization=freeform_tags.get(Tags.ORGANIZATION.value, UNKNOWN),
640
+ task=freeform_tags.get(Tags.TASK.value, UNKNOWN),
641
+ time_created=model.time_created,
642
+ is_fine_tuned_model=is_fine_tuned_model,
643
+ tags=tags,
644
+ console_link=console_link,
645
+ search_text=search_text,
646
+ ready_to_deploy=ready_to_deploy,
647
+ )
648
+
649
+ @telemetry(entry_point="plugin=model&action=list", name="aqua")
650
+ def list(
651
+ self, compartment_id: str = None, project_id: str = None, **kwargs
652
+ ) -> List["AquaModelSummary"]:
653
+ """Lists all Aqua models within a specified compartment and/or project.
654
+ If `compartment_id` is not specified, the method defaults to returning
655
+ the service models within the pre-configured default compartment. By default, the list
656
+ of models in the service compartment are cached. Use clear_model_list_cache() to invalidate
657
+ the cache.
658
+
659
+ Parameters
660
+ ----------
661
+ compartment_id: (str, optional). Defaults to `None`.
662
+ The compartment OCID.
663
+ project_id: (str, optional). Defaults to `None`.
664
+ The project OCID.
665
+ **kwargs:
666
+ Additional keyword arguments that can be used to filter the results.
667
+
668
+ Returns
669
+ -------
670
+ List[AquaModelSummary]:
671
+ The list of the `ads.aqua.model.AquaModelSummary`.
672
+ """
673
+
674
+ models = []
675
+ if compartment_id:
676
+ # tracks number of times custom model listing was called
677
+ self.telemetry.record_event_async(
678
+ category="aqua/custom/model", action="list"
679
+ )
680
+
681
+ logger.info(f"Fetching custom models from compartment_id={compartment_id}.")
682
+ models = self._rqs(compartment_id)
683
+ else:
684
+ # tracks number of times service model listing was called
685
+ self.telemetry.record_event_async(
686
+ category="aqua/service/model", action="list"
687
+ )
688
+
689
+ if ODSC_MODEL_COMPARTMENT_OCID in self._service_models_cache.keys():
690
+ logger.info(
691
+ f"Returning service models list in {ODSC_MODEL_COMPARTMENT_OCID} from cache."
692
+ )
693
+ return self._service_models_cache.get(ODSC_MODEL_COMPARTMENT_OCID)
694
+ logger.info(
695
+ f"Fetching service models from compartment_id={ODSC_MODEL_COMPARTMENT_OCID}"
696
+ )
697
+ lifecycle_state = kwargs.pop(
698
+ "lifecycle_state", Model.LIFECYCLE_STATE_ACTIVE
699
+ )
700
+
701
+ models = self.list_resource(
702
+ self.ds_client.list_models,
703
+ compartment_id=ODSC_MODEL_COMPARTMENT_OCID,
704
+ lifecycle_state=lifecycle_state,
705
+ **kwargs,
706
+ )
707
+
708
+ logger.info(
709
+ f"Fetch {len(models)} model in compartment_id={compartment_id or ODSC_MODEL_COMPARTMENT_OCID}."
710
+ )
711
+
712
+ aqua_models = []
713
+
714
+ for model in models:
715
+ aqua_models.append(
716
+ AquaModelSummary(
717
+ **self._process_model(model=model, region=self.region),
718
+ project_id=project_id or UNKNOWN,
719
+ )
720
+ )
721
+
722
+ if not compartment_id:
723
+ self._service_models_cache.__setitem__(
724
+ key=ODSC_MODEL_COMPARTMENT_OCID, value=aqua_models
725
+ )
726
+
727
+ return aqua_models
728
+
729
+ def clear_model_list_cache(
730
+ self,
731
+ ):
732
+ """
733
+ Allows user to clear list model cache items from the service models compartment.
734
+ Returns
735
+ -------
736
+ dict with the key used, and True if cache has the key that needs to be deleted.
737
+ """
738
+ res = {}
739
+ logger.info(f"Clearing _service_models_cache")
740
+ with self._cache_lock:
741
+ if ODSC_MODEL_COMPARTMENT_OCID in self._service_models_cache.keys():
742
+ self._service_models_cache.pop(key=ODSC_MODEL_COMPARTMENT_OCID)
743
+ res = {
744
+ "key": {
745
+ "compartment_id": ODSC_MODEL_COMPARTMENT_OCID,
746
+ },
747
+ "cache_deleted": True,
748
+ }
749
+ return res
750
+
751
+ def _if_show(self, model: DataScienceModel) -> bool:
752
+ """Determine if the given model should be return by `list`."""
753
+ if model.freeform_tags is None:
754
+ return False
755
+
756
+ TARGET_TAGS = model.freeform_tags.keys()
757
+ return (
758
+ Tags.AQUA_TAG.value in TARGET_TAGS
759
+ or Tags.AQUA_TAG.value.lower() in TARGET_TAGS
760
+ )
761
+
762
+ def _load_icon(self, model_name: str) -> str:
763
+ """Loads icon."""
764
+
765
+ # TODO: switch to the official logo
766
+ try:
767
+ return create_word_icon(model_name, return_as_datauri=True)
768
+ except Exception as e:
769
+ logger.debug(f"Failed to load icon for the model={model_name}: {str(e)}.")
770
+ return None
771
+
772
+ def _rqs(self, compartment_id: str, **kwargs):
773
+ """Use RQS to fetch models in the user tenancy."""
774
+
775
+ condition_tags = f"&& (freeformTags.key = '{Tags.AQUA_TAG.value}' && freeformTags.key = '{Tags.AQUA_FINE_TUNED_MODEL_TAG.value}')"
776
+ condition_lifecycle = "&& lifecycleState = 'ACTIVE'"
777
+ query = f"query datasciencemodel resources where (compartmentId = '{compartment_id}' {condition_lifecycle} {condition_tags})"
778
+ logger.info(query)
779
+ logger.info(f"tenant_id={TENANCY_OCID}")
780
+ return OCIResource.search(
781
+ query, type=SEARCH_TYPE.STRUCTURED, tenant_id=TENANCY_OCID, **kwargs
782
+ )
783
+
784
+ def _build_search_text(self, tags: dict, description: str = None) -> str:
785
+ """Constructs search_text field in response."""
786
+ description = description or ""
787
+ tags_text = (
788
+ ",".join(str(v) for v in tags.values()) if isinstance(tags, dict) else ""
789
+ )
790
+ separator = " " if description else ""
791
+ return f"{description}{separator}{tags_text}"
792
+
793
+ @telemetry(entry_point="plugin=model&action=load_license", name="aqua")
794
+ def load_license(self, model_id: str) -> AquaModelLicense:
795
+ """Loads the license full text for the given model.
796
+
797
+ Parameters
798
+ ----------
799
+ model_id: str
800
+ The model id.
801
+
802
+ Returns
803
+ -------
804
+ AquaModelLicense:
805
+ The instance of AquaModelLicense.
806
+ """
807
+ oci_model = self.ds_client.get_model(model_id).data
808
+ artifact_path = get_artifact_path(oci_model.custom_metadata_list)
809
+ if not artifact_path:
810
+ raise AquaRuntimeError("Failed to get artifact path from custom metadata.")
811
+
812
+ content = str(
813
+ read_file(
814
+ file_path=f"{os.path.dirname(artifact_path)}/{LICENSE_TXT}",
815
+ auth=default_signer(),
816
+ )
817
+ )
818
+
819
+ return AquaModelLicense(id=model_id, license=content)