oracle-ads 2.10.0__py3-none-any.whl → 2.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +12 -0
- ads/aqua/base.py +324 -0
- ads/aqua/cli.py +19 -0
- ads/aqua/config/deployment_config_defaults.json +9 -0
- ads/aqua/config/resource_limit_names.json +7 -0
- ads/aqua/constants.py +45 -0
- ads/aqua/data.py +40 -0
- ads/aqua/decorator.py +101 -0
- ads/aqua/deployment.py +643 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation.py +1751 -0
- ads/aqua/exception.py +82 -0
- ads/aqua/extension/__init__.py +40 -0
- ads/aqua/extension/base_handler.py +138 -0
- ads/aqua/extension/common_handler.py +21 -0
- ads/aqua/extension/deployment_handler.py +202 -0
- ads/aqua/extension/evaluation_handler.py +135 -0
- ads/aqua/extension/finetune_handler.py +66 -0
- ads/aqua/extension/model_handler.py +59 -0
- ads/aqua/extension/ui_handler.py +201 -0
- ads/aqua/extension/utils.py +23 -0
- ads/aqua/finetune.py +579 -0
- ads/aqua/job.py +29 -0
- ads/aqua/model.py +819 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +459 -0
- ads/aqua/ui.py +453 -0
- ads/aqua/utils.py +715 -0
- ads/cli.py +37 -6
- ads/common/auth.py +7 -0
- ads/common/decorator/__init__.py +7 -3
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/object_storage_details.py +166 -7
- ads/common/oci_client.py +18 -1
- ads/common/oci_logging.py +2 -2
- ads/common/oci_mixin.py +4 -5
- ads/common/serializer.py +34 -5
- ads/common/utils.py +75 -10
- ads/config.py +40 -1
- ads/dataset/correlation_plot.py +10 -12
- ads/jobs/ads_job.py +43 -25
- ads/jobs/builders/infrastructure/base.py +4 -2
- ads/jobs/builders/infrastructure/dsc_job.py +49 -39
- ads/jobs/builders/runtimes/base.py +71 -1
- ads/jobs/builders/runtimes/container_runtime.py +4 -4
- ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
- ads/jobs/templates/driver_pytorch.py +27 -10
- ads/model/artifact_downloader.py +84 -14
- ads/model/artifact_uploader.py +25 -23
- ads/model/datascience_model.py +388 -38
- ads/model/deployment/model_deployment.py +10 -2
- ads/model/generic_model.py +8 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_metadata.py +1 -1
- ads/model/service/oci_datascience_model.py +34 -5
- ads/opctl/config/merger.py +2 -2
- ads/opctl/operator/__init__.py +3 -1
- ads/opctl/operator/cli.py +7 -1
- ads/opctl/operator/cmd.py +3 -3
- ads/opctl/operator/common/errors.py +2 -1
- ads/opctl/operator/common/operator_config.py +22 -3
- ads/opctl/operator/common/utils.py +16 -0
- ads/opctl/operator/lowcode/anomaly/MLoperator +15 -0
- ads/opctl/operator/lowcode/anomaly/README.md +209 -0
- ads/opctl/operator/lowcode/anomaly/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/__main__.py +104 -0
- ads/opctl/operator/lowcode/anomaly/cmd.py +35 -0
- ads/opctl/operator/lowcode/anomaly/const.py +88 -0
- ads/opctl/operator/lowcode/anomaly/environment.yaml +12 -0
- ads/opctl/operator/lowcode/anomaly/model/__init__.py +5 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +147 -0
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +89 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +103 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +354 -0
- ads/opctl/operator/lowcode/anomaly/model/factory.py +67 -0
- ads/opctl/operator/lowcode/anomaly/model/tods.py +119 -0
- ads/opctl/operator/lowcode/anomaly/operator_config.py +105 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +359 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +81 -0
- ads/opctl/operator/lowcode/common/__init__.py +5 -0
- ads/opctl/operator/lowcode/common/const.py +10 -0
- ads/opctl/operator/lowcode/common/data.py +96 -0
- ads/opctl/operator/lowcode/common/errors.py +41 -0
- ads/opctl/operator/lowcode/common/transformations.py +191 -0
- ads/opctl/operator/lowcode/common/utils.py +250 -0
- ads/opctl/operator/lowcode/forecast/README.md +3 -2
- ads/opctl/operator/lowcode/forecast/__main__.py +18 -2
- ads/opctl/operator/lowcode/forecast/cmd.py +8 -7
- ads/opctl/operator/lowcode/forecast/const.py +17 -1
- ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
- ads/opctl/operator/lowcode/forecast/model/arima.py +106 -117
- ads/opctl/operator/lowcode/forecast/model/automlx.py +204 -180
- ads/opctl/operator/lowcode/forecast/model/autots.py +144 -253
- ads/opctl/operator/lowcode/forecast/model/base_model.py +326 -259
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +325 -176
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +293 -237
- ads/opctl/operator/lowcode/forecast/model/prophet.py +191 -208
- ads/opctl/operator/lowcode/forecast/operator_config.py +24 -33
- ads/opctl/operator/lowcode/forecast/schema.yaml +116 -29
- ads/opctl/operator/lowcode/forecast/utils.py +186 -356
- ads/opctl/operator/lowcode/pii/model/guardrails.py +18 -15
- ads/opctl/operator/lowcode/pii/model/report.py +7 -7
- ads/opctl/operator/lowcode/pii/operator_config.py +1 -8
- ads/opctl/operator/lowcode/pii/utils.py +0 -82
- ads/opctl/operator/runtime/runtime.py +3 -2
- ads/telemetry/base.py +62 -0
- ads/telemetry/client.py +105 -0
- ads/telemetry/telemetry.py +6 -3
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +44 -7
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +116 -59
- ads/opctl/operator/lowcode/forecast/model/transformations.py +0 -125
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.10.0.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
@@ -1,16 +1,20 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8; -*-
|
3
|
+
|
4
|
+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
4
7
|
from ads.jobs.builders.runtimes.artifact import PythonArtifact, GitPythonArtifact
|
8
|
+
from ads.jobs.builders.runtimes.base import MultiNodeRuntime
|
5
9
|
from ads.jobs.builders.runtimes.python_runtime import (
|
6
10
|
PythonRuntime,
|
7
11
|
GitPythonRuntime,
|
8
12
|
)
|
9
13
|
|
10
14
|
|
11
|
-
class PyTorchDistributedRuntime(PythonRuntime):
|
15
|
+
class PyTorchDistributedRuntime(PythonRuntime, MultiNodeRuntime):
|
16
|
+
"""Represents runtime supporting PyTorch Distributed training."""
|
12
17
|
CONST_GIT = "git"
|
13
|
-
CONST_REPLICA = "replicas"
|
14
18
|
CONST_INPUT = "inputs"
|
15
19
|
CONST_DEP = "dependencies"
|
16
20
|
CONST_PIP_REQ = "pipRequirements"
|
@@ -79,26 +83,6 @@ class PyTorchDistributedRuntime(PythonRuntime):
|
|
79
83
|
"""The input files to be copied into the job run."""
|
80
84
|
return self.get_spec(self.CONST_INPUT)
|
81
85
|
|
82
|
-
def with_replica(self, count: int):
|
83
|
-
"""Specifies the number of nodes (job runs) for the job.
|
84
|
-
|
85
|
-
Parameters
|
86
|
-
----------
|
87
|
-
count : int
|
88
|
-
Number of nodes (job runs)
|
89
|
-
|
90
|
-
Returns
|
91
|
-
-------
|
92
|
-
self
|
93
|
-
The runtime instance.
|
94
|
-
"""
|
95
|
-
return self.set_spec(self.CONST_REPLICA, count)
|
96
|
-
|
97
|
-
@property
|
98
|
-
def replica(self) -> int:
|
99
|
-
"""The number of nodes (job runs)."""
|
100
|
-
return self.get_spec(self.CONST_REPLICA)
|
101
|
-
|
102
86
|
def with_dependency(self, pip_req=None, pip_pkg=None):
|
103
87
|
"""Specifies additional dependencies to be installed using pip.
|
104
88
|
|
@@ -185,7 +169,7 @@ class PyTorchDistributedRuntime(PythonRuntime):
|
|
185
169
|
def command(self):
|
186
170
|
"""The command for launching the workload."""
|
187
171
|
return self.get_spec(self.CONST_COMMAND)
|
188
|
-
|
172
|
+
|
189
173
|
@property
|
190
174
|
def use_deepspeed(self):
|
191
175
|
"""Indicate whether whether to configure deepspeed for multi-node workload"""
|
@@ -193,43 +177,6 @@ class PyTorchDistributedRuntime(PythonRuntime):
|
|
193
177
|
return True
|
194
178
|
return False
|
195
179
|
|
196
|
-
def run(self, dsc_job, **kwargs):
|
197
|
-
"""Starts the job runs"""
|
198
|
-
replicas = self.replica if self.replica else 1
|
199
|
-
main_run = None
|
200
|
-
job_runs = []
|
201
|
-
try:
|
202
|
-
for i in range(replicas):
|
203
|
-
replica_kwargs = kwargs.copy()
|
204
|
-
envs = replica_kwargs.get("environment_variables")
|
205
|
-
if not envs:
|
206
|
-
envs = {}
|
207
|
-
# Huggingface accelerate requires machine rank
|
208
|
-
# Here we use NODE_RANK to store the machine rank
|
209
|
-
envs["NODE_RANK"] = str(i)
|
210
|
-
envs["WORLD_SIZE"] = str(replicas)
|
211
|
-
if main_run:
|
212
|
-
envs["MAIN_JOB_RUN_OCID"] = main_run.id
|
213
|
-
name = replica_kwargs.get("display_name")
|
214
|
-
if not name:
|
215
|
-
name = dsc_job.display_name
|
216
|
-
|
217
|
-
replica_kwargs["display_name"] = f"{name}-{str(i)}"
|
218
|
-
replica_kwargs["environment_variables"] = envs
|
219
|
-
run = dsc_job.run(**replica_kwargs)
|
220
|
-
job_runs.append(run)
|
221
|
-
if i == 0:
|
222
|
-
main_run = run
|
223
|
-
except Exception:
|
224
|
-
traceback.print_exc()
|
225
|
-
# Wait a few second to avoid the job run being in a transient state.
|
226
|
-
time.sleep(2)
|
227
|
-
# If there is any error when creating the job runs
|
228
|
-
# cancel all the job runs.
|
229
|
-
for run in job_runs:
|
230
|
-
run.cancel()
|
231
|
-
return main_run
|
232
|
-
|
233
180
|
|
234
181
|
class PyTorchDistributedArtifact(PythonArtifact):
|
235
182
|
CONST_DRIVER_SCRIPT = "driver_pytorch.py"
|
@@ -5,6 +5,7 @@
|
|
5
5
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
6
|
"""This module requires oracle-ads>=2.6.8
|
7
7
|
"""
|
8
|
+
import getpass
|
8
9
|
import ipaddress
|
9
10
|
import logging
|
10
11
|
import multiprocessing
|
@@ -40,17 +41,32 @@ logger = logging.getLogger(__name__)
|
|
40
41
|
logger = driver_utils.set_log_level(logger)
|
41
42
|
|
42
43
|
|
44
|
+
# Envs provisioned by the service
|
43
45
|
CONST_ENV_HOST_JOB_RUN_OCID = "MAIN_JOB_RUN_OCID"
|
44
46
|
CONST_ENV_JOB_RUN_OCID = "JOB_RUN_OCID"
|
45
|
-
|
47
|
+
# Envs set by the ADS API
|
48
|
+
OCI__WORKER_COUNT = "OCI__WORKER_COUNT"
|
49
|
+
CONST_ENV_NODE_RANK = "NODE_RANK"
|
50
|
+
CONST_ENV_NODE_COUNT = "NODE_COUNT"
|
46
51
|
CONST_ENV_LAUNCH_CMD = "OCI__LAUNCH_CMD"
|
47
52
|
CONST_ENV_DEEPSPEED = "OCI__DEEPSPEED"
|
53
|
+
# Envs set by this module
|
54
|
+
CONST_ENV_WORLD_SIZE = "WORLD_SIZE"
|
55
|
+
CONST_ENV_LD_PRELOAD = "LD_PRELOAD"
|
56
|
+
# Envs for debugging only
|
57
|
+
# OCI_ODSC_SERVICE_ENDPOINT is used for all processes in the job run
|
58
|
+
CONST_ENV_ODSC_SERVICE_ENDPOINT = "OCI_ODSC_SERVICE_ENDPOINT"
|
59
|
+
# OCI_DS_SERVICE_ENDPOINT is used only by the training process
|
60
|
+
CONST_ENV_DS_SERVICE_ENDPOINT = "OCI_DS_SERVICE_ENDPOINT"
|
61
|
+
|
62
|
+
# Constants used in logs
|
48
63
|
LOG_PREFIX_HOST_IP = "Distributed Training HOST IP: "
|
49
64
|
LOG_PREFIX_NODE_IP = "Node IP: "
|
50
65
|
LOG_PREFIX_PUBLIC_KEY = "HOST PUBLIC KEY: "
|
51
|
-
|
52
|
-
#
|
53
|
-
|
66
|
+
# Other constants used within this script
|
67
|
+
# Other constants used within this script
|
68
|
+
USER_HOME = os.environ.get("HOME", f"/home/{getpass.getuser()}")
|
69
|
+
SSH_DIR = os.environ.get("OCI__SSH_DIR", os.path.join(USER_HOME, ".ssh"))
|
54
70
|
DEFAULT_LAUNCHER = "torchrun"
|
55
71
|
|
56
72
|
# Set authentication method to resource principal
|
@@ -131,8 +147,11 @@ class Runner(driver_utils.JobRunner):
|
|
131
147
|
|
132
148
|
self.host_job_run = DataScienceJobRun.from_ocid(self.host_ocid)
|
133
149
|
self.entrypoint_env = PythonRuntimeHandler.CONST_CODE_ENTRYPOINT
|
134
|
-
# The total number of
|
135
|
-
|
150
|
+
# The total number of nodes is OCI__WORKER_COUNT + 1
|
151
|
+
if CONST_ENV_NODE_COUNT in os.environ:
|
152
|
+
self.node_count = int(os.environ[CONST_ENV_NODE_COUNT])
|
153
|
+
else:
|
154
|
+
self.node_count = int(os.environ.get(OCI__WORKER_COUNT, 0)) + 1
|
136
155
|
logger.debug("Node count: %s", self.node_count)
|
137
156
|
self.gpu_count = torch.cuda.device_count()
|
138
157
|
logger.debug("GPU count on this node: %s", self.gpu_count)
|
@@ -343,9 +362,7 @@ class Runner(driver_utils.JobRunner):
|
|
343
362
|
if self.launch_cmd:
|
344
363
|
if self.LAUNCHER:
|
345
364
|
if not self.launch_cmd.startswith(self.LAUNCHER):
|
346
|
-
raise ValueError(
|
347
|
-
f"Command not supported: '{self.launch_cmd}'. "
|
348
|
-
)
|
365
|
+
raise ValueError(f"Command not supported: '{self.launch_cmd}'. ")
|
349
366
|
|
350
367
|
launch_args.append(self.launch_cmd[len(self.LAUNCHER) + 1 :])
|
351
368
|
else:
|
@@ -689,7 +706,7 @@ class GenericRunner(TorchRunner, DeepSpeedRunner):
|
|
689
706
|
def set_env_var(self):
|
690
707
|
"""Set default environment variables."""
|
691
708
|
defaults = {
|
692
|
-
"WORLD_SIZE": self.node_count,
|
709
|
+
"WORLD_SIZE": self.node_count * self.gpu_count,
|
693
710
|
"MASTER_ADDR": self.host_ip,
|
694
711
|
"MASTER_PORT": self.RDZV_PORT,
|
695
712
|
}
|
ads/model/artifact_downloader.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
# -*- coding: utf-8; -*-
|
3
3
|
|
4
|
-
# Copyright (c) 2022,
|
4
|
+
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
|
5
5
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
6
|
|
7
7
|
import os
|
8
8
|
import shutil
|
9
|
-
import tempfile
|
10
9
|
import uuid
|
11
10
|
from abc import ABC, abstractmethod
|
12
11
|
from typing import Dict, Optional
|
@@ -15,6 +14,7 @@ from zipfile import ZipFile
|
|
15
14
|
from ads.common import utils
|
16
15
|
from ads.common.utils import extract_region
|
17
16
|
from ads.model.service.oci_datascience_model import OCIDataScienceModel
|
17
|
+
from ads.common.object_storage_details import ObjectStorageDetails
|
18
18
|
|
19
19
|
|
20
20
|
class ArtifactDownloader(ABC):
|
@@ -85,17 +85,32 @@ class SmallArtifactDownloader(ArtifactDownloader):
|
|
85
85
|
def _download(self):
|
86
86
|
"""Downloads model artifacts."""
|
87
87
|
self.progress.update("Importing model artifacts from catalog")
|
88
|
-
|
88
|
+
|
89
|
+
artifact_info = self.dsc_model.get_artifact_info()
|
90
|
+
artifact_name = artifact_info["Content-Disposition"].replace(
|
91
|
+
"attachment; filename=", ""
|
92
|
+
)
|
93
|
+
_, file_extension = os.path.splitext(artifact_name)
|
94
|
+
file_extension = file_extension.lower() if file_extension else ".zip"
|
95
|
+
|
96
|
+
file_content = self.dsc_model.get_model_artifact_content()
|
89
97
|
self.progress.update("Copying model artifacts to the artifact directory")
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
98
|
+
|
99
|
+
file_name = (
|
100
|
+
"model_description" if file_extension == ".json" else str(uuid.uuid4())
|
101
|
+
)
|
102
|
+
artifact_file_path = os.path.join(
|
103
|
+
self.target_dir, f"{file_name}{file_extension}"
|
104
|
+
)
|
105
|
+
with open(artifact_file_path, "wb") as _file:
|
106
|
+
_file.write(file_content)
|
107
|
+
|
108
|
+
if file_extension == ".zip":
|
109
|
+
self.progress.update("Extracting model artifacts")
|
110
|
+
with ZipFile(artifact_file_path) as _file:
|
111
|
+
_file.extractall(self.target_dir)
|
112
|
+
utils.remove_file(artifact_file_path)
|
113
|
+
|
99
114
|
|
100
115
|
class LargeArtifactDownloader(ArtifactDownloader):
|
101
116
|
PROGRESS_STEPS_COUNT = 4
|
@@ -110,6 +125,7 @@ class LargeArtifactDownloader(ArtifactDownloader):
|
|
110
125
|
bucket_uri: Optional[str] = None,
|
111
126
|
overwrite_existing_artifact: Optional[bool] = True,
|
112
127
|
remove_existing_artifact: Optional[bool] = True,
|
128
|
+
model_file_description: Optional[dict] = None,
|
113
129
|
):
|
114
130
|
"""Initializes `LargeArtifactDownloader` instance.
|
115
131
|
|
@@ -137,6 +153,9 @@ class LargeArtifactDownloader(ArtifactDownloader):
|
|
137
153
|
Overwrite target bucket artifact if exists.
|
138
154
|
remove_existing_artifact: (bool, optional). Defaults to `True`.
|
139
155
|
Wether artifacts uploaded to object storage bucket need to be removed or not.
|
156
|
+
model_file_description: (dict, optional). Defaults to None.
|
157
|
+
Contains object path details for models created by reference.
|
158
|
+
|
140
159
|
"""
|
141
160
|
super().__init__(
|
142
161
|
dsc_model=dsc_model, target_dir=target_dir, force_overwrite=force_overwrite
|
@@ -146,12 +165,19 @@ class LargeArtifactDownloader(ArtifactDownloader):
|
|
146
165
|
self.bucket_uri = bucket_uri
|
147
166
|
self.overwrite_existing_artifact = overwrite_existing_artifact
|
148
167
|
self.remove_existing_artifact = remove_existing_artifact
|
168
|
+
self.model_file_description = model_file_description
|
149
169
|
|
150
170
|
def _download(self):
|
151
171
|
"""Downloads model artifacts."""
|
152
172
|
self.progress.update(f"Importing model artifacts from catalog")
|
153
173
|
|
174
|
+
if self.dsc_model.is_model_by_reference() and self.model_file_description:
|
175
|
+
self.download_from_model_file_description()
|
176
|
+
self.progress.update()
|
177
|
+
return
|
178
|
+
|
154
179
|
bucket_uri = self.bucket_uri
|
180
|
+
|
155
181
|
if not os.path.basename(bucket_uri):
|
156
182
|
bucket_uri = os.path.join(bucket_uri, f"{self.dsc_model.id}.zip")
|
157
183
|
elif not bucket_uri.lower().endswith(".zip"):
|
@@ -159,7 +185,6 @@ class LargeArtifactDownloader(ArtifactDownloader):
|
|
159
185
|
|
160
186
|
self.dsc_model.import_model_artifact(bucket_uri=bucket_uri, region=self.region)
|
161
187
|
self.progress.update("Copying model artifacts to the artifact directory")
|
162
|
-
|
163
188
|
zip_file_path = os.path.join(self.target_dir, f"{str(uuid.uuid4())}.zip")
|
164
189
|
zip_file_path = utils.copy_file(
|
165
190
|
uri_src=bucket_uri,
|
@@ -172,7 +197,6 @@ class LargeArtifactDownloader(ArtifactDownloader):
|
|
172
197
|
zip_file.extractall(self.target_dir)
|
173
198
|
|
174
199
|
utils.remove_file(zip_file_path)
|
175
|
-
|
176
200
|
if self.remove_existing_artifact:
|
177
201
|
self.progress.update(
|
178
202
|
"Removing temporary artifacts from the Object Storage bucket"
|
@@ -180,3 +204,49 @@ class LargeArtifactDownloader(ArtifactDownloader):
|
|
180
204
|
utils.remove_file(bucket_uri)
|
181
205
|
else:
|
182
206
|
self.progress.update()
|
207
|
+
|
208
|
+
def download_from_model_file_description(self):
|
209
|
+
"""Helper function to download the objects using model file description content to the target directory."""
|
210
|
+
|
211
|
+
models = self.model_file_description["models"]
|
212
|
+
total_size = 0
|
213
|
+
|
214
|
+
for model in models:
|
215
|
+
namespace, bucket_name, prefix = (
|
216
|
+
model["namespace"],
|
217
|
+
model["bucketName"],
|
218
|
+
model["prefix"],
|
219
|
+
)
|
220
|
+
bucket_uri = f"oci://{bucket_name}@{namespace}/{prefix}"
|
221
|
+
|
222
|
+
message = f"Copying model artifacts by reference from {bucket_uri} to {self.target_dir}"
|
223
|
+
self.progress.update(message)
|
224
|
+
|
225
|
+
objects = model["objects"]
|
226
|
+
os_details_list = list()
|
227
|
+
|
228
|
+
for obj in objects:
|
229
|
+
name = obj["name"]
|
230
|
+
version = None if obj["version"] == "" else obj["version"]
|
231
|
+
size = obj["sizeInBytes"]
|
232
|
+
if size == 0:
|
233
|
+
continue
|
234
|
+
total_size += size
|
235
|
+
object_uri = f"oci://{bucket_name}@{namespace}/{name}"
|
236
|
+
|
237
|
+
os_details = ObjectStorageDetails.from_path(object_uri)
|
238
|
+
os_details.version = version
|
239
|
+
os_details_list.append(os_details)
|
240
|
+
try:
|
241
|
+
ObjectStorageDetails.from_path(
|
242
|
+
bucket_uri
|
243
|
+
).bulk_download_from_object_storage(
|
244
|
+
paths=os_details_list,
|
245
|
+
target_dir=self.target_dir,
|
246
|
+
progress_bar=self.progress,
|
247
|
+
)
|
248
|
+
except Exception as ex:
|
249
|
+
raise RuntimeError(
|
250
|
+
f"Failed to download model artifact by reference from the given Object Storage path `{bucket_uri}`."
|
251
|
+
f"See Exception: {ex}"
|
252
|
+
)
|
ads/model/artifact_uploader.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
# -*- coding: utf-8; -*-
|
3
|
+
import logging
|
3
4
|
|
4
|
-
# Copyright (c) 2022,
|
5
|
+
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
|
5
6
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
7
|
|
7
8
|
import os
|
@@ -38,7 +39,7 @@ class ArtifactUploader(ABC):
|
|
38
39
|
|
39
40
|
self.dsc_model = dsc_model
|
40
41
|
self.artifact_path = artifact_path
|
41
|
-
self.
|
42
|
+
self.artifact_file_path = None
|
42
43
|
self.progress = None
|
43
44
|
|
44
45
|
def upload(self):
|
@@ -48,8 +49,8 @@ class ArtifactUploader(ABC):
|
|
48
49
|
ArtifactUploader.PROGRESS_STEPS_COUNT + self.PROGRESS_STEPS_COUNT
|
49
50
|
) as progress:
|
50
51
|
self.progress = progress
|
51
|
-
self.progress.update("Preparing model artifacts
|
52
|
-
self.
|
52
|
+
self.progress.update("Preparing model artifacts file.")
|
53
|
+
self._prepare_artifact_tmp_file()
|
53
54
|
self.progress.update("Uploading model artifacts.")
|
54
55
|
self._upload()
|
55
56
|
self.progress.update(
|
@@ -59,35 +60,35 @@ class ArtifactUploader(ABC):
|
|
59
60
|
except Exception:
|
60
61
|
raise
|
61
62
|
finally:
|
62
|
-
self.
|
63
|
+
self._remove_artifact_tmp_file()
|
63
64
|
|
64
|
-
def
|
65
|
-
"""Prepares model artifacts
|
65
|
+
def _prepare_artifact_tmp_file(self) -> str:
|
66
|
+
"""Prepares model artifacts file.
|
66
67
|
|
67
68
|
Returns
|
68
69
|
-------
|
69
70
|
str
|
70
|
-
Path to the model artifact
|
71
|
+
Path to the model artifact file.
|
71
72
|
"""
|
72
73
|
if ObjectStorageDetails.is_oci_path(self.artifact_path):
|
73
|
-
self.
|
74
|
+
self.artifact_file_path = self.artifact_path
|
74
75
|
elif os.path.isfile(self.artifact_path) and self.artifact_path.lower().endswith(
|
75
|
-
".zip"
|
76
|
+
(".zip", ".json")
|
76
77
|
):
|
77
|
-
self.
|
78
|
+
self.artifact_file_path = self.artifact_path
|
78
79
|
else:
|
79
|
-
self.
|
80
|
+
self.artifact_file_path = model_utils.zip_artifact(
|
80
81
|
artifact_dir=self.artifact_path
|
81
82
|
)
|
82
|
-
return self.
|
83
|
+
return self.artifact_file_path
|
83
84
|
|
84
|
-
def
|
85
|
-
"""Removes temporary created artifact
|
85
|
+
def _remove_artifact_tmp_file(self):
|
86
|
+
"""Removes temporary created artifact file."""
|
86
87
|
if (
|
87
|
-
self.
|
88
|
-
and self.
|
88
|
+
self.artifact_file_path
|
89
|
+
and self.artifact_file_path.lower() != self.artifact_path.lower()
|
89
90
|
):
|
90
|
-
shutil.rmtree(self.
|
91
|
+
shutil.rmtree(self.artifact_file_path, ignore_errors=True)
|
91
92
|
|
92
93
|
@abstractmethod
|
93
94
|
def _upload(self):
|
@@ -101,9 +102,10 @@ class SmallArtifactUploader(ArtifactUploader):
|
|
101
102
|
|
102
103
|
def _upload(self):
|
103
104
|
"""Uploads model artifacts to the model catalog."""
|
105
|
+
_, ext = os.path.splitext(self.artifact_file_path)
|
104
106
|
self.progress.update("Uploading model artifacts to the catalog")
|
105
|
-
with open(self.
|
106
|
-
self.dsc_model.create_model_artifact(file_data)
|
107
|
+
with open(self.artifact_file_path, "rb") as file_data:
|
108
|
+
self.dsc_model.create_model_artifact(bytes_content=file_data, extension=ext)
|
107
109
|
|
108
110
|
|
109
111
|
class LargeArtifactUploader(ArtifactUploader):
|
@@ -117,7 +119,7 @@ class LargeArtifactUploader(ArtifactUploader):
|
|
117
119
|
- object storage path to zip archive. Example: `oci://<bucket_name>@<namespace>/prefix/mymodel.zip`.
|
118
120
|
- local path to zip archive. Example: `./mymodel.zip`.
|
119
121
|
- local path to folder with artifacts. Example: `./mymodel`.
|
120
|
-
|
122
|
+
artifact_file_path: str
|
121
123
|
The uri of the zip of model artifact.
|
122
124
|
auth: dict
|
123
125
|
The default authetication is set using `ads.set_auth` API.
|
@@ -222,7 +224,7 @@ class LargeArtifactUploader(ArtifactUploader):
|
|
222
224
|
"""Uploads model artifacts to the model catalog."""
|
223
225
|
bucket_uri = self.bucket_uri
|
224
226
|
self.progress.update("Copying model artifact to the Object Storage bucket")
|
225
|
-
if not bucket_uri == self.
|
227
|
+
if not bucket_uri == self.artifact_file_path:
|
226
228
|
bucket_uri_file_name = os.path.basename(bucket_uri)
|
227
229
|
|
228
230
|
if not bucket_uri_file_name:
|
@@ -240,7 +242,7 @@ class LargeArtifactUploader(ArtifactUploader):
|
|
240
242
|
|
241
243
|
try:
|
242
244
|
utils.upload_to_os(
|
243
|
-
src_uri=self.
|
245
|
+
src_uri=self.artifact_file_path,
|
244
246
|
dst_uri=bucket_uri,
|
245
247
|
auth=self.auth,
|
246
248
|
parallel_process_count=self._parallel_process_count,
|