oracle-ads 2.10.1__py3-none-any.whl → 2.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. ads/aqua/__init__.py +12 -0
  2. ads/aqua/base.py +324 -0
  3. ads/aqua/cli.py +19 -0
  4. ads/aqua/config/deployment_config_defaults.json +9 -0
  5. ads/aqua/config/resource_limit_names.json +7 -0
  6. ads/aqua/constants.py +45 -0
  7. ads/aqua/data.py +40 -0
  8. ads/aqua/decorator.py +101 -0
  9. ads/aqua/deployment.py +643 -0
  10. ads/aqua/dummy_data/icon.txt +1 -0
  11. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  12. ads/aqua/dummy_data/oci_models.json +1 -0
  13. ads/aqua/dummy_data/readme.md +26 -0
  14. ads/aqua/evaluation.py +1751 -0
  15. ads/aqua/exception.py +82 -0
  16. ads/aqua/extension/__init__.py +40 -0
  17. ads/aqua/extension/base_handler.py +138 -0
  18. ads/aqua/extension/common_handler.py +21 -0
  19. ads/aqua/extension/deployment_handler.py +202 -0
  20. ads/aqua/extension/evaluation_handler.py +135 -0
  21. ads/aqua/extension/finetune_handler.py +66 -0
  22. ads/aqua/extension/model_handler.py +59 -0
  23. ads/aqua/extension/ui_handler.py +201 -0
  24. ads/aqua/extension/utils.py +23 -0
  25. ads/aqua/finetune.py +579 -0
  26. ads/aqua/job.py +29 -0
  27. ads/aqua/model.py +819 -0
  28. ads/aqua/training/__init__.py +4 -0
  29. ads/aqua/training/exceptions.py +459 -0
  30. ads/aqua/ui.py +453 -0
  31. ads/aqua/utils.py +715 -0
  32. ads/cli.py +37 -6
  33. ads/common/decorator/__init__.py +7 -3
  34. ads/common/decorator/require_nonempty_arg.py +65 -0
  35. ads/common/object_storage_details.py +166 -7
  36. ads/common/oci_client.py +18 -1
  37. ads/common/oci_logging.py +2 -2
  38. ads/common/oci_mixin.py +4 -5
  39. ads/common/serializer.py +34 -5
  40. ads/common/utils.py +75 -10
  41. ads/config.py +40 -1
  42. ads/jobs/ads_job.py +43 -25
  43. ads/jobs/builders/infrastructure/base.py +4 -2
  44. ads/jobs/builders/infrastructure/dsc_job.py +49 -39
  45. ads/jobs/builders/runtimes/base.py +71 -1
  46. ads/jobs/builders/runtimes/container_runtime.py +4 -4
  47. ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
  48. ads/jobs/templates/driver_pytorch.py +27 -10
  49. ads/model/artifact_downloader.py +84 -14
  50. ads/model/artifact_uploader.py +25 -23
  51. ads/model/datascience_model.py +388 -38
  52. ads/model/deployment/model_deployment.py +10 -2
  53. ads/model/generic_model.py +8 -0
  54. ads/model/model_file_description_schema.json +68 -0
  55. ads/model/model_metadata.py +1 -1
  56. ads/model/service/oci_datascience_model.py +34 -5
  57. ads/opctl/operator/lowcode/anomaly/README.md +2 -1
  58. ads/opctl/operator/lowcode/anomaly/__main__.py +10 -4
  59. ads/opctl/operator/lowcode/anomaly/environment.yaml +2 -1
  60. ads/opctl/operator/lowcode/anomaly/model/automlx.py +12 -6
  61. ads/opctl/operator/lowcode/forecast/README.md +3 -2
  62. ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
  63. ads/opctl/operator/lowcode/forecast/model/automlx.py +12 -23
  64. ads/telemetry/base.py +62 -0
  65. ads/telemetry/client.py +105 -0
  66. ads/telemetry/telemetry.py +6 -3
  67. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +37 -7
  68. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +71 -36
  69. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
  70. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
  71. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
@@ -1,16 +1,20 @@
1
- import json
2
- import time
3
- import traceback
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8; -*-
3
+
4
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
+
4
7
  from ads.jobs.builders.runtimes.artifact import PythonArtifact, GitPythonArtifact
8
+ from ads.jobs.builders.runtimes.base import MultiNodeRuntime
5
9
  from ads.jobs.builders.runtimes.python_runtime import (
6
10
  PythonRuntime,
7
11
  GitPythonRuntime,
8
12
  )
