oracle-ads 2.13.11__py3-none-any.whl → 2.13.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/app.py +73 -15
- ads/aqua/cli.py +17 -0
- ads/aqua/client/client.py +38 -21
- ads/aqua/client/openai_client.py +20 -10
- ads/aqua/common/entities.py +78 -12
- ads/aqua/common/utils.py +35 -0
- ads/aqua/constants.py +2 -0
- ads/aqua/evaluation/evaluation.py +5 -4
- ads/aqua/extension/common_handler.py +47 -2
- ads/aqua/extension/model_handler.py +51 -9
- ads/aqua/model/constants.py +1 -0
- ads/aqua/model/enums.py +19 -1
- ads/aqua/model/model.py +119 -51
- ads/aqua/model/utils.py +1 -2
- ads/aqua/modeldeployment/config_loader.py +815 -0
- ads/aqua/modeldeployment/constants.py +4 -1
- ads/aqua/modeldeployment/deployment.py +178 -129
- ads/aqua/modeldeployment/entities.py +150 -178
- ads/aqua/modeldeployment/model_group_config.py +233 -0
- ads/aqua/modeldeployment/utils.py +0 -539
- ads/aqua/verify_policies/__init__.py +8 -0
- ads/aqua/verify_policies/constants.py +13 -0
- ads/aqua/verify_policies/entities.py +29 -0
- ads/aqua/verify_policies/messages.py +101 -0
- ads/aqua/verify_policies/utils.py +432 -0
- ads/aqua/verify_policies/verify.py +345 -0
- ads/aqua/version.json +3 -0
- ads/common/oci_logging.py +4 -7
- ads/common/work_request.py +39 -38
- ads/jobs/builders/infrastructure/dsc_job.py +121 -24
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +71 -24
- ads/jobs/builders/runtimes/base.py +7 -5
- ads/jobs/builders/runtimes/pytorch_runtime.py +6 -8
- ads/jobs/templates/driver_pytorch.py +486 -172
- ads/jobs/templates/driver_utils.py +27 -11
- ads/model/deployment/model_deployment.py +51 -38
- ads/model/service/oci_datascience_model_deployment.py +6 -11
- ads/telemetry/client.py +4 -4
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/METADATA +2 -1
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/RECORD +43 -34
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/WHEEL +0 -0
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/entry_points.txt +0 -0
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,345 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
import click
|
4
|
+
import oci.exceptions
|
5
|
+
|
6
|
+
from ads.aqua.verify_policies.constants import POLICY_HELP_LINK, TEST_JOB_NAME, TEST_JOB_RUN_NAME, TEST_LIMIT_NAME, \
|
7
|
+
TEST_MVS_NAME, TEST_MD_NAME, TEST_VM_SHAPE
|
8
|
+
from ads.aqua.verify_policies.messages import operation_messages
|
9
|
+
from ads.aqua.verify_policies.entities import OperationResultSuccess, OperationResultFailure, PolicyStatus
|
10
|
+
from ads.aqua.verify_policies.utils import PolicyValidationError, VerifyPoliciesUtils, RichStatusLog
|
11
|
+
from functools import wraps
|
12
|
+
|
13
|
+
logger = logging.getLogger("aqua.policies")
|
14
|
+
|
15
|
+
|
16
|
+
def with_spinner(func):
|
17
|
+
"""Decorator to wrap execution of a function with a rich UI spinner.
|
18
|
+
|
19
|
+
Displays status while the operation runs and logs success or failure messages
|
20
|
+
based on the policy verification result.
|
21
|
+
"""
|
22
|
+
|
23
|
+
@wraps(func)
|
24
|
+
def wrapper(self, function, **kwargs):
|
25
|
+
operation_message = operation_messages[function.__name__]
|
26
|
+
ignore_spinner = kwargs.pop("ignore_spinner", False)
|
27
|
+
|
28
|
+
def run_func():
|
29
|
+
return_value, result_status = func(self, function, **kwargs)
|
30
|
+
result_message = f"{self._rich_ui.get_status_emoji(result_status.status)} {result_status.operation}"
|
31
|
+
if result_status.status == PolicyStatus.SUCCESS:
|
32
|
+
logger.info(result_message)
|
33
|
+
else:
|
34
|
+
logger.warning(result_message)
|
35
|
+
logger.info(result_status.error)
|
36
|
+
logger.info(f"Policy hint: {result_status.policy_hint}")
|
37
|
+
logger.info(f"Refer to: {POLICY_HELP_LINK}")
|
38
|
+
|
39
|
+
return return_value, result_status
|
40
|
+
|
41
|
+
if ignore_spinner:
|
42
|
+
return run_func()
|
43
|
+
else:
|
44
|
+
with self._rich_ui.console.status(f"Verifying {operation_message['name']}") as status:
|
45
|
+
return run_func()
|
46
|
+
|
47
|
+
return wrapper
|
48
|
+
|
49
|
+
|
50
|
+
class AquaVerifyPoliciesApp:
|
51
|
+
"""
|
52
|
+
AquaVerifyPoliciesApp provides methods to verify IAM policies required for
|
53
|
+
various operations in OCI Data Science's AQUA (Accelerated Data Science) platform.
|
54
|
+
|
55
|
+
This utility is intended to help users validate whether they have the necessary
|
56
|
+
permissions to perform common AQUA workflows such as model registration,
|
57
|
+
deployment, evaluation, and fine-tuning.
|
58
|
+
|
59
|
+
Methods
|
60
|
+
-------
|
61
|
+
`common_policies()`: Validates basic read-level policies across AQUA components.
|
62
|
+
`model_register()`: Checks policies for object storage access and model registration.
|
63
|
+
`model_deployment()`: Validates policies for registering and deploying models.
|
64
|
+
`evaluation()`: Confirms ability to manage model version sets, jobs, and storage for evaluation.
|
65
|
+
`finetune()`: Verifies access required to fine-tune models.
|
66
|
+
"""
|
67
|
+
|
68
|
+
def __init__(self):
|
69
|
+
super().__init__()
|
70
|
+
self._util = VerifyPoliciesUtils()
|
71
|
+
self._rich_ui = RichStatusLog()
|
72
|
+
self.model_id = None
|
73
|
+
logger.propagate = False
|
74
|
+
logger.setLevel(logging.INFO)
|
75
|
+
|
76
|
+
def _get_operation_result(self, operation, status):
|
77
|
+
"""Maps a function and policy status to a corresponding result object.
|
78
|
+
|
79
|
+
Parameters:
|
80
|
+
operation (function): The operation being verified.
|
81
|
+
status (PolicyStatus): The outcome of the policy verification.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
OperationResultSuccess or OperationResultFailure based on status.
|
85
|
+
"""
|
86
|
+
operation_message = operation_messages[operation.__name__]
|
87
|
+
if status == PolicyStatus.SUCCESS:
|
88
|
+
return OperationResultSuccess(operation=operation_message["name"])
|
89
|
+
if status == PolicyStatus.UNVERIFIED:
|
90
|
+
return OperationResultSuccess(operation=operation_message["name"], status=status)
|
91
|
+
if status == PolicyStatus.FAILURE:
|
92
|
+
return OperationResultFailure(operation=operation_message["name"], error=operation_message["error"],
|
93
|
+
policy_hint=f"{operation_message['policy_hint']}" )
|
94
|
+
|
95
|
+
@with_spinner
|
96
|
+
def _execute(self, function, **kwargs):
|
97
|
+
"""Executes a given operation function with policy validation and error handling.
|
98
|
+
Parameters:
|
99
|
+
function (callable): The function to execute.
|
100
|
+
kwargs (dict): Keyword arguments to pass to the function.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
Tuple: (result, OperationResult)
|
104
|
+
"""
|
105
|
+
result = None
|
106
|
+
try:
|
107
|
+
result = function(**kwargs)
|
108
|
+
status = PolicyStatus.SUCCESS
|
109
|
+
except oci.exceptions.ServiceError as oci_error:
|
110
|
+
if oci_error.status == 404:
|
111
|
+
logger.debug(oci_error)
|
112
|
+
status = PolicyStatus.FAILURE
|
113
|
+
else:
|
114
|
+
logger.error(oci_error)
|
115
|
+
raise oci_error
|
116
|
+
except PolicyValidationError as policy_error:
|
117
|
+
status = PolicyStatus.FAILURE
|
118
|
+
except Exception as e:
|
119
|
+
logger.error(e)
|
120
|
+
raise e
|
121
|
+
return result, self._get_operation_result(function, status)
|
122
|
+
|
123
|
+
def _test_model_register(self, **kwargs):
|
124
|
+
"""Verifies policies required to manage an object storage bucket and register a model.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
List of result dicts for bucket management and model registration.
|
128
|
+
"""
|
129
|
+
result = []
|
130
|
+
bucket = kwargs.pop("bucket")
|
131
|
+
_, test_manage_obs_policy = self._execute(self._util.manage_bucket, bucket=bucket, **kwargs)
|
132
|
+
result.append(test_manage_obs_policy.to_dict())
|
133
|
+
|
134
|
+
if test_manage_obs_policy.status == PolicyStatus.SUCCESS:
|
135
|
+
self.model_id, test_model_register = self._execute(self._util.register_model)
|
136
|
+
result.append(test_model_register.to_dict())
|
137
|
+
return result
|
138
|
+
|
139
|
+
def _test_delete_model(self, **kwargs):
|
140
|
+
"""Attempts to delete the test model created during model registration.
|
141
|
+
|
142
|
+
Returns:
|
143
|
+
List containing the result of model deletion.
|
144
|
+
"""
|
145
|
+
if self.model_id is not None:
|
146
|
+
_, test_delete_model_test = self._execute(self._util.aqua_model.ds_client.delete_model,
|
147
|
+
model_id=self.model_id, **kwargs)
|
148
|
+
return [test_delete_model_test.to_dict()]
|
149
|
+
else:
|
150
|
+
return [self._get_operation_result(self._util.aqua_model.ds_client.delete_model,
|
151
|
+
PolicyStatus.UNVERIFIED).to_dict()]
|
152
|
+
|
153
|
+
def _test_model_deployment(self, **kwargs):
|
154
|
+
"""Verifies policies required to create and delete a model deployment.
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
List of result dicts for deployment creation and deletion.
|
158
|
+
"""
|
159
|
+
logger.info(f"Creating Model Deployment with name {TEST_MD_NAME}")
|
160
|
+
md_ocid, test_model_deployment = self._execute(self._util.create_model_deployment, model_id=self.model_id,
|
161
|
+
instance_shape=TEST_VM_SHAPE)
|
162
|
+
_, test_delete_md = self._execute(self._util.aqua_model.ds_client.delete_model_deployment, model_deployment_id=md_ocid)
|
163
|
+
return [test_model_deployment.to_dict(), test_delete_md.to_dict()]
|
164
|
+
|
165
|
+
def _test_manage_mvs(self, **kwargs):
|
166
|
+
"""Verifies policies required to create and delete a model version set (MVS).
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
List of result dicts for MVS creation and deletion.
|
170
|
+
"""
|
171
|
+
logger.info(f"Creating ModelVersionSet with name {TEST_MVS_NAME}")
|
172
|
+
|
173
|
+
model_mvs, test_create_mvs = self._execute(self._util.create_model_version_set, name=TEST_MVS_NAME)
|
174
|
+
model_mvs_id = model_mvs[0]
|
175
|
+
if model_mvs_id:
|
176
|
+
logger.info(f"Deleting ModelVersionSet {TEST_MVS_NAME}")
|
177
|
+
_, delete_mvs = self._execute(self._util.aqua_model.ds_client.delete_model_version_set,
|
178
|
+
model_version_set_id=model_mvs_id)
|
179
|
+
else:
|
180
|
+
delete_mvs = self._get_operation_result(self._util.aqua_model.ds_client.delete_model_version_set,
|
181
|
+
PolicyStatus.UNVERIFIED)
|
182
|
+
return [test_create_mvs.to_dict(), delete_mvs.to_dict()]
|
183
|
+
|
184
|
+
def _test_manage_job(self, **kwargs):
|
185
|
+
"""Verifies policies required to create a job, create a job run, and delete the job.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
List of result dicts for job creation, job run creation, and job deletion.
|
189
|
+
"""
|
190
|
+
|
191
|
+
logger.info(f"Creating Job with name {TEST_JOB_NAME}")
|
192
|
+
|
193
|
+
# Create Job & JobRun.
|
194
|
+
job_id, test_create_job = self._execute(self._util.create_job, display_name=TEST_JOB_NAME,
|
195
|
+
**kwargs)
|
196
|
+
|
197
|
+
logger.info(f"Creating JobRun with name {TEST_JOB_RUN_NAME}")
|
198
|
+
|
199
|
+
_, test_create_job_run = self._execute(self._util.create_job_run, display_name=TEST_JOB_RUN_NAME,
|
200
|
+
job_id=job_id, **kwargs)
|
201
|
+
|
202
|
+
# Delete Job Run
|
203
|
+
if job_id:
|
204
|
+
_, delete_job = self._execute(self._util.aqua_model.ds_client.delete_job, job_id=job_id, delete_related_job_runs=True)
|
205
|
+
else:
|
206
|
+
delete_job = self._get_operation_result(self._util.aqua_model.ds_client.delete_job, PolicyStatus.UNVERIFIED)
|
207
|
+
|
208
|
+
return [test_create_job.to_dict(), test_create_job_run.to_dict(), delete_job.to_dict()]
|
209
|
+
|
210
|
+
def _prompt(self, message, bool=False):
|
211
|
+
"""Wrapper for Click prompt or confirmation.
|
212
|
+
|
213
|
+
Parameters:
|
214
|
+
message (str): The prompt message.
|
215
|
+
bool (bool): Whether to ask for confirmation instead of input.
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
User input or confirmation (bool/str).
|
219
|
+
"""
|
220
|
+
if bool:
|
221
|
+
return click.confirm(message, default=False)
|
222
|
+
else:
|
223
|
+
return click.prompt(message, type=str)
|
224
|
+
|
225
|
+
def _consent(self):
|
226
|
+
"""
|
227
|
+
Prompts the user for confirmation before performing actions.
|
228
|
+
Exits if the user does not consent.
|
229
|
+
"""
|
230
|
+
answer = self._prompt("Do you want to continue?", bool=True)
|
231
|
+
if not answer:
|
232
|
+
exit(0)
|
233
|
+
|
234
|
+
def common_policies(self, **kwargs):
|
235
|
+
"""Verifies basic read-level policies across various AQUA components
|
236
|
+
(e.g. compartments, models, jobs, buckets, logs).
|
237
|
+
|
238
|
+
Returns:
|
239
|
+
List of result dicts for each verified operation.
|
240
|
+
"""
|
241
|
+
logger.info("[magenta]Verifying Common Policies")
|
242
|
+
basic_operations = [self._util.list_compartments, self._util.list_models, self._util.list_model_version_sets,
|
243
|
+
self._util.list_project, self._util.list_jobs, self._util.list_job_runs,
|
244
|
+
self._util.list_buckets,
|
245
|
+
self._util.list_log_groups
|
246
|
+
]
|
247
|
+
result = []
|
248
|
+
for op in basic_operations:
|
249
|
+
_, status = self._execute(op, **kwargs)
|
250
|
+
result.append(status.to_dict())
|
251
|
+
|
252
|
+
_, get_resource_availability_status = self._execute(self._util.get_resource_availability,
|
253
|
+
limit_name=TEST_LIMIT_NAME)
|
254
|
+
result.append(get_resource_availability_status.to_dict())
|
255
|
+
return result
|
256
|
+
|
257
|
+
def model_register(self, **kwargs):
|
258
|
+
"""Verifies policies required to register a model, including object storage access.
|
259
|
+
|
260
|
+
Returns:
|
261
|
+
List of result dicts for registration and cleanup.
|
262
|
+
"""
|
263
|
+
logger.info("[magenta]Verifying Model Register")
|
264
|
+
logger.info("Object and Model will be created.")
|
265
|
+
kwargs.pop("consent", None) == True or self._consent()
|
266
|
+
|
267
|
+
model_save_bucket = kwargs.pop("bucket", None) or self._prompt(
|
268
|
+
"Provide bucket name where model artifacts will be saved")
|
269
|
+
register_model_result = self._test_model_register(bucket=model_save_bucket)
|
270
|
+
delete_model_result = self._test_delete_model(**kwargs)
|
271
|
+
return [*register_model_result, *delete_model_result]
|
272
|
+
|
273
|
+
def model_deployment(self, **kwargs):
|
274
|
+
"""Verifies policies required to register and deploy a model, and perform cleanup.
|
275
|
+
|
276
|
+
Returns:
|
277
|
+
List of result dicts for registration, deployment, and cleanup.
|
278
|
+
"""
|
279
|
+
logger.info("[magenta]Verifying Model Deployment")
|
280
|
+
logger.info("Object, Model, Model deployment will be created.")
|
281
|
+
kwargs.pop("consent", None) == True or self._consent()
|
282
|
+
model_save_bucket = kwargs.pop("bucket", None) or self._prompt(
|
283
|
+
"Provide bucket name where model artifacts will be saved")
|
284
|
+
model_register = self._test_model_register(bucket=model_save_bucket)
|
285
|
+
model_deployment = self._test_model_deployment()
|
286
|
+
delete_model_result = self._test_delete_model(**kwargs)
|
287
|
+
|
288
|
+
return [*model_register, *model_deployment, *delete_model_result]
|
289
|
+
|
290
|
+
def evaluation(self, **kwargs):
|
291
|
+
"""Verifies policies for evaluation workloads including model version set,
|
292
|
+
job and job runs, and object storage access.
|
293
|
+
|
294
|
+
Returns:
|
295
|
+
List of result dicts for all evaluation steps.
|
296
|
+
"""
|
297
|
+
logger.info("[magenta]Verifying Evaluation")
|
298
|
+
logger.info("Model Version Set, Model, Object, Job and JobRun will be created.")
|
299
|
+
kwargs.pop("consent", None) == True or self._consent()
|
300
|
+
|
301
|
+
# Create & Delete MVS
|
302
|
+
test_manage_mvs = self._test_manage_mvs(**kwargs)
|
303
|
+
|
304
|
+
# Create & Model
|
305
|
+
model_save_bucket = kwargs.pop("bucket", None) or self._prompt(
|
306
|
+
"Provide bucket name where model artifacts will be saved")
|
307
|
+
register_model_result = self._test_model_register(bucket=model_save_bucket)
|
308
|
+
delete_model_result = self._test_delete_model(**kwargs)
|
309
|
+
|
310
|
+
# Manage Jobs & Job Runs
|
311
|
+
test_job_and_job_run = self._test_manage_job(**kwargs)
|
312
|
+
|
313
|
+
return [*test_manage_mvs, *register_model_result, *delete_model_result, *test_job_and_job_run]
|
314
|
+
|
315
|
+
def finetune(self, **kwargs):
|
316
|
+
"""Verifies policies for fine-tuning jobs, including managing object storage,
|
317
|
+
MVS.
|
318
|
+
|
319
|
+
Returns:
|
320
|
+
List of result dicts for each fine-tuning operation.
|
321
|
+
"""
|
322
|
+
logger.info("[magenta]Verifying Finetuning")
|
323
|
+
logger.info("Object, Model Version Set, Job and JobRun will be created. VCN will be used.")
|
324
|
+
kwargs.pop("consent", None) == True or self._consent()
|
325
|
+
|
326
|
+
# Manage bucket
|
327
|
+
bucket = kwargs.pop("bucket", None) or self._prompt(
|
328
|
+
"Provide bucket name required to save training datasets, scripts, and fine-tuned model outputs")
|
329
|
+
|
330
|
+
subnet_id = kwargs.pop("subnet_id", None)
|
331
|
+
ignore_subnet = kwargs.pop("ignore_subnet", False)
|
332
|
+
|
333
|
+
if subnet_id is None and not ignore_subnet and self._prompt("Do you want to use custom subnet", bool=True):
|
334
|
+
subnet_id = self._prompt("Provide subnet id")
|
335
|
+
|
336
|
+
_, test_manage_obs_policy = self._execute(self._util.manage_bucket, bucket=bucket, **kwargs)
|
337
|
+
|
338
|
+
# Create & Delete MVS
|
339
|
+
test_manage_mvs = self._test_manage_mvs(**kwargs)
|
340
|
+
|
341
|
+
# Manage Jobs & Job Runs
|
342
|
+
test_job_and_job_run = self._test_manage_job(subnet_id = subnet_id, **kwargs)
|
343
|
+
|
344
|
+
return [*test_manage_mvs, *test_job_and_job_run, test_manage_obs_policy.to_dict()]
|
345
|
+
|
ads/aqua/version.json
ADDED
ads/common/oci_logging.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
3
|
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
@@ -7,16 +6,16 @@
|
|
7
6
|
import datetime
|
8
7
|
import logging
|
9
8
|
import time
|
10
|
-
from typing import Dict,
|
9
|
+
from typing import Dict, List, Union
|
11
10
|
|
11
|
+
import oci.exceptions
|
12
12
|
import oci.logging
|
13
13
|
import oci.loggingsearch
|
14
|
-
|
14
|
+
|
15
15
|
from ads.common.decorator.utils import class_or_instance_method
|
16
16
|
from ads.common.oci_mixin import OCIModelMixin, OCIWorkRequestMixin
|
17
17
|
from ads.common.oci_resource import OCIResource, ResourceNotFoundError
|
18
18
|
|
19
|
-
|
20
19
|
logger = logging.getLogger(__name__)
|
21
20
|
|
22
21
|
# Maximum number of log records to be returned by default.
|
@@ -862,9 +861,7 @@ class ConsolidatedLog:
|
|
862
861
|
time_start=time_start,
|
863
862
|
log_filter=log_filter,
|
864
863
|
)
|
865
|
-
self._print(
|
866
|
-
sorted(tail_logs, key=lambda log: log["time"])
|
867
|
-
)
|
864
|
+
self._print(sorted(tail_logs, key=lambda log: log["time"]))
|
868
865
|
|
869
866
|
def head(
|
870
867
|
self,
|
ads/common/work_request.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
#
|
3
|
-
|
4
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
5
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
4
|
|
7
5
|
import logging
|
@@ -12,6 +10,7 @@ from typing import Callable
|
|
12
10
|
import oci
|
13
11
|
from oci import Signer
|
14
12
|
from tqdm.auto import tqdm
|
13
|
+
|
15
14
|
from ads.common.oci_datascience import OCIDataScienceMixin
|
16
15
|
|
17
16
|
logger = logging.getLogger(__name__)
|
@@ -20,10 +19,10 @@ WORK_REQUEST_STOP_STATE = ("SUCCEEDED", "FAILED", "CANCELED")
|
|
20
19
|
DEFAULT_WAIT_TIME = 1200
|
21
20
|
DEFAULT_POLL_INTERVAL = 10
|
22
21
|
WORK_REQUEST_PERCENTAGE = 100
|
23
|
-
# default tqdm progress bar format:
|
22
|
+
# default tqdm progress bar format:
|
24
23
|
# {l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]
|
25
24
|
# customize the bar format to remove the {n_fmt}/{total_fmt} from the right side
|
26
|
-
DEFAULT_BAR_FORMAT =
|
25
|
+
DEFAULT_BAR_FORMAT = "{l_bar}{bar}| [{elapsed}<{remaining}, " "{rate_fmt}{postfix}]"
|
27
26
|
|
28
27
|
|
29
28
|
class DataScienceWorkRequest(OCIDataScienceMixin):
|
@@ -32,13 +31,13 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
|
|
32
31
|
"""
|
33
32
|
|
34
33
|
def __init__(
|
35
|
-
self,
|
36
|
-
id: str,
|
34
|
+
self,
|
35
|
+
id: str,
|
37
36
|
description: str = "Processing",
|
38
|
-
config: dict = None,
|
39
|
-
signer: Signer = None,
|
40
|
-
client_kwargs: dict = None,
|
41
|
-
**kwargs
|
37
|
+
config: dict = None,
|
38
|
+
signer: Signer = None,
|
39
|
+
client_kwargs: dict = None,
|
40
|
+
**kwargs,
|
42
41
|
) -> None:
|
43
42
|
"""Initializes ADSWorkRequest object.
|
44
43
|
|
@@ -49,41 +48,43 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
|
|
49
48
|
description: str
|
50
49
|
Progress bar initial step description (Defaults to `Processing`).
|
51
50
|
config : dict, optional
|
52
|
-
OCI API key config dictionary to initialize
|
51
|
+
OCI API key config dictionary to initialize
|
53
52
|
oci.data_science.DataScienceClient (Defaults to None).
|
54
53
|
signer : oci.signer.Signer, optional
|
55
|
-
OCI authentication signer to initialize
|
54
|
+
OCI authentication signer to initialize
|
56
55
|
oci.data_science.DataScienceClient (Defaults to None).
|
57
56
|
client_kwargs : dict, optional
|
58
|
-
Additional client keyword arguments to initialize
|
57
|
+
Additional client keyword arguments to initialize
|
59
58
|
oci.data_science.DataScienceClient (Defaults to None).
|
60
59
|
kwargs:
|
61
|
-
Additional keyword arguments to initialize
|
60
|
+
Additional keyword arguments to initialize
|
62
61
|
oci.data_science.DataScienceClient.
|
63
62
|
"""
|
64
63
|
self.id = id
|
65
64
|
self._description = description
|
66
65
|
self._percentage = 0
|
67
66
|
self._status = None
|
67
|
+
self._error_message = ""
|
68
68
|
super().__init__(config, signer, client_kwargs, **kwargs)
|
69
|
-
|
70
69
|
|
71
70
|
def _sync(self):
|
72
71
|
"""Fetches the latest work request information to ADSWorkRequest object."""
|
73
72
|
work_request = self.client.get_work_request(self.id).data
|
74
|
-
work_request_logs = self.client.list_work_request_logs(
|
75
|
-
self.id
|
76
|
-
).data
|
73
|
+
work_request_logs = self.client.list_work_request_logs(self.id).data
|
77
74
|
|
78
|
-
self._percentage= work_request.percent_complete
|
75
|
+
self._percentage = work_request.percent_complete
|
79
76
|
self._status = work_request.status
|
80
|
-
self._description =
|
77
|
+
self._description = (
|
78
|
+
work_request_logs[-1].message if work_request_logs else "Processing"
|
79
|
+
)
|
80
|
+
if work_request.status == "FAILED":
|
81
|
+
self._error_message = self.client.list_work_request_errors(self.id).data
|
81
82
|
|
82
83
|
def watch(
|
83
|
-
self,
|
84
|
+
self,
|
84
85
|
progress_callback: Callable,
|
85
|
-
max_wait_time: int=DEFAULT_WAIT_TIME,
|
86
|
-
poll_interval: int=DEFAULT_POLL_INTERVAL,
|
86
|
+
max_wait_time: int = DEFAULT_WAIT_TIME,
|
87
|
+
poll_interval: int = DEFAULT_POLL_INTERVAL,
|
87
88
|
):
|
88
89
|
"""Updates the progress bar with realtime message and percentage until the process is completed.
|
89
90
|
|
@@ -92,10 +93,10 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
|
|
92
93
|
progress_callback: Callable
|
93
94
|
Progress bar callback function.
|
94
95
|
It must accept `(percent_change, description)` where `percent_change` is the
|
95
|
-
work request percent complete and `description` is the latest work request log message.
|
96
|
+
work request percent complete and `description` is the latest work request log message.
|
96
97
|
max_wait_time: int
|
97
98
|
Maximum amount of time to wait in seconds (Defaults to 1200).
|
98
|
-
Negative implies infinite wait time.
|
99
|
+
Negative implies infinite wait time.
|
99
100
|
poll_interval: int
|
100
101
|
Poll interval in seconds (Defaults to 10).
|
101
102
|
|
@@ -107,7 +108,6 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
|
|
107
108
|
|
108
109
|
start_time = time.time()
|
109
110
|
while self._percentage < 100:
|
110
|
-
|
111
111
|
seconds_since = time.time() - start_time
|
112
112
|
if max_wait_time > 0 and seconds_since >= max_wait_time:
|
113
113
|
logger.error(f"Exceeded max wait time of {max_wait_time} seconds.")
|
@@ -124,12 +124,14 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
|
|
124
124
|
percent_change = self._percentage - previous_percent_complete
|
125
125
|
previous_percent_complete = self._percentage
|
126
126
|
progress_callback(
|
127
|
-
percent_change=percent_change,
|
128
|
-
description=self._description
|
127
|
+
percent_change=percent_change, description=self._description
|
129
128
|
)
|
130
129
|
|
131
130
|
if self._status in WORK_REQUEST_STOP_STATE:
|
132
|
-
if
|
131
|
+
if (
|
132
|
+
self._status
|
133
|
+
!= oci.work_requests.models.WorkRequest.STATUS_SUCCEEDED
|
134
|
+
):
|
133
135
|
if self._description:
|
134
136
|
raise Exception(self._description)
|
135
137
|
else:
|
@@ -145,12 +147,12 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
|
|
145
147
|
|
146
148
|
def wait_work_request(
|
147
149
|
self,
|
148
|
-
progress_bar_description: str="Processing",
|
149
|
-
max_wait_time: int=DEFAULT_WAIT_TIME,
|
150
|
-
poll_interval: int=DEFAULT_POLL_INTERVAL
|
150
|
+
progress_bar_description: str = "Processing",
|
151
|
+
max_wait_time: int = DEFAULT_WAIT_TIME,
|
152
|
+
poll_interval: int = DEFAULT_POLL_INTERVAL,
|
151
153
|
):
|
152
154
|
"""Waits for the work request progress bar to be completed.
|
153
|
-
|
155
|
+
|
154
156
|
Parameters
|
155
157
|
----------
|
156
158
|
progress_bar_description: str
|
@@ -160,7 +162,7 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
|
|
160
162
|
Negative implies infinite wait time.
|
161
163
|
poll_interval: int
|
162
164
|
Poll interval in seconds (Defaults to 10).
|
163
|
-
|
165
|
+
|
164
166
|
Returns
|
165
167
|
-------
|
166
168
|
None
|
@@ -172,7 +174,7 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
|
|
172
174
|
mininterval=0,
|
173
175
|
file=sys.stdout,
|
174
176
|
desc=progress_bar_description,
|
175
|
-
bar_format=DEFAULT_BAR_FORMAT
|
177
|
+
bar_format=DEFAULT_BAR_FORMAT,
|
176
178
|
) as pbar:
|
177
179
|
|
178
180
|
def progress_callback(percent_change, description):
|
@@ -184,6 +186,5 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
|
|
184
186
|
self.watch(
|
185
187
|
progress_callback=progress_callback,
|
186
188
|
max_wait_time=max_wait_time,
|
187
|
-
poll_interval=poll_interval
|
189
|
+
poll_interval=poll_interval,
|
188
190
|
)
|
189
|
-
|