oracle-ads 2.13.11__py3-none-any.whl → 2.13.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. ads/aqua/app.py +73 -15
  2. ads/aqua/cli.py +17 -0
  3. ads/aqua/client/client.py +38 -21
  4. ads/aqua/client/openai_client.py +20 -10
  5. ads/aqua/common/entities.py +78 -12
  6. ads/aqua/common/utils.py +35 -0
  7. ads/aqua/constants.py +2 -0
  8. ads/aqua/evaluation/evaluation.py +5 -4
  9. ads/aqua/extension/common_handler.py +47 -2
  10. ads/aqua/extension/model_handler.py +51 -9
  11. ads/aqua/model/constants.py +1 -0
  12. ads/aqua/model/enums.py +19 -1
  13. ads/aqua/model/model.py +119 -51
  14. ads/aqua/model/utils.py +1 -2
  15. ads/aqua/modeldeployment/config_loader.py +815 -0
  16. ads/aqua/modeldeployment/constants.py +4 -1
  17. ads/aqua/modeldeployment/deployment.py +178 -129
  18. ads/aqua/modeldeployment/entities.py +150 -178
  19. ads/aqua/modeldeployment/model_group_config.py +233 -0
  20. ads/aqua/modeldeployment/utils.py +0 -539
  21. ads/aqua/verify_policies/__init__.py +8 -0
  22. ads/aqua/verify_policies/constants.py +13 -0
  23. ads/aqua/verify_policies/entities.py +29 -0
  24. ads/aqua/verify_policies/messages.py +101 -0
  25. ads/aqua/verify_policies/utils.py +432 -0
  26. ads/aqua/verify_policies/verify.py +345 -0
  27. ads/aqua/version.json +3 -0
  28. ads/common/oci_logging.py +4 -7
  29. ads/common/work_request.py +39 -38
  30. ads/jobs/builders/infrastructure/dsc_job.py +121 -24
  31. ads/jobs/builders/infrastructure/dsc_job_runtime.py +71 -24
  32. ads/jobs/builders/runtimes/base.py +7 -5
  33. ads/jobs/builders/runtimes/pytorch_runtime.py +6 -8
  34. ads/jobs/templates/driver_pytorch.py +486 -172
  35. ads/jobs/templates/driver_utils.py +27 -11
  36. ads/model/deployment/model_deployment.py +51 -38
  37. ads/model/service/oci_datascience_model_deployment.py +6 -11
  38. ads/telemetry/client.py +4 -4
  39. {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/METADATA +2 -1
  40. {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/RECORD +43 -34
  41. {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/WHEEL +0 -0
  42. {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/entry_points.txt +0 -0
  43. {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,345 @@
1
+ import logging
2
+
3
+ import click
4
+ import oci.exceptions
5
+
6
+ from ads.aqua.verify_policies.constants import POLICY_HELP_LINK, TEST_JOB_NAME, TEST_JOB_RUN_NAME, TEST_LIMIT_NAME, \
7
+ TEST_MVS_NAME, TEST_MD_NAME, TEST_VM_SHAPE
8
+ from ads.aqua.verify_policies.messages import operation_messages
9
+ from ads.aqua.verify_policies.entities import OperationResultSuccess, OperationResultFailure, PolicyStatus
10
+ from ads.aqua.verify_policies.utils import PolicyValidationError, VerifyPoliciesUtils, RichStatusLog
11
+ from functools import wraps
12
+
13
+ logger = logging.getLogger("aqua.policies")
14
+
15
+
16
+ def with_spinner(func):
17
+ """Decorator to wrap execution of a function with a rich UI spinner.
18
+
19
+ Displays status while the operation runs and logs success or failure messages
20
+ based on the policy verification result.
21
+ """
22
+
23
+ @wraps(func)
24
+ def wrapper(self, function, **kwargs):
25
+ operation_message = operation_messages[function.__name__]
26
+ ignore_spinner = kwargs.pop("ignore_spinner", False)
27
+
28
+ def run_func():
29
+ return_value, result_status = func(self, function, **kwargs)
30
+ result_message = f"{self._rich_ui.get_status_emoji(result_status.status)} {result_status.operation}"
31
+ if result_status.status == PolicyStatus.SUCCESS:
32
+ logger.info(result_message)
33
+ else:
34
+ logger.warning(result_message)
35
+ logger.info(result_status.error)
36
+ logger.info(f"Policy hint: {result_status.policy_hint}")
37
+ logger.info(f"Refer to: {POLICY_HELP_LINK}")
38
+
39
+ return return_value, result_status
40
+
41
+ if ignore_spinner:
42
+ return run_func()
43
+ else:
44
+ with self._rich_ui.console.status(f"Verifying {operation_message['name']}") as status:
45
+ return run_func()
46
+
47
+ return wrapper
48
+
49
+
50
+ class AquaVerifyPoliciesApp:
51
+ """
52
+ AquaVerifyPoliciesApp provides methods to verify IAM policies required for
53
+ various operations in OCI Data Science's AQUA (Accelerated Data Science) platform.
54
+
55
+ This utility is intended to help users validate whether they have the necessary
56
+ permissions to perform common AQUA workflows such as model registration,
57
+ deployment, evaluation, and fine-tuning.
58
+
59
+ Methods
60
+ -------
61
+ `common_policies()`: Validates basic read-level policies across AQUA components.
62
+ `model_register()`: Checks policies for object storage access and model registration.
63
+ `model_deployment()`: Validates policies for registering and deploying models.
64
+ `evaluation()`: Confirms ability to manage model version sets, jobs, and storage for evaluation.
65
+ `finetune()`: Verifies access required to fine-tune models.
66
+ """
67
+
68
+ def __init__(self):
69
+ super().__init__()
70
+ self._util = VerifyPoliciesUtils()
71
+ self._rich_ui = RichStatusLog()
72
+ self.model_id = None
73
+ logger.propagate = False
74
+ logger.setLevel(logging.INFO)
75
+
76
+ def _get_operation_result(self, operation, status):
77
+ """Maps a function and policy status to a corresponding result object.
78
+
79
+ Parameters:
80
+ operation (function): The operation being verified.
81
+ status (PolicyStatus): The outcome of the policy verification.
82
+
83
+ Returns:
84
+ OperationResultSuccess or OperationResultFailure based on status.
85
+ """
86
+ operation_message = operation_messages[operation.__name__]
87
+ if status == PolicyStatus.SUCCESS:
88
+ return OperationResultSuccess(operation=operation_message["name"])
89
+ if status == PolicyStatus.UNVERIFIED:
90
+ return OperationResultSuccess(operation=operation_message["name"], status=status)
91
+ if status == PolicyStatus.FAILURE:
92
+ return OperationResultFailure(operation=operation_message["name"], error=operation_message["error"],
93
+ policy_hint=f"{operation_message['policy_hint']}" )
94
+
95
+ @with_spinner
96
+ def _execute(self, function, **kwargs):
97
+ """Executes a given operation function with policy validation and error handling.
98
+ Parameters:
99
+ function (callable): The function to execute.
100
+ kwargs (dict): Keyword arguments to pass to the function.
101
+
102
+ Returns:
103
+ Tuple: (result, OperationResult)
104
+ """
105
+ result = None
106
+ try:
107
+ result = function(**kwargs)
108
+ status = PolicyStatus.SUCCESS
109
+ except oci.exceptions.ServiceError as oci_error:
110
+ if oci_error.status == 404:
111
+ logger.debug(oci_error)
112
+ status = PolicyStatus.FAILURE
113
+ else:
114
+ logger.error(oci_error)
115
+ raise oci_error
116
+ except PolicyValidationError as policy_error:
117
+ status = PolicyStatus.FAILURE
118
+ except Exception as e:
119
+ logger.error(e)
120
+ raise e
121
+ return result, self._get_operation_result(function, status)
122
+
123
+ def _test_model_register(self, **kwargs):
124
+ """Verifies policies required to manage an object storage bucket and register a model.
125
+
126
+ Returns:
127
+ List of result dicts for bucket management and model registration.
128
+ """
129
+ result = []
130
+ bucket = kwargs.pop("bucket")
131
+ _, test_manage_obs_policy = self._execute(self._util.manage_bucket, bucket=bucket, **kwargs)
132
+ result.append(test_manage_obs_policy.to_dict())
133
+
134
+ if test_manage_obs_policy.status == PolicyStatus.SUCCESS:
135
+ self.model_id, test_model_register = self._execute(self._util.register_model)
136
+ result.append(test_model_register.to_dict())
137
+ return result
138
+
139
+ def _test_delete_model(self, **kwargs):
140
+ """Attempts to delete the test model created during model registration.
141
+
142
+ Returns:
143
+ List containing the result of model deletion.
144
+ """
145
+ if self.model_id is not None:
146
+ _, test_delete_model_test = self._execute(self._util.aqua_model.ds_client.delete_model,
147
+ model_id=self.model_id, **kwargs)
148
+ return [test_delete_model_test.to_dict()]
149
+ else:
150
+ return [self._get_operation_result(self._util.aqua_model.ds_client.delete_model,
151
+ PolicyStatus.UNVERIFIED).to_dict()]
152
+
153
+ def _test_model_deployment(self, **kwargs):
154
+ """Verifies policies required to create and delete a model deployment.
155
+
156
+ Returns:
157
+ List of result dicts for deployment creation and deletion.
158
+ """
159
+ logger.info(f"Creating Model Deployment with name {TEST_MD_NAME}")
160
+ md_ocid, test_model_deployment = self._execute(self._util.create_model_deployment, model_id=self.model_id,
161
+ instance_shape=TEST_VM_SHAPE)
162
+ _, test_delete_md = self._execute(self._util.aqua_model.ds_client.delete_model_deployment, model_deployment_id=md_ocid)
163
+ return [test_model_deployment.to_dict(), test_delete_md.to_dict()]
164
+
165
+ def _test_manage_mvs(self, **kwargs):
166
+ """Verifies policies required to create and delete a model version set (MVS).
167
+
168
+ Returns:
169
+ List of result dicts for MVS creation and deletion.
170
+ """
171
+ logger.info(f"Creating ModelVersionSet with name {TEST_MVS_NAME}")
172
+
173
+ model_mvs, test_create_mvs = self._execute(self._util.create_model_version_set, name=TEST_MVS_NAME)
174
+ model_mvs_id = model_mvs[0]
175
+ if model_mvs_id:
176
+ logger.info(f"Deleting ModelVersionSet {TEST_MVS_NAME}")
177
+ _, delete_mvs = self._execute(self._util.aqua_model.ds_client.delete_model_version_set,
178
+ model_version_set_id=model_mvs_id)
179
+ else:
180
+ delete_mvs = self._get_operation_result(self._util.aqua_model.ds_client.delete_model_version_set,
181
+ PolicyStatus.UNVERIFIED)
182
+ return [test_create_mvs.to_dict(), delete_mvs.to_dict()]
183
+
184
+ def _test_manage_job(self, **kwargs):
185
+ """Verifies policies required to create a job, create a job run, and delete the job.
186
+
187
+ Returns:
188
+ List of result dicts for job creation, job run creation, and job deletion.
189
+ """
190
+
191
+ logger.info(f"Creating Job with name {TEST_JOB_NAME}")
192
+
193
+ # Create Job & JobRun.
194
+ job_id, test_create_job = self._execute(self._util.create_job, display_name=TEST_JOB_NAME,
195
+ **kwargs)
196
+
197
+ logger.info(f"Creating JobRun with name {TEST_JOB_RUN_NAME}")
198
+
199
+ _, test_create_job_run = self._execute(self._util.create_job_run, display_name=TEST_JOB_RUN_NAME,
200
+ job_id=job_id, **kwargs)
201
+
202
+ # Delete Job Run
203
+ if job_id:
204
+ _, delete_job = self._execute(self._util.aqua_model.ds_client.delete_job, job_id=job_id, delete_related_job_runs=True)
205
+ else:
206
+ delete_job = self._get_operation_result(self._util.aqua_model.ds_client.delete_job, PolicyStatus.UNVERIFIED)
207
+
208
+ return [test_create_job.to_dict(), test_create_job_run.to_dict(), delete_job.to_dict()]
209
+
210
+ def _prompt(self, message, bool=False):
211
+ """Wrapper for Click prompt or confirmation.
212
+
213
+ Parameters:
214
+ message (str): The prompt message.
215
+ bool (bool): Whether to ask for confirmation instead of input.
216
+
217
+ Returns:
218
+ User input or confirmation (bool/str).
219
+ """
220
+ if bool:
221
+ return click.confirm(message, default=False)
222
+ else:
223
+ return click.prompt(message, type=str)
224
+
225
+ def _consent(self):
226
+ """
227
+ Prompts the user for confirmation before performing actions.
228
+ Exits if the user does not consent.
229
+ """
230
+ answer = self._prompt("Do you want to continue?", bool=True)
231
+ if not answer:
232
+ exit(0)
233
+
234
+ def common_policies(self, **kwargs):
235
+ """Verifies basic read-level policies across various AQUA components
236
+ (e.g. compartments, models, jobs, buckets, logs).
237
+
238
+ Returns:
239
+ List of result dicts for each verified operation.
240
+ """
241
+ logger.info("[magenta]Verifying Common Policies")
242
+ basic_operations = [self._util.list_compartments, self._util.list_models, self._util.list_model_version_sets,
243
+ self._util.list_project, self._util.list_jobs, self._util.list_job_runs,
244
+ self._util.list_buckets,
245
+ self._util.list_log_groups
246
+ ]
247
+ result = []
248
+ for op in basic_operations:
249
+ _, status = self._execute(op, **kwargs)
250
+ result.append(status.to_dict())
251
+
252
+ _, get_resource_availability_status = self._execute(self._util.get_resource_availability,
253
+ limit_name=TEST_LIMIT_NAME)
254
+ result.append(get_resource_availability_status.to_dict())
255
+ return result
256
+
257
+ def model_register(self, **kwargs):
258
+ """Verifies policies required to register a model, including object storage access.
259
+
260
+ Returns:
261
+ List of result dicts for registration and cleanup.
262
+ """
263
+ logger.info("[magenta]Verifying Model Register")
264
+ logger.info("Object and Model will be created.")
265
+ kwargs.pop("consent", None) == True or self._consent()
266
+
267
+ model_save_bucket = kwargs.pop("bucket", None) or self._prompt(
268
+ "Provide bucket name where model artifacts will be saved")
269
+ register_model_result = self._test_model_register(bucket=model_save_bucket)
270
+ delete_model_result = self._test_delete_model(**kwargs)
271
+ return [*register_model_result, *delete_model_result]
272
+
273
+ def model_deployment(self, **kwargs):
274
+ """Verifies policies required to register and deploy a model, and perform cleanup.
275
+
276
+ Returns:
277
+ List of result dicts for registration, deployment, and cleanup.
278
+ """
279
+ logger.info("[magenta]Verifying Model Deployment")
280
+ logger.info("Object, Model, Model deployment will be created.")
281
+ kwargs.pop("consent", None) == True or self._consent()
282
+ model_save_bucket = kwargs.pop("bucket", None) or self._prompt(
283
+ "Provide bucket name where model artifacts will be saved")
284
+ model_register = self._test_model_register(bucket=model_save_bucket)
285
+ model_deployment = self._test_model_deployment()
286
+ delete_model_result = self._test_delete_model(**kwargs)
287
+
288
+ return [*model_register, *model_deployment, *delete_model_result]
289
+
290
+ def evaluation(self, **kwargs):
291
+ """Verifies policies for evaluation workloads including model version set,
292
+ job and job runs, and object storage access.
293
+
294
+ Returns:
295
+ List of result dicts for all evaluation steps.
296
+ """
297
+ logger.info("[magenta]Verifying Evaluation")
298
+ logger.info("Model Version Set, Model, Object, Job and JobRun will be created.")
299
+ kwargs.pop("consent", None) == True or self._consent()
300
+
301
+ # Create & Delete MVS
302
+ test_manage_mvs = self._test_manage_mvs(**kwargs)
303
+
304
+ # Create & Model
305
+ model_save_bucket = kwargs.pop("bucket", None) or self._prompt(
306
+ "Provide bucket name where model artifacts will be saved")
307
+ register_model_result = self._test_model_register(bucket=model_save_bucket)
308
+ delete_model_result = self._test_delete_model(**kwargs)
309
+
310
+ # Manage Jobs & Job Runs
311
+ test_job_and_job_run = self._test_manage_job(**kwargs)
312
+
313
+ return [*test_manage_mvs, *register_model_result, *delete_model_result, *test_job_and_job_run]
314
+
315
+ def finetune(self, **kwargs):
316
+ """Verifies policies for fine-tuning jobs, including managing object storage,
317
+ MVS.
318
+
319
+ Returns:
320
+ List of result dicts for each fine-tuning operation.
321
+ """
322
+ logger.info("[magenta]Verifying Finetuning")
323
+ logger.info("Object, Model Version Set, Job and JobRun will be created. VCN will be used.")
324
+ kwargs.pop("consent", None) == True or self._consent()
325
+
326
+ # Manage bucket
327
+ bucket = kwargs.pop("bucket", None) or self._prompt(
328
+ "Provide bucket name required to save training datasets, scripts, and fine-tuned model outputs")
329
+
330
+ subnet_id = kwargs.pop("subnet_id", None)
331
+ ignore_subnet = kwargs.pop("ignore_subnet", False)
332
+
333
+ if subnet_id is None and not ignore_subnet and self._prompt("Do you want to use custom subnet", bool=True):
334
+ subnet_id = self._prompt("Provide subnet id")
335
+
336
+ _, test_manage_obs_policy = self._execute(self._util.manage_bucket, bucket=bucket, **kwargs)
337
+
338
+ # Create & Delete MVS
339
+ test_manage_mvs = self._test_manage_mvs(**kwargs)
340
+
341
+ # Manage Jobs & Job Runs
342
+ test_job_and_job_run = self._test_manage_job(subnet_id = subnet_id, **kwargs)
343
+
344
+ return [*test_manage_mvs, *test_job_and_job_run, test_manage_obs_policy.to_dict()]
345
+
ads/aqua/version.json ADDED
@@ -0,0 +1,3 @@
1
+ {
2
+ "aqua": "1.0.7"
3
+ }
ads/common/oci_logging.py CHANGED
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
3
  # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -7,16 +6,16 @@
7
6
  import datetime
8
7
  import logging
9
8
  import time
10
- from typing import Dict, Union, List
9
+ from typing import Dict, List, Union
11
10
 
11
+ import oci.exceptions
12
12
  import oci.logging
13
13
  import oci.loggingsearch
14
- import oci.exceptions
14
+
15
15
  from ads.common.decorator.utils import class_or_instance_method
16
16
  from ads.common.oci_mixin import OCIModelMixin, OCIWorkRequestMixin
17
17
  from ads.common.oci_resource import OCIResource, ResourceNotFoundError
18
18
 
19
-
20
19
  logger = logging.getLogger(__name__)
21
20
 
22
21
  # Maximum number of log records to be returned by default.
@@ -862,9 +861,7 @@ class ConsolidatedLog:
862
861
  time_start=time_start,
863
862
  log_filter=log_filter,
864
863
  )
865
- self._print(
866
- sorted(tail_logs, key=lambda log: log["time"])
867
- )
864
+ self._print(sorted(tail_logs, key=lambda log: log["time"]))
868
865
 
869
866
  def head(
870
867
  self,
@@ -1,7 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
-
4
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
5
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
4
 
7
5
  import logging
@@ -12,6 +10,7 @@ from typing import Callable
12
10
  import oci
13
11
  from oci import Signer
14
12
  from tqdm.auto import tqdm
13
+
15
14
  from ads.common.oci_datascience import OCIDataScienceMixin
16
15
 
17
16
  logger = logging.getLogger(__name__)
@@ -20,10 +19,10 @@ WORK_REQUEST_STOP_STATE = ("SUCCEEDED", "FAILED", "CANCELED")
20
19
  DEFAULT_WAIT_TIME = 1200
21
20
  DEFAULT_POLL_INTERVAL = 10
22
21
  WORK_REQUEST_PERCENTAGE = 100
23
- # default tqdm progress bar format:
22
+ # default tqdm progress bar format:
24
23
  # {l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]
25
24
  # customize the bar format to remove the {n_fmt}/{total_fmt} from the right side
26
- DEFAULT_BAR_FORMAT = '{l_bar}{bar}| [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]'
25
+ DEFAULT_BAR_FORMAT = "{l_bar}{bar}| [{elapsed}<{remaining}, " "{rate_fmt}{postfix}]"
27
26
 
28
27
 
29
28
  class DataScienceWorkRequest(OCIDataScienceMixin):
@@ -32,13 +31,13 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
32
31
  """
33
32
 
34
33
  def __init__(
35
- self,
36
- id: str,
34
+ self,
35
+ id: str,
37
36
  description: str = "Processing",
38
- config: dict = None,
39
- signer: Signer = None,
40
- client_kwargs: dict = None,
41
- **kwargs
37
+ config: dict = None,
38
+ signer: Signer = None,
39
+ client_kwargs: dict = None,
40
+ **kwargs,
42
41
  ) -> None:
43
42
  """Initializes ADSWorkRequest object.
44
43
 
@@ -49,41 +48,43 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
49
48
  description: str
50
49
  Progress bar initial step description (Defaults to `Processing`).
51
50
  config : dict, optional
52
- OCI API key config dictionary to initialize
51
+ OCI API key config dictionary to initialize
53
52
  oci.data_science.DataScienceClient (Defaults to None).
54
53
  signer : oci.signer.Signer, optional
55
- OCI authentication signer to initialize
54
+ OCI authentication signer to initialize
56
55
  oci.data_science.DataScienceClient (Defaults to None).
57
56
  client_kwargs : dict, optional
58
- Additional client keyword arguments to initialize
57
+ Additional client keyword arguments to initialize
59
58
  oci.data_science.DataScienceClient (Defaults to None).
60
59
  kwargs:
61
- Additional keyword arguments to initialize
60
+ Additional keyword arguments to initialize
62
61
  oci.data_science.DataScienceClient.
63
62
  """
64
63
  self.id = id
65
64
  self._description = description
66
65
  self._percentage = 0
67
66
  self._status = None
67
+ self._error_message = ""
68
68
  super().__init__(config, signer, client_kwargs, **kwargs)
69
-
70
69
 
71
70
  def _sync(self):
72
71
  """Fetches the latest work request information to ADSWorkRequest object."""
73
72
  work_request = self.client.get_work_request(self.id).data
74
- work_request_logs = self.client.list_work_request_logs(
75
- self.id
76
- ).data
73
+ work_request_logs = self.client.list_work_request_logs(self.id).data
77
74
 
78
- self._percentage= work_request.percent_complete
75
+ self._percentage = work_request.percent_complete
79
76
  self._status = work_request.status
80
- self._description = work_request_logs[-1].message if work_request_logs else "Processing"
77
+ self._description = (
78
+ work_request_logs[-1].message if work_request_logs else "Processing"
79
+ )
80
+ if work_request.status == "FAILED":
81
+ self._error_message = self.client.list_work_request_errors(self.id).data
81
82
 
82
83
  def watch(
83
- self,
84
+ self,
84
85
  progress_callback: Callable,
85
- max_wait_time: int=DEFAULT_WAIT_TIME,
86
- poll_interval: int=DEFAULT_POLL_INTERVAL,
86
+ max_wait_time: int = DEFAULT_WAIT_TIME,
87
+ poll_interval: int = DEFAULT_POLL_INTERVAL,
87
88
  ):
88
89
  """Updates the progress bar with realtime message and percentage until the process is completed.
89
90
 
@@ -92,10 +93,10 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
92
93
  progress_callback: Callable
93
94
  Progress bar callback function.
94
95
  It must accept `(percent_change, description)` where `percent_change` is the
95
- work request percent complete and `description` is the latest work request log message.
96
+ work request percent complete and `description` is the latest work request log message.
96
97
  max_wait_time: int
97
98
  Maximum amount of time to wait in seconds (Defaults to 1200).
98
- Negative implies infinite wait time.
99
+ Negative implies infinite wait time.
99
100
  poll_interval: int
100
101
  Poll interval in seconds (Defaults to 10).
101
102
 
@@ -107,7 +108,6 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
107
108
 
108
109
  start_time = time.time()
109
110
  while self._percentage < 100:
110
-
111
111
  seconds_since = time.time() - start_time
112
112
  if max_wait_time > 0 and seconds_since >= max_wait_time:
113
113
  logger.error(f"Exceeded max wait time of {max_wait_time} seconds.")
@@ -124,12 +124,14 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
124
124
  percent_change = self._percentage - previous_percent_complete
125
125
  previous_percent_complete = self._percentage
126
126
  progress_callback(
127
- percent_change=percent_change,
128
- description=self._description
127
+ percent_change=percent_change, description=self._description
129
128
  )
130
129
 
131
130
  if self._status in WORK_REQUEST_STOP_STATE:
132
- if self._status != oci.work_requests.models.WorkRequest.STATUS_SUCCEEDED:
131
+ if (
132
+ self._status
133
+ != oci.work_requests.models.WorkRequest.STATUS_SUCCEEDED
134
+ ):
133
135
  if self._description:
134
136
  raise Exception(self._description)
135
137
  else:
@@ -145,12 +147,12 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
145
147
 
146
148
  def wait_work_request(
147
149
  self,
148
- progress_bar_description: str="Processing",
149
- max_wait_time: int=DEFAULT_WAIT_TIME,
150
- poll_interval: int=DEFAULT_POLL_INTERVAL
150
+ progress_bar_description: str = "Processing",
151
+ max_wait_time: int = DEFAULT_WAIT_TIME,
152
+ poll_interval: int = DEFAULT_POLL_INTERVAL,
151
153
  ):
152
154
  """Waits for the work request progress bar to be completed.
153
-
155
+
154
156
  Parameters
155
157
  ----------
156
158
  progress_bar_description: str
@@ -160,7 +162,7 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
160
162
  Negative implies infinite wait time.
161
163
  poll_interval: int
162
164
  Poll interval in seconds (Defaults to 10).
163
-
165
+
164
166
  Returns
165
167
  -------
166
168
  None
@@ -172,7 +174,7 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
172
174
  mininterval=0,
173
175
  file=sys.stdout,
174
176
  desc=progress_bar_description,
175
- bar_format=DEFAULT_BAR_FORMAT
177
+ bar_format=DEFAULT_BAR_FORMAT,
176
178
  ) as pbar:
177
179
 
178
180
  def progress_callback(percent_change, description):
@@ -184,6 +186,5 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
184
186
  self.watch(
185
187
  progress_callback=progress_callback,
186
188
  max_wait_time=max_wait_time,
187
- poll_interval=poll_interval
189
+ poll_interval=poll_interval,
188
190
  )
189
-