oracle-ads 2.10.1__py3-none-any.whl → 2.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. ads/aqua/__init__.py +12 -0
  2. ads/aqua/base.py +324 -0
  3. ads/aqua/cli.py +19 -0
  4. ads/aqua/config/deployment_config_defaults.json +9 -0
  5. ads/aqua/config/resource_limit_names.json +7 -0
  6. ads/aqua/constants.py +45 -0
  7. ads/aqua/data.py +40 -0
  8. ads/aqua/decorator.py +101 -0
  9. ads/aqua/deployment.py +643 -0
  10. ads/aqua/dummy_data/icon.txt +1 -0
  11. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  12. ads/aqua/dummy_data/oci_models.json +1 -0
  13. ads/aqua/dummy_data/readme.md +26 -0
  14. ads/aqua/evaluation.py +1751 -0
  15. ads/aqua/exception.py +82 -0
  16. ads/aqua/extension/__init__.py +40 -0
  17. ads/aqua/extension/base_handler.py +138 -0
  18. ads/aqua/extension/common_handler.py +21 -0
  19. ads/aqua/extension/deployment_handler.py +202 -0
  20. ads/aqua/extension/evaluation_handler.py +135 -0
  21. ads/aqua/extension/finetune_handler.py +66 -0
  22. ads/aqua/extension/model_handler.py +59 -0
  23. ads/aqua/extension/ui_handler.py +201 -0
  24. ads/aqua/extension/utils.py +23 -0
  25. ads/aqua/finetune.py +579 -0
  26. ads/aqua/job.py +29 -0
  27. ads/aqua/model.py +819 -0
  28. ads/aqua/training/__init__.py +4 -0
  29. ads/aqua/training/exceptions.py +459 -0
  30. ads/aqua/ui.py +453 -0
  31. ads/aqua/utils.py +715 -0
  32. ads/cli.py +37 -6
  33. ads/common/decorator/__init__.py +7 -3
  34. ads/common/decorator/require_nonempty_arg.py +65 -0
  35. ads/common/object_storage_details.py +166 -7
  36. ads/common/oci_client.py +18 -1
  37. ads/common/oci_logging.py +2 -2
  38. ads/common/oci_mixin.py +4 -5
  39. ads/common/serializer.py +34 -5
  40. ads/common/utils.py +75 -10
  41. ads/config.py +40 -1
  42. ads/jobs/ads_job.py +43 -25
  43. ads/jobs/builders/infrastructure/base.py +4 -2
  44. ads/jobs/builders/infrastructure/dsc_job.py +49 -39
  45. ads/jobs/builders/runtimes/base.py +71 -1
  46. ads/jobs/builders/runtimes/container_runtime.py +4 -4
  47. ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
  48. ads/jobs/templates/driver_pytorch.py +27 -10
  49. ads/model/artifact_downloader.py +84 -14
  50. ads/model/artifact_uploader.py +25 -23
  51. ads/model/datascience_model.py +388 -38
  52. ads/model/deployment/model_deployment.py +10 -2
  53. ads/model/generic_model.py +8 -0
  54. ads/model/model_file_description_schema.json +68 -0
  55. ads/model/model_metadata.py +1 -1
  56. ads/model/service/oci_datascience_model.py +34 -5
  57. ads/opctl/operator/lowcode/anomaly/README.md +2 -1
  58. ads/opctl/operator/lowcode/anomaly/__main__.py +10 -4
  59. ads/opctl/operator/lowcode/anomaly/environment.yaml +2 -1
  60. ads/opctl/operator/lowcode/anomaly/model/automlx.py +12 -6
  61. ads/opctl/operator/lowcode/forecast/README.md +3 -2
  62. ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
  63. ads/opctl/operator/lowcode/forecast/model/automlx.py +12 -23
  64. ads/telemetry/base.py +62 -0
  65. ads/telemetry/client.py +105 -0
  66. ads/telemetry/telemetry.py +6 -3
  67. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +37 -7
  68. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +71 -36
  69. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
  70. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
  71. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