9
13
 
10
14
 
11
- class PyTorchDistributedRuntime(PythonRuntime):
15
+ class PyTorchDistributedRuntime(PythonRuntime, MultiNodeRuntime):
16
+ """Represents runtime supporting PyTorch Distributed training."""
12
17
  CONST_GIT = "git"
13
- CONST_REPLICA = "replicas"
14
18
  CONST_INPUT = "inputs"
15
19
  CONST_DEP = "dependencies"
16
20
  CONST_PIP_REQ = "pipRequirements"
@@ -79,26 +83,6 @@ class PyTorchDistributedRuntime(PythonRuntime):
79
83
  """The input files to be copied into the job run."""
80
84
  return self.get_spec(self.CONST_INPUT)
81
85
 
82
- def with_replica(self, count: int):
83
- """Specifies the number of nodes (job runs) for the job.
84
-
85
- Parameters
86
- ----------
87
- count : int
88
- Number of nodes (job runs)
89
-
90
- Returns
91
- -------
92
- self
93
- The runtime instance.
94
- """
95
- return self.set_spec(self.CONST_REPLICA, count)
96
-
97
- @property
98
- def replica(self) -> int:
99
- """The number of nodes (job runs)."""
100
- return self.get_spec(self.CONST_REPLICA)
101
-
102
86
  def with_dependency(self, pip_req=None, pip_pkg=None):
103
87
  """Specifies additional dependencies to be installed using pip.
104
88
 
@@ -185,7 +169,7 @@ class PyTorchDistributedRuntime(PythonRuntime):
185
169
  def command(self):
186
170
  """The command for launching the workload."""
187
171
  return self.get_spec(self.CONST_COMMAND)
188
-
172
+
189
173
  @property
190
174
  def use_deepspeed(self):
191
175
  """Indicate whether whether to configure deepspeed for multi-node workload"""
@@ -193,43 +177,6 @@ class PyTorchDistributedRuntime(PythonRuntime):
193
177
  return True
194
178
  return False
195
179
 
196
- def run(self, dsc_job, **kwargs):
197
- """Starts the job runs"""
198
- replicas = self.replica if self.replica else 1
199
- main_run = None
200
- job_runs = []
201
- try:
202
- for i in range(replicas):
203
- replica_kwargs = kwargs.copy()
204
- envs = replica_kwargs.get("environment_variables")
205
- if not envs:
206
- envs = {}
207
- # Huggingface accelerate requires machine rank
208
- # Here we use NODE_RANK to store the machine rank
209
- envs["NODE_RANK"] = str(i)
210
- envs["WORLD_SIZE"] = str(replicas)
211
- if main_run:
212
- envs["MAIN_JOB_RUN_OCID"] = main_run.id
213
- name = replica_kwargs.get("display_name")
214
- if not name:
215
- name = dsc_job.display_name
216
-
217
- replica_kwargs["display_name"] = f"{name}-{str(i)}"
218
- replica_kwargs["environment_variables"] = envs
219
- run = dsc_job.run(**replica_kwargs)
220
- job_runs.append(run)
221
- if i == 0:
222
- main_run = run
223
- except Exception:
224
- traceback.print_exc()
225
- # Wait a few second to avoid the job run being in a transient state.
226
- time.sleep(2)
227
- # If there is any error when creating the job runs
228
- # cancel all the job runs.
229
- for run in job_runs:
230
- run.cancel()
231
- return main_run
232
-
233
180
 
234
181
  class PyTorchDistributedArtifact(PythonArtifact):
235
182
  CONST_DRIVER_SCRIPT = "driver_pytorch.py"
@@ -5,6 +5,7 @@
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
  """This module requires oracle-ads>=2.6.8
7
7
  """
8
+ import getpass
8
9
  import ipaddress
9
10
  import logging
10
11
  import multiprocessing
@@ -40,17 +41,32 @@ logger = logging.getLogger(__name__)
40
41
  logger = driver_utils.set_log_level(logger)
41
42
 
42
43
 
44
+ # Envs provisioned by the service
43
45
  CONST_ENV_HOST_JOB_RUN_OCID = "MAIN_JOB_RUN_OCID"
44
46
  CONST_ENV_JOB_RUN_OCID = "JOB_RUN_OCID"
45
- CONST_ENV_LD_PRELOAD = "LD_PRELOAD"
47
+ # Envs set by the ADS API
48
+ OCI__WORKER_COUNT = "OCI__WORKER_COUNT"
49
+ CONST_ENV_NODE_RANK = "NODE_RANK"
50
+ CONST_ENV_NODE_COUNT = "NODE_COUNT"
46
51
  CONST_ENV_LAUNCH_CMD = "OCI__LAUNCH_CMD"
47
52
  CONST_ENV_DEEPSPEED = "OCI__DEEPSPEED"
53
+ # Envs set by this module
54
+ CONST_ENV_WORLD_SIZE = "WORLD_SIZE"
55
+ CONST_ENV_LD_PRELOAD = "LD_PRELOAD"
56
+ # Envs for debugging only
57
+ # OCI_ODSC_SERVICE_ENDPOINT is used for all processes in the job run
58
+ CONST_ENV_ODSC_SERVICE_ENDPOINT = "OCI_ODSC_SERVICE_ENDPOINT"
59
+ # OCI_DS_SERVICE_ENDPOINT is used only by the training process
60
+ CONST_ENV_DS_SERVICE_ENDPOINT = "OCI_DS_SERVICE_ENDPOINT"
61
+
62
+ # Constants used in logs
48
63
  LOG_PREFIX_HOST_IP = "Distributed Training HOST IP: "
49
64
  LOG_PREFIX_NODE_IP = "Node IP: "
50
65
  LOG_PREFIX_PUBLIC_KEY = "HOST PUBLIC KEY: "
51
- SSH_DIR = "/home/datascience/.ssh"
52
- # Working count is the number of node - 1
53
- OCI__WORKER_COUNT = "OCI__WORKER_COUNT"
66
+ # Other constants used within this script
67
+ # Other constants used within this script
68
+ USER_HOME = os.environ.get("HOME", f"/home/{getpass.getuser()}")
69
+ SSH_DIR = os.environ.get("OCI__SSH_DIR", os.path.join(USER_HOME, ".ssh"))
54
70
  DEFAULT_LAUNCHER = "torchrun"
55
71
 
56
72
  # Set authentication method to resource principal
@@ -131,8 +147,11 @@ class Runner(driver_utils.JobRunner):
131
147
 
132
148
  self.host_job_run = DataScienceJobRun.from_ocid(self.host_ocid)
133
149
  self.entrypoint_env = PythonRuntimeHandler.CONST_CODE_ENTRYPOINT
134
- # The total number of node is OCI__WORKER_COUNT + 1
135
- self.node_count = int(os.environ.get(OCI__WORKER_COUNT, 0)) + 1
150
+ # The total number of nodes is OCI__WORKER_COUNT + 1
151
+ if CONST_ENV_NODE_COUNT in os.environ:
152
+ self.node_count = int(os.environ[CONST_ENV_NODE_COUNT])
153
+ else:
154
+ self.node_count = int(os.environ.get(OCI__WORKER_COUNT, 0)) + 1
136
155
  logger.debug("Node count: %s", self.node_count)
137
156
  self.gpu_count = torch.cuda.device_count()
138
157
  logger.debug("GPU count on this node: %s", self.gpu_count)
@@ -343,9 +362,7 @@ class Runner(driver_utils.JobRunner):
343
362
  if self.launch_cmd:
344
363
  if self.LAUNCHER:
345
364
  if not self.launch_cmd.startswith(self.LAUNCHER):
346
- raise ValueError(
347
- f"Command not supported: '{self.launch_cmd}'. "
348
- )
365
+ raise ValueError(f"Command not supported: '{self.launch_cmd}'. ")
349
366
 
350
367
  launch_args.append(self.launch_cmd[len(self.LAUNCHER) + 1 :])
351
368
  else:
@@ -689,7 +706,7 @@ class GenericRunner(TorchRunner, DeepSpeedRunner):
689
706
  def set_env_var(self):
690
707
  """Set default environment variables."""
691
708
  defaults = {
692
- "WORLD_SIZE": self.node_count,
709
+ "WORLD_SIZE": self.node_count * self.gpu_count,
693
710
  "MASTER_ADDR": self.host_ip,
694
711
  "MASTER_PORT": self.RDZV_PORT,
695
712
  }
@@ -1,12 +1,11 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8; -*-
3
3
 
4
- # Copyright (c) 2022, 2023 Oracle and/or its affiliates.
4
+ # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
 
7
7
  import os
8
8
  import shutil
9
- import tempfile
10
9
  import uuid
11
10
  from abc import ABC, abstractmethod
12
11
  from typing import Dict, Optional
@@ -15,6 +14,7 @@ from zipfile import ZipFile
15
14
  from ads.common import utils
16
15
  from ads.common.utils import extract_region
17
16
  from ads.model.service.oci_datascience_model import OCIDataScienceModel
17
+ from ads.common.object_storage_details import ObjectStorageDetails
18
18
 
19
19
 
20
20
  class ArtifactDownloader(ABC):
@@ -85,17 +85,32 @@ class SmallArtifactDownloader(ArtifactDownloader):
85
85
  def _download(self):
86
86
  """Downloads model artifacts."""
87
87
  self.progress.update("Importing model artifacts from catalog")
88
- zip_content = self.dsc_model.get_model_artifact_content()
88
+
89
+ artifact_info = self.dsc_model.get_artifact_info()
90
+ artifact_name = artifact_info["Content-Disposition"].replace(
91
+ "attachment; filename=", ""
92
+ )
93
+ _, file_extension = os.path.splitext(artifact_name)
94
+ file_extension = file_extension.lower() if file_extension else ".zip"
95
+
96
+ file_content = self.dsc_model.get_model_artifact_content()
89
97
  self.progress.update("Copying model artifacts to the artifact directory")
90
-
91
- zip_file_path = os.path.join(self.target_dir, f"{str(uuid.uuid4())}.zip")
92
- with open(zip_file_path, "wb") as zip_file:
93
- zip_file.write(zip_content)
94
- self.progress.update("Extracting model artifacts")
95
- with ZipFile(zip_file_path) as zip_file:
96
- zip_file.extractall(self.target_dir)
97
-
98
- utils.remove_file(zip_file_path)
98
+
99
+ file_name = (
100
+ "model_description" if file_extension == ".json" else str(uuid.uuid4())
101
+ )
102
+ artifact_file_path = os.path.join(
103
+ self.target_dir, f"{file_name}{file_extension}"
104
+ )
105
+ with open(artifact_file_path, "wb") as _file:
106
+ _file.write(file_content)
107
+
108
+ if file_extension == ".zip":
109
+ self.progress.update("Extracting model artifacts")
110
+ with ZipFile(artifact_file_path) as _file:
111
+ _file.extractall(self.target_dir)
112
+ utils.remove_file(artifact_file_path)
113
+
99
114
 
100
115
  class LargeArtifactDownloader(ArtifactDownloader):
101
116
  PROGRESS_STEPS_COUNT = 4
@@ -110,6 +125,7 @@ class LargeArtifactDownloader(ArtifactDownloader):
110
125
  bucket_uri: Optional[str] = None,
111
126
  overwrite_existing_artifact: Optional[bool] = True,
112
127
  remove_existing_artifact: Optional[bool] = True,
128
+ model_file_description: Optional[dict] = None,
113
129
  ):
114
130
  """Initializes `LargeArtifactDownloader` instance.