ads/aqua/ui.py ADDED
@@ -0,0 +1,453 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+ import concurrent.futures
6
+ from datetime import datetime, timedelta
7
+ from threading import Lock
8
+
9
+ from cachetools import TTLCache
10
+ from oci.exceptions import ServiceError
11
+ from oci.identity.models import Compartment
12
+
13
+ from ads.aqua import logger
14
+ from ads.aqua.base import AquaApp
15
+ from ads.aqua.data import Tags
16
+ from ads.aqua.exception import AquaValueError, AquaResourceAccessError
17
+ from ads.aqua.utils import load_config, sanitize_response
18
+ from ads.common import oci_client as oc
19
+ from ads.common.auth import default_signer
20
+ from ads.common.object_storage_details import ObjectStorageDetails
21
+ from ads.config import (
22
+ AQUA_CONFIG_FOLDER,
23
+ AQUA_RESOURCE_LIMIT_NAMES_CONFIG,
24
+ COMPARTMENT_OCID,
25
+ DATA_SCIENCE_SERVICE_NAME,
26
+ TENANCY_OCID,
27
+ )
28
+ from ads.telemetry import telemetry
29
+
30
+
31
+ class AquaUIApp(AquaApp):
32
+ """Contains APIs for supporting Aqua UI.
33
+
34
+ Attributes
35
+ ----------
36
+
37
+ Methods
38
+ -------
39
+ list_log_groups(self, **kwargs) -> List[Dict]
40
+ Lists all log groups for the specified compartment or tenancy.
41
+ list_logs(self, **kwargs) -> List[Dict]
42
+ Lists the specified log group's log objects.
43
+ list_compartments(self, **kwargs) -> List[Dict]
44
+ Lists the compartments in a specified compartment.
45
+
46
+ """
47
+
48
+ _compartments_cache = TTLCache(
49
+ maxsize=10, ttl=timedelta(hours=2), timer=datetime.now
50
+ )
51
+ _cache_lock = Lock()
52
+
53
+ @telemetry(entry_point="plugin=ui&action=list_log_groups", name="aqua")
54
+ def list_log_groups(self, **kwargs) -> str:
55
+ """Lists all log groups for the specified compartment or tenancy. This is a pass through the OCI list_log_groups
56
+ API.
57
+
58
+ Parameters
59
+ ----------
60
+ kwargs
61
+ Keyword arguments, such as compartment_id,
62
+ for `list_log_groups <https://docs.oracle.com/en-us/iaas/tools/python/2.119.1/api/logging/client/oci.logging.LoggingManagementClient.html#oci.logging.LoggingManagementClient.list_log_groups>`_
63
+
64
+ Returns
65
+ -------
66
+ str has json representation of oci.logging.models.log_group.LogGroup
67
+ """
68
+
69
+ compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
70
+
71
+ res = self.logging_client.list_log_groups(
72
+ compartment_id=compartment_id, **kwargs
73
+ ).data
74
+ return sanitize_response(oci_client=self.logging_client, response=res)
75
+
76
+ @telemetry(entry_point="plugin=ui&action=list_logs", name="aqua")
77
+ def list_logs(self, **kwargs) -> str:
78
+ """Lists the specified log group's log objects. This is a pass through the OCI list_log_groups
79
+ API.
80
+
81
+ Parameters
82
+ ----------
83
+ kwargs
84
+ Keyword arguments, such as log_group_id, log_type
85
+ for `list_logs <https://docs.oracle.com/en-us/iaas/tools/python/2.119.1/api/logging/client/oci.logging.LoggingManagementClient.html#oci.logging.LoggingManagementClient.list_logs>`_
86
+
87
+ Returns
88
+ -------
89
+ str:
90
+ str has json representation of oci.logging.models.log_summary.LogSummary
91
+ """
92
+ log_group_id = kwargs.pop("log_group_id")
93
+
94
+ res = self.logging_client.list_logs(log_group_id=log_group_id, **kwargs).data
95
+ return sanitize_response(oci_client=self.logging_client, response=res)
96
+
97
+ @telemetry(entry_point="plugin=ui&action=list_compartments", name="aqua")
98
+ def list_compartments(self) -> str:
99
+ """Lists the compartments in a tenancy specified by TENANCY_OCID env variable. This is a pass through the OCI list_compartments
100
+ API.
101
+
102
+ Returns
103
+ -------
104
+ str:
105
+ str has json representation of oci.identity.models.Compartment
106
+ """
107
+ try:
108
+ if not TENANCY_OCID:
109
+ raise AquaValueError(
110
+ f"TENANCY_OCID must be available in environment"
111
+ " variables to list the sub compartments."
112
+ )
113
+
114
+ if TENANCY_OCID in self._compartments_cache.keys():
115
+ logger.info(
116
+ f"Returning compartments list in {TENANCY_OCID} from cache."
117
+ )
118
+ return self._compartments_cache.get(TENANCY_OCID)
119
+
120
+ compartments = []
121
+ # User may not have permissions to list compartment.
122
+ try:
123
+ compartments.extend(
124
+ self.list_resource(
125
+ list_func_ref=self.identity_client.list_compartments,
126
+ compartment_id=TENANCY_OCID,
127
+ compartment_id_in_subtree=True,
128
+ access_level="ANY",
129
+ )
130
+ )
131
+ except ServiceError as se:
132
+ logger.error(
133
+ f"ERROR: Unable to list all sub compartment in tenancy {TENANCY_OCID}."
134
+ )
135
+ try:
136
+ compartments.append(
137
+ self.list_resource(
138
+ list_func_ref=self.identity_client.list_compartments,
139
+ compartment_id=TENANCY_OCID,
140
+ )
141
+ )
142
+ except ServiceError as se:
143
+ logger.error(
144
+ f"ERROR: Unable to list all child compartment in tenancy {TENANCY_OCID}."
145
+ )
146
+ try:
147
+ root_compartment = self.identity_client.get_compartment(
148
+ TENANCY_OCID
149
+ ).data
150
+ compartments.insert(0, root_compartment)
151
+ except ServiceError as se:
152
+ logger.error(
153
+ f"ERROR: Unable to get details of the root compartment {TENANCY_OCID}."
154
+ )
155
+ compartments.insert(
156
+ 0,
157
+ Compartment(id=TENANCY_OCID, name=" ** Root - Name N/A **"),
158
+ )
159
+ # convert the string of the results flattened as a dict
160
+ res = sanitize_response(
161
+ oci_client=self.identity_client, response=compartments
162
+ )
163
+
164
+ # cache compartment results
165
+ self._compartments_cache.__setitem__(key=TENANCY_OCID, value=res)
166
+
167
+ return res
168
+
169
+ # todo : update this once exception handling is set up
170
+ except ServiceError as se:
171
+ raise se
172
+
173
+ def get_default_compartment(self) -> dict:
174
+ """Returns user compartment OCID fetched from environment variables.
175
+
176
+ Returns
177
+ -------
178
+ dict:
179
+ The compartment ocid.
180
+ """
181
+ if not COMPARTMENT_OCID:
182
+ logger.error("No compartment id found from environment variables.")
183
+ return dict(compartment_id=COMPARTMENT_OCID)
184
+
185
+ def clear_compartments_list_cache(self) -> dict:
186
+ """Allows caller to clear compartments list cache
187
+ Returns
188
+ -------
189
+ dict with the key used, and True if cache has the key that needs to be deleted.
190
+ """
191
+ res = {}
192
+ logger.info(f"Clearing list_compartments cache")
193
+ with self._cache_lock:
194
+ if TENANCY_OCID in self._compartments_cache.keys():
195
+ self._compartments_cache.pop(key=TENANCY_OCID)
196
+ res = {
197
+ "key": {
198
+ "tenancy_ocid": TENANCY_OCID,
199
+ },
200
+ "cache_deleted": True,
201
+ }
202
+ return res
203
+
204
+ @telemetry(entry_point="plugin=ui&action=list_model_version_sets", name="aqua")
205
+ def list_model_version_sets(self, target_tag: str = None, **kwargs) -> str:
206
+ """Lists all model version sets for the specified compartment or tenancy.
207
+
208
+ Parameters
209
+ ----------
210
+ target_tag: str
211
+ Required Tag for the targeting model version sets.
212
+ **kwargs
213
+ Addtional arguments, such as `compartment_id`,
214
+ for `list_model_version_sets <https://docs.oracle.com/en-us/iaas/tools/python/2.121.0/api/data_science/client/oci.data_science.DataScienceClient.html#oci.data_science.DataScienceClient.list_model_version_sets>`_
215
+
216
+ Returns
217
+ -------
218
+ str has json representation of `oci.data_science.models.ModelVersionSetSummary`.
219
+ """
220
+ compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
221
+ target_resource = (
222
+ "experiments"
223
+ if target_tag == Tags.AQUA_EVALUATION.value
224
+ else "modelversionsets"
225
+ )
226
+ logger.info(f"Loading {target_resource} from compartment: {compartment_id}")
227
+
228
+ items = self.list_resource(
229
+ self.ds_client.list_model_version_sets,
230
+ compartment_id=compartment_id,
231
+ **kwargs,
232
+ )
233
+
234
+ if target_tag is not None:
235
+ res = []
236
+ for item in items:
237
+ if target_tag in item.freeform_tags:
238
+ res.append(item)
239
+ else:
240
+ res = items
241
+
242
+ return sanitize_response(oci_client=self.ds_client, response=res)
243
+
244
+ @telemetry(entry_point="plugin=ui&action=list_buckets", name="aqua")
245
+ def list_buckets(self, **kwargs) -> list:
246
+ """Lists all buckets for the specified compartment.
247
+
248
+ Parameters
249
+ ----------
250
+ **kwargs
251
+ Addtional arguments, such as `compartment_id`,
252
+ for `list_buckets <https://docs.oracle.com/en-us/iaas/tools/python/2.122.0/api/object_storage/client/oci.object_storage.ObjectStorageClient.html?highlight=list%20bucket#oci.object_storage.ObjectStorageClient.list_buckets>`_
253
+
254
+ Returns
255
+ -------
256
+ str has json representation of `oci.object_storage.models.BucketSummary`."""
257
+ compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
258
+ logger.info(f"Loading buckets summary from compartment: {compartment_id}")
259
+
260
+ versioned = kwargs.pop("versioned", False)
261
+
262
+ os_client = oc.OCIClientFactory(**default_signer()).object_storage
263
+ namespace_name = os_client.get_namespace(compartment_id=compartment_id).data
264
+ logger.info(f"Object Storage namespace is `{namespace_name}`.")
265
+
266
+ response = os_client.list_buckets(
267
+ namespace_name=namespace_name,
268
+ compartment_id=compartment_id,
269
+ **kwargs,
270
+ ).data
271
+
272
+ if response and versioned:
273
+ with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
274
+ result = list(
275
+ filter(None, executor.map(self._is_bucket_versioned, response))
276
+ )
277
+ else:
278
+ result = response
279
+
280
+ return sanitize_response(oci_client=os_client, response=result)
281
+
282
+ @staticmethod
283
+ def _is_bucket_versioned(response):
284
+ bucket_name = response.name
285
+ namespace = response.namespace
286
+ bucket_uri = f"oci://{bucket_name}@{namespace}"
287
+ if ObjectStorageDetails.from_path(bucket_uri).is_bucket_versioned():
288
+ return response
289
+ else:
290
+ return None
291
+
292
+ @telemetry(entry_point="plugin=ui&action=list_job_shapes", name="aqua")
293
+ def list_job_shapes(self, **kwargs) -> list:
294
+ """Lists all availiable job shapes for the specified compartment.
295
+
296
+ Parameters
297
+ ----------
298
+ **kwargs
299
+ Addtional arguments, such as `compartment_id`,
300
+ for `list_job_shapes <https://docs.oracle.com/en-us/iaas/tools/python/2.122.0/api/data_science/client/oci.data_science.DataScienceClient.html#oci.data_science.DataScienceClient.list_job_shapes>`_
301
+
302
+ Returns
303
+ -------
304
+ str has json representation of `oci.data_science.models.JobShapeSummary`."""
305
+ compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
306
+ logger.info(f"Loading job shape summary from compartment: {compartment_id}")
307
+
308
+ res = self.ds_client.list_job_shapes(
309
+ compartment_id=compartment_id, **kwargs
310
+ ).data
311
+ return sanitize_response(oci_client=self.ds_client, response=res)
312
+
313
+ @telemetry(entry_point="plugin=ui&action=list_vcn", name="aqua")
314
+ def list_vcn(self, **kwargs) -> list:
315
+ """Lists the virtual cloud networks (VCNs) in the specified compartment.
316
+
317
+ Parameters
318
+ ----------
319
+ **kwargs
320
+ Addtional arguments, such as `compartment_id`,
321
+ for `list_vcns <https://docs.oracle.com/iaas/api/#/en/iaas/20160918/Vcn/ListVcns>`_
322
+
323
+ Returns
324
+ -------
325
+ json representation of `oci.core.models.Vcn`."""
326
+
327
+ compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
328
+ logger.info(f"Loading VCN list from compartment: {compartment_id}")
329
+
330
+ # todo: add _vcn_client in init in AquaApp, then add a property vcn_client which does lazy init
331
+ # of _vcn_client. Do this for all clients in AquaApp
332
+ vcn_client = oc.OCIClientFactory(**default_signer()).virtual_network
333
+ res = vcn_client.list_vcns(compartment_id=compartment_id).data
334
+ return sanitize_response(oci_client=vcn_client, response=res)
335
+
336
+ @telemetry(entry_point="plugin=ui&action=list_subnets", name="aqua")
337
+ def list_subnets(self, **kwargs) -> list:
338
+ """Lists the subnets in the specified VCN and the specified compartment.
339
+
340
+ Parameters
341
+ ----------
342
+ **kwargs
343
+ Addtional arguments, such as `compartment_id`,
344
+ for `list_vcns <https://docs.oracle.com/iaas/api/#/en/iaas/20160918/Subnet/ListSubnets>`_
345
+
346
+ Returns
347
+ -------
348
+ json representation of `oci.core.models.Subnet`."""
349
+
350
+ compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
351
+ vcn_id = kwargs.pop("vcn_id", None)
352
+ logger.info(
353
+ f"Loading subnet list from compartment: {compartment_id} for VCN: {vcn_id}"
354
+ )
355
+
356
+ vcn_client = oc.OCIClientFactory(**default_signer()).virtual_network
357
+ res = vcn_client.list_subnets(compartment_id=compartment_id, vcn_id=vcn_id).data
358
+
359
+ return sanitize_response(oci_client=vcn_client, response=res)
360
+
361
+ @telemetry(entry_point="plugin=ui&action=get_shape_availability", name="aqua")
362
+ def get_shape_availability(self, **kwargs):
363
+ """
364
+ For a given compartmentId, resource limit name, and scope, returns the number of available resources associated
365
+ with the given limit.
366
+ Parameters
367
+ ----------
368
+ kwargs
369
+ instance_shape: (str).
370
+ The shape of the instance used for deployment.
371
+
372
+ **kwargs
373
+ Addtional arguments, such as `compartment_id`,
374
+ for `get_resource_availability <https://docs.oracle.com/iaas/api/#/en/limits/20181025/ResourceAvailability/GetResourceAvailability>`_
375
+
376
+ Returns
377
+ -------
378
+ dict:
379
+ available resource count.
380
+
381
+ """
382
+ compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
383
+ instance_shape = kwargs.pop("instance_shape", None)
384
+
385
+ if not instance_shape:
386
+ raise AquaValueError("instance_shape argument is required.")
387
+
388
+ limits_client = oc.OCIClientFactory(**default_signer()).limits
389
+
390
+ artifact_path = AQUA_CONFIG_FOLDER
391
+ config = load_config(
392
+ artifact_path,
393
+ config_file_name=AQUA_RESOURCE_LIMIT_NAMES_CONFIG,
394
+ )
395
+
396
+ if instance_shape not in config:
397
+ logger.error(
398
+ f"{instance_shape} does not have mapping details in {AQUA_RESOURCE_LIMIT_NAMES_CONFIG}"
399
+ )
400
+ return {}
401
+
402
+ limit_name = config[instance_shape]
403
+ try:
404
+ res = limits_client.get_resource_availability(
405
+ DATA_SCIENCE_SERVICE_NAME, limit_name, compartment_id, **kwargs
406
+ ).data
407
+ except ServiceError as se:
408
+ raise AquaResourceAccessError(
409
+ f"Could not check limits availability for the shape {instance_shape}.",
410
+ service_payload=se.args[0] if se.args else None,
411
+ )
412
+
413
+ available = res.available
414
+
415
+ try:
416
+ cards = int(instance_shape.split(".")[-1])
417
+ except:
418
+ cards = 1
419
+
420
+ response = dict(available_count=available)
421
+
422
+ if available < cards:
423
+ raise AquaValueError(
424
+ f"Inadequate resource is available to create the {instance_shape} resource. The number of available "
425
+ f"resource associated with the limit name {limit_name} is {available}.",
426
+ service_payload=response,
427
+ )
428
+
429
+ return response
430
+
431
+ @telemetry(entry_point="plugin=ui&action=is_bucket_versioned", name="aqua")
432
+ def is_bucket_versioned(self, bucket_uri: str):
433
+ """Check if the given bucket is versioned. Required check for fine-tuned model creation process where the model
434
+ weights are stored.
435
+
436
+ Parameters
437
+ ----------
438
+ bucket_uri
439
+
440
+ Returns
441
+ -------
442
+ dict:
443
+ is_versioned flag that informs whether it is versioned or not.
444
+
445
+ """
446
+ if ObjectStorageDetails.from_path(bucket_uri).is_bucket_versioned():
447
+ is_versioned = True
448
+ message = f"Model artifact bucket {bucket_uri} is versioned."
449
+ else:
450
+ is_versioned = False
451
+ message = f"Model artifact bucket {bucket_uri} is not versioned. Check if the path exists and enable versioning on the bucket to proceed with model creation."
452
+
453
+ return dict(is_versioned=is_versioned, message=message)