115
131
 
@@ -137,6 +153,9 @@ class LargeArtifactDownloader(ArtifactDownloader):
137
153
  Overwrite target bucket artifact if exists.
138
154
  remove_existing_artifact: (bool, optional). Defaults to `True`.
139
155
  Wether artifacts uploaded to object storage bucket need to be removed or not.
156
+ model_file_description: (dict, optional). Defaults to None.
157
+ Contains object path details for models created by reference.
158
+
140
159
  """
141
160
  super().__init__(
142
161
  dsc_model=dsc_model, target_dir=target_dir, force_overwrite=force_overwrite
@@ -146,12 +165,19 @@ class LargeArtifactDownloader(ArtifactDownloader):
146
165
  self.bucket_uri = bucket_uri
147
166
  self.overwrite_existing_artifact = overwrite_existing_artifact
148
167
  self.remove_existing_artifact = remove_existing_artifact
168
+ self.model_file_description = model_file_description
149
169
 
150
170
  def _download(self):
151
171
  """Downloads model artifacts."""
152
172
  self.progress.update(f"Importing model artifacts from catalog")
153
173
 
174
+ if self.dsc_model.is_model_by_reference() and self.model_file_description:
175
+ self.download_from_model_file_description()
176
+ self.progress.update()
177
+ return
178
+
154
179
  bucket_uri = self.bucket_uri
180
+
155
181
  if not os.path.basename(bucket_uri):
156
182
  bucket_uri = os.path.join(bucket_uri, f"{self.dsc_model.id}.zip")
157
183
  elif not bucket_uri.lower().endswith(".zip"):
@@ -159,7 +185,6 @@ class LargeArtifactDownloader(ArtifactDownloader):
159
185
 
160
186
  self.dsc_model.import_model_artifact(bucket_uri=bucket_uri, region=self.region)
161
187
  self.progress.update("Copying model artifacts to the artifact directory")
162
-
163
188
  zip_file_path = os.path.join(self.target_dir, f"{str(uuid.uuid4())}.zip")
164
189
  zip_file_path = utils.copy_file(
165
190
  uri_src=bucket_uri,
@@ -172,7 +197,6 @@ class LargeArtifactDownloader(ArtifactDownloader):
172
197
  zip_file.extractall(self.target_dir)
173
198
 
174
199
  utils.remove_file(zip_file_path)
175
-
176
200
  if self.remove_existing_artifact:
177
201
  self.progress.update(
178
202
  "Removing temporary artifacts from the Object Storage bucket"
@@ -180,3 +204,49 @@ class LargeArtifactDownloader(ArtifactDownloader):
180
204
  utils.remove_file(bucket_uri)
181
205
  else:
182
206
  self.progress.update()
207
+
208
+ def download_from_model_file_description(self):
209
+ """Helper function to download the objects using model file description content to the target directory."""
210
+
211
+ models = self.model_file_description["models"]
212
+ total_size = 0
213
+
214
+ for model in models:
215
+ namespace, bucket_name, prefix = (
216
+ model["namespace"],
217
+ model["bucketName"],
218
+ model["prefix"],
219
+ )
220
+ bucket_uri = f"oci://{bucket_name}@{namespace}/{prefix}"
221
+
222
+ message = f"Copying model artifacts by reference from {bucket_uri} to {self.target_dir}"
223
+ self.progress.update(message)
224
+
225
+ objects = model["objects"]
226
+ os_details_list = list()
227
+
228
+ for obj in objects:
229
+ name = obj["name"]
230
+ version = None if obj["version"] == "" else obj["version"]
231
+ size = obj["sizeInBytes"]
232
+ if size == 0:
233
+ continue
234
+ total_size += size
235
+ object_uri = f"oci://{bucket_name}@{namespace}/{name}"
236
+
237
+ os_details = ObjectStorageDetails.from_path(object_uri)
238
+ os_details.version = version
239
+ os_details_list.append(os_details)
240
+ try:
241
+ ObjectStorageDetails.from_path(
242
+ bucket_uri
243
+ ).bulk_download_from_object_storage(
244
+ paths=os_details_list,
245
+ target_dir=self.target_dir,
246
+ progress_bar=self.progress,
247
+ )
248
+ except Exception as ex:
249
+ raise RuntimeError(
250
+ f"Failed to download model artifact by reference from the given Object Storage path `{bucket_uri}`."
251
+ f"See Exception: {ex}"
252
+ )
@@ -1,7 +1,8 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8; -*-
3
+ import logging
3
4
 
4
- # Copyright (c) 2022, 2023 Oracle and/or its affiliates.
5
+ # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
5
6
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
7
 
7
8
  import os
@@ -38,7 +39,7 @@ class ArtifactUploader(ABC):
38
39
 
39
40
  self.dsc_model = dsc_model
40
41
  self.artifact_path = artifact_path
41
- self.artifact_zip_path = None
42
+ self.artifact_file_path = None
42
43
  self.progress = None
43
44
 
44
45
  def upload(self):
@@ -48,8 +49,8 @@ class ArtifactUploader(ABC):
48
49
  ArtifactUploader.PROGRESS_STEPS_COUNT + self.PROGRESS_STEPS_COUNT
49
50
  ) as progress:
50
51
  self.progress = progress
51
- self.progress.update("Preparing model artifacts ZIP archive.")
52
- self._prepare_artifact_tmp_zip()
52
+ self.progress.update("Preparing model artifacts file.")
53
+ self._prepare_artifact_tmp_file()
53
54
  self.progress.update("Uploading model artifacts.")
54
55
  self._upload()
55
56
  self.progress.update(
@@ -59,35 +60,35 @@ class ArtifactUploader(ABC):
59
60
  except Exception:
60
61
  raise
61
62
  finally:
62
- self._remove_artifact_tmp_zip()
63
+ self._remove_artifact_tmp_file()
63
64
 
64
- def _prepare_artifact_tmp_zip(self) -> str:
65
- """Prepares model artifacts ZIP archive.
65
+ def _prepare_artifact_tmp_file(self) -> str:
66
+ """Prepares model artifacts file.
66
67
 
67
68
  Returns
68
69
  -------
69
70
  str
70
- Path to the model artifact ZIP archive.
71
+ Path to the model artifact file.
71
72
  """
72
73
  if ObjectStorageDetails.is_oci_path(self.artifact_path):
73
- self.artifact_zip_path = self.artifact_path
74
+ self.artifact_file_path = self.artifact_path
74
75
  elif os.path.isfile(self.artifact_path) and self.artifact_path.lower().endswith(
75
- ".zip"
76
+ (".zip", ".json")
76
77
  ):
77
- self.artifact_zip_path = self.artifact_path
78
+ self.artifact_file_path = self.artifact_path
78
79
  else:
79
- self.artifact_zip_path = model_utils.zip_artifact(
80
+ self.artifact_file_path = model_utils.zip_artifact(
80
81
  artifact_dir=self.artifact_path
81
82
  )
82
- return self.artifact_zip_path
83
+ return self.artifact_file_path
83
84
 
84
- def _remove_artifact_tmp_zip(self):
85
- """Removes temporary created artifact zip archive."""
85
+ def _remove_artifact_tmp_file(self):
86
+ """Removes temporary created artifact file."""
86
87
  if (
87
- self.artifact_zip_path
88
- and self.artifact_zip_path.lower() != self.artifact_path.lower()
88
+ self.artifact_file_path
89
+ and self.artifact_file_path.lower() != self.artifact_path.lower()
89
90
  ):
90
- shutil.rmtree(self.artifact_zip_path, ignore_errors=True)
91
+ shutil.rmtree(self.artifact_file_path, ignore_errors=True)
91
92
 
92
93
  @abstractmethod
93
94
  def _upload(self):
@@ -101,9 +102,10 @@ class SmallArtifactUploader(ArtifactUploader):
101
102
 
102
103
  def _upload(self):
103
104
  """Uploads model artifacts to the model catalog."""
105
+ _, ext = os.path.splitext(self.artifact_file_path)
104
106
  self.progress.update("Uploading model artifacts to the catalog")
105
- with open(self.artifact_zip_path, "rb") as file_data:
106
- self.dsc_model.create_model_artifact(file_data)
107
+ with open(self.artifact_file_path, "rb") as file_data:
108
+ self.dsc_model.create_model_artifact(bytes_content=file_data, extension=ext)
107
109
 
108
110
 
109
111
  class LargeArtifactUploader(ArtifactUploader):
@@ -117,7 +119,7 @@ class LargeArtifactUploader(ArtifactUploader):
117
119
  - object storage path to zip archive. Example: `oci://<bucket_name>@<namespace>/prefix/mymodel.zip`.
118
120
  - local path to zip archive. Example: `./mymodel.zip`.
119
121
  - local path to folder with artifacts. Example: `./mymodel`.
120
- artifact_zip_path: str
122
+ artifact_file_path: str
121
123
  The uri of the zip of model artifact.
122
124
  auth: dict
123
125
  The default authetication is set using `ads.set_auth` API.
@@ -222,7 +224,7 @@ class LargeArtifactUploader(ArtifactUploader):
222
224
  """Uploads model artifacts to the model catalog."""
223
225
  bucket_uri = self.bucket_uri
224
226
  self.progress.update("Copying model artifact to the Object Storage bucket")
225
- if not bucket_uri == self.artifact_zip_path:
227
+ if not bucket_uri == self.artifact_file_path:
226
228
  bucket_uri_file_name = os.path.basename(bucket_uri)
227
229
 
228
230
  if not bucket_uri_file_name:
@@ -240,7 +242,7 @@ class LargeArtifactUploader(ArtifactUploader):
240
242
 
241
243
  try:
242
244
  utils.upload_to_os(
243
- src_uri=self.artifact_zip_path,
245
+ src_uri=self.artifact_file_path,
244
246
  dst_uri=bucket_uri,
245
247
  auth=self.auth,
246
248
  parallel_process_count=self._parallel_process_count,