oracle-ads 2.12.10rc0__py3-none-any.whl → 2.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ads/aqua/__init__.py +2 -1
  2. ads/aqua/app.py +46 -19
  3. ads/aqua/client/__init__.py +3 -0
  4. ads/aqua/client/client.py +799 -0
  5. ads/aqua/common/enums.py +19 -14
  6. ads/aqua/common/errors.py +3 -4
  7. ads/aqua/common/utils.py +2 -2
  8. ads/aqua/constants.py +1 -0
  9. ads/aqua/evaluation/constants.py +7 -7
  10. ads/aqua/evaluation/errors.py +3 -4
  11. ads/aqua/evaluation/evaluation.py +20 -12
  12. ads/aqua/extension/aqua_ws_msg_handler.py +14 -7
  13. ads/aqua/extension/base_handler.py +12 -9
  14. ads/aqua/extension/model_handler.py +29 -1
  15. ads/aqua/extension/models/ws_models.py +5 -6
  16. ads/aqua/finetuning/constants.py +3 -3
  17. ads/aqua/finetuning/entities.py +3 -0
  18. ads/aqua/finetuning/finetuning.py +32 -1
  19. ads/aqua/model/constants.py +7 -7
  20. ads/aqua/model/entities.py +2 -1
  21. ads/aqua/model/enums.py +4 -5
  22. ads/aqua/model/model.py +158 -76
  23. ads/aqua/modeldeployment/deployment.py +22 -10
  24. ads/aqua/modeldeployment/entities.py +3 -1
  25. ads/cli.py +16 -8
  26. ads/common/auth.py +33 -20
  27. ads/common/extended_enum.py +52 -44
  28. ads/llm/__init__.py +11 -8
  29. ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
  30. ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
  31. ads/model/artifact_downloader.py +3 -4
  32. ads/model/datascience_model.py +84 -64
  33. ads/model/generic_model.py +3 -3
  34. ads/model/model_metadata.py +17 -11
  35. ads/model/service/oci_datascience_model.py +12 -14
  36. ads/opctl/backend/marketplace/helm_helper.py +13 -14
  37. ads/opctl/cli.py +4 -5
  38. ads/opctl/cmds.py +28 -32
  39. ads/opctl/config/merger.py +8 -11
  40. ads/opctl/config/resolver.py +25 -30
  41. ads/opctl/operator/cli.py +9 -9
  42. ads/opctl/operator/common/backend_factory.py +56 -60
  43. ads/opctl/operator/common/const.py +5 -5
  44. ads/opctl/operator/lowcode/anomaly/const.py +8 -9
  45. ads/opctl/operator/lowcode/common/transformations.py +38 -3
  46. ads/opctl/operator/lowcode/common/utils.py +11 -1
  47. ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +43 -48
  48. ads/opctl/operator/lowcode/forecast/__main__.py +10 -0
  49. ads/opctl/operator/lowcode/forecast/const.py +6 -6
  50. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +1 -1
  51. ads/opctl/operator/lowcode/forecast/operator_config.py +31 -0
  52. ads/opctl/operator/lowcode/forecast/schema.yaml +63 -0
  53. ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
  54. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +233 -0
  55. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +238 -0
  56. ads/opctl/operator/lowcode/pii/constant.py +6 -7
  57. ads/opctl/operator/lowcode/recommender/constant.py +12 -7
  58. ads/opctl/operator/runtime/marketplace_runtime.py +4 -10
  59. ads/opctl/operator/runtime/runtime.py +4 -6
  60. ads/pipeline/ads_pipeline_run.py +13 -25
  61. ads/pipeline/visualizer/graph_renderer.py +3 -4
  62. {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/METADATA +4 -2
  63. {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/RECORD +66 -59
  64. {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/LICENSE.txt +0 -0
  65. {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/WHEEL +0 -0
  66. {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,184 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) 2025 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ from typing import Any, Callable, Dict, List, Mapping, Optional
7
+
8
+ import requests
9
+ from langchain_core.embeddings import Embeddings
10
+ from langchain_core.language_models.llms import create_base_retry_decorator
11
+ from pydantic import BaseModel, Field
12
+
13
+ DEFAULT_HEADER = {
14
+ "Content-Type": "application/json",
15
+ }
16
+
17
+
18
+ class TokenExpiredError(Exception):
19
+ pass
20
+
21
+
22
+ def _create_retry_decorator(llm) -> Callable[[Any], Any]:
23
+ """Creates a retry decorator."""
24
+ errors = [requests.exceptions.ConnectTimeout, TokenExpiredError]
25
+ decorator = create_base_retry_decorator(
26
+ error_types=errors, max_retries=llm.max_retries
27
+ )
28
+ return decorator
29
+
30
+
31
+ class OCIDataScienceEmbedding(BaseModel, Embeddings):
32
+ """Embedding model deployed on OCI Data Science Model Deployment.
33
+
34
+ Example:
35
+
36
+ .. code-block:: python
37
+
38
+ from ads.llm import OCIDataScienceEmbedding
39
+
40
+ embeddings = OCIDataScienceEmbedding(
41
+ endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<md_ocid>/predict",
42
+ )
43
+ """ # noqa: E501
44
+
45
+ auth: dict = Field(default_factory=dict, exclude=True)
46
+ """ADS auth dictionary for OCI authentication:
47
+ https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html.
48
+ This can be generated by calling `ads.common.auth.api_keys()`
49
+ or `ads.common.auth.resource_principal()`. If this is not
50
+ provided then the `ads.common.default_signer()` will be used."""
51
+
52
+ endpoint: str = ""
53
+ """The uri of the endpoint from the deployed Model Deployment model."""
54
+
55
+ model_kwargs: Optional[Dict] = None
56
+ """Keyword arguments to pass to the model."""
57
+
58
+ endpoint_kwargs: Optional[Dict] = None
59
+ """Optional attributes (except for headers) passed to the request.post
60
+ function.
61
+ """
62
+
63
+ max_retries: int = 1
64
+ """The maximum number of retries to make when generating."""
65
+
66
+ @property
67
+ def _identifying_params(self) -> Mapping[str, Any]:
68
+ """Get the identifying parameters."""
69
+ _model_kwargs = self.model_kwargs or {}
70
+ return {
71
+ **{"endpoint": self.endpoint},
72
+ **{"model_kwargs": _model_kwargs},
73
+ }
74
+
75
+ def _embed_with_retry(self, **kwargs) -> Any:
76
+ """Use tenacity to retry the call."""
77
+ retry_decorator = _create_retry_decorator(self)
78
+
79
+ @retry_decorator
80
+ def _completion_with_retry(**kwargs: Any) -> Any:
81
+ try:
82
+ response = requests.post(self.endpoint, **kwargs)
83
+ response.raise_for_status()
84
+ return response
85
+ except requests.exceptions.HTTPError as http_err:
86
+ if response.status_code == 401 and self._refresh_signer():
87
+ raise TokenExpiredError() from http_err
88
+ else:
89
+ raise ValueError(
90
+ f"Server error: {str(http_err)}. Message: {response.text}"
91
+ ) from http_err
92
+ except Exception as e:
93
+ raise ValueError(f"Error occurs by inference endpoint: {str(e)}") from e
94
+
95
+ return _completion_with_retry(**kwargs)
96
+
97
+ def _embedding(self, texts: List[str]) -> List[List[float]]:
98
+ """Call out to OCI Data Science Model Deployment Endpoint.
99
+
100
+ Args:
101
+ texts: A list of texts to embed.
102
+
103
+ Returns:
104
+ A list of list of floats representing the embeddings, or None if an
105
+ error occurs.
106
+ """
107
+ _model_kwargs = self.model_kwargs or {}
108
+ body = self._construct_request_body(texts, _model_kwargs)
109
+ request_kwargs = self._construct_request_kwargs(body)
110
+ response = self._embed_with_retry(**request_kwargs)
111
+ return self._proceses_response(response)
112
+
113
+ def _construct_request_kwargs(self, body: Any) -> dict:
114
+ """Constructs the request kwargs as a dictionary."""
115
+ from ads.model.common.utils import _is_json_serializable
116
+
117
+ _endpoint_kwargs = self.endpoint_kwargs or {}
118
+ headers = _endpoint_kwargs.pop("headers", DEFAULT_HEADER)
119
+ return (
120
+ dict(
121
+ headers=headers,
122
+ json=body,
123
+ auth=self.auth.get("signer"),
124
+ **_endpoint_kwargs,
125
+ )
126
+ if _is_json_serializable(body)
127
+ else dict(
128
+ headers=headers,
129
+ data=body,
130
+ auth=self.auth.get("signer"),
131
+ **_endpoint_kwargs,
132
+ )
133
+ )
134
+
135
+ def _construct_request_body(self, texts: List[str], params: dict) -> Any:
136
+ """Constructs the request body."""
137
+ return {"input": texts}
138
+
139
+ def _proceses_response(self, response: requests.Response) -> List[List[float]]:
140
+ """Extracts results from requests.Response."""
141
+ try:
142
+ res_json = response.json()
143
+ embeddings = res_json["data"][0]["embedding"]
144
+ except Exception as e:
145
+ raise ValueError(
146
+ f"Error raised by inference API: {e}.\nResponse: {response.text}"
147
+ ) from e
148
+ return embeddings
149
+
150
+ def embed_documents(
151
+ self,
152
+ texts: List[str],
153
+ chunk_size: Optional[int] = None,
154
+ ) -> List[List[float]]:
155
+ """Compute doc embeddings using OCI Data Science Model Deployment Endpoint.
156
+
157
+ Args:
158
+ texts: The list of texts to embed.
159
+ chunk_size: The chunk size defines how many input texts will
160
+ be grouped together as request. If None, will use the
161
+ chunk size specified by the class.
162
+
163
+ Returns:
164
+ List of embeddings, one for each text.
165
+ """
166
+ results = []
167
+ _chunk_size = (
168
+ len(texts) if (not chunk_size or chunk_size > len(texts)) else chunk_size
169
+ )
170
+ for i in range(0, len(texts), _chunk_size):
171
+ response = self._embedding(texts[i : i + _chunk_size])
172
+ results.extend(response)
173
+ return results
174
+
175
+ def embed_query(self, text: str) -> List[float]:
176
+ """Compute query embeddings using OCI Data Science Model Deployment Endpoint.
177
+
178
+ Args:
179
+ text: The text to embed.
180
+
181
+ Returns:
182
+ Embeddings for the text.
183
+ """
184
+ return self._embedding([text])[0]
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
3
  # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -12,9 +11,9 @@ from typing import Dict, Optional
12
11
  from zipfile import ZipFile
13
12
 
14
13
  from ads.common import utils
14
+ from ads.common.object_storage_details import ObjectStorageDetails
15
15
  from ads.common.utils import extract_region
16
16
  from ads.model.service.oci_datascience_model import OCIDataScienceModel
17
- from ads.common.object_storage_details import ObjectStorageDetails
18
17
 
19
18
 
20
19
  class ArtifactDownloader(ABC):
@@ -169,9 +168,9 @@ class LargeArtifactDownloader(ArtifactDownloader):
169
168
 
170
169
  def _download(self):
171
170
  """Downloads model artifacts."""
172
- self.progress.update(f"Importing model artifacts from catalog")
171
+ self.progress.update("Importing model artifacts from catalog")
173
172
 
174
- if self.dsc_model.is_model_by_reference() and self.model_file_description:
173
+ if self.dsc_model._is_model_by_reference() and self.model_file_description:
175
174
  self.download_from_model_file_description()
176
175
  self.progress.update()
177
176
  return
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2022, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import cgi
@@ -19,12 +18,14 @@ from jsonschema import ValidationError, validate
19
18
 
20
19
  from ads.common import oci_client as oc
21
20
  from ads.common import utils
22
- from ads.common.extended_enum import ExtendedEnumMeta
21
+ from ads.common.extended_enum import ExtendedEnum
23
22
  from ads.common.object_storage_details import ObjectStorageDetails
23
+ from ads.config import (
24
+ AQUA_SERVICE_MODELS_BUCKET as SERVICE_MODELS_BUCKET,
25
+ )
24
26
  from ads.config import (
25
27
  COMPARTMENT_OCID,
26
28
  PROJECT_OCID,
27
- AQUA_SERVICE_MODELS_BUCKET as SERVICE_MODELS_BUCKET,
28
29
  )
29
30
  from ads.feature_engineering.schema import Schema
30
31
  from ads.jobs.builders.base import Builder
@@ -80,14 +81,14 @@ class InvalidArtifactType(Exception): # pragma: no cover
80
81
  pass
81
82
 
82
83
 
83
- class CustomerNotificationType(str, metaclass=ExtendedEnumMeta):
84
+ class CustomerNotificationType(ExtendedEnum):
84
85
  NONE = "NONE"
85
86
  ALL = "ALL"
86
87
  ON_FAILURE = "ON_FAILURE"
87
88
  ON_SUCCESS = "ON_SUCCESS"
88
89
 
89
90
 
90
- class SettingStatus(str, metaclass=ExtendedEnumMeta):
91
+ class SettingStatus(ExtendedEnum):
91
92
  """Enum to represent the status of retention settings."""
92
93
 
93
94
  PENDING = "PENDING"
@@ -116,17 +117,17 @@ class ModelBackupSetting:
116
117
  """
117
118
 
118
119
  def __init__(
119
- self,
120
- is_backup_enabled: Optional[bool] = None,
121
- backup_region: Optional[str] = None,
122
- customer_notification_type: Optional[CustomerNotificationType] = None,
120
+ self,
121
+ is_backup_enabled: Optional[bool] = None,
122
+ backup_region: Optional[str] = None,
123
+ customer_notification_type: Optional[CustomerNotificationType] = None,
123
124
  ):
124
125
  self.is_backup_enabled = (
125
126
  is_backup_enabled if is_backup_enabled is not None else False
126
127
  )
127
128
  self.backup_region = backup_region
128
129
  self.customer_notification_type = (
129
- customer_notification_type or CustomerNotificationType.NONE
130
+ customer_notification_type or CustomerNotificationType.NONE
130
131
  )
131
132
 
132
133
  def to_dict(self) -> Dict:
@@ -143,10 +144,7 @@ class ModelBackupSetting:
143
144
  return cls(
144
145
  is_backup_enabled=data.get("is_backup_enabled"),
145
146
  backup_region=data.get("backup_region"),
146
- customer_notification_type=CustomerNotificationType(
147
- data.get("customer_notification_type")
148
- )
149
- or None,
147
+ customer_notification_type=data.get("customer_notification_type") or None,
150
148
  )
151
149
 
152
150
  def to_json(self) -> str:
@@ -166,12 +164,15 @@ class ModelBackupSetting:
166
164
 
167
165
  def validate(self) -> bool:
168
166
  """Validates the backup settings details. Returns True if valid, False otherwise."""
169
- return all([
170
- isinstance(self.is_backup_enabled, bool),
171
- not self.backup_region or isinstance(self.backup_region, str),
172
- isinstance(self.customer_notification_type, str) and self.customer_notification_type in
173
- CustomerNotificationType.values()
174
- ])
167
+ return all(
168
+ [
169
+ isinstance(self.is_backup_enabled, bool),
170
+ not self.backup_region or isinstance(self.backup_region, str),
171
+ isinstance(self.customer_notification_type, str)
172
+ and self.customer_notification_type
173
+ in CustomerNotificationType.values(),
174
+ ]
175
+ )
175
176
 
176
177
  def __repr__(self):
177
178
  return self.to_yaml()
@@ -198,15 +199,15 @@ class ModelRetentionSetting:
198
199
  """
199
200
 
200
201
  def __init__(
201
- self,
202
- archive_after_days: Optional[int] = None,
203
- delete_after_days: Optional[int] = None,
204
- customer_notification_type: Optional[CustomerNotificationType] = None,
202
+ self,
203
+ archive_after_days: Optional[int] = None,
204
+ delete_after_days: Optional[int] = None,
205
+ customer_notification_type: Optional[CustomerNotificationType] = None,
205
206
  ):
206
207
  self.archive_after_days = archive_after_days
207
208
  self.delete_after_days = delete_after_days
208
209
  self.customer_notification_type = (
209
- customer_notification_type or CustomerNotificationType.NONE
210
+ customer_notification_type or CustomerNotificationType.NONE
210
211
  )
211
212
 
212
213
  def to_dict(self) -> Dict:
@@ -223,10 +224,7 @@ class ModelRetentionSetting:
223
224
  return cls(
224
225
  archive_after_days=data.get("archive_after_days"),
225
226
  delete_after_days=data.get("delete_after_days"),
226
- customer_notification_type=CustomerNotificationType(
227
- data.get("customer_notification_type")
228
- )
229
- or None,
227
+ customer_notification_type=data.get("customer_notification_type") or None,
230
228
  )
231
229
 
232
230
  def to_json(self) -> str:
@@ -245,13 +243,23 @@ class ModelRetentionSetting:
245
243
 
246
244
  def validate(self) -> bool:
247
245
  """Validates the retention settings details. Returns True if valid, False otherwise."""
248
- return all([
249
- self.archive_after_days is None or (
250
- isinstance(self.archive_after_days, int) and self.archive_after_days >= 0),
251
- self.delete_after_days is None or (isinstance(self.delete_after_days, int) and self.delete_after_days >= 0),
252
- isinstance(self.customer_notification_type, str) and self.customer_notification_type in
253
- CustomerNotificationType.values()
254
- ])
246
+ return all(
247
+ [
248
+ self.archive_after_days is None
249
+ or (
250
+ isinstance(self.archive_after_days, int)
251
+ and self.archive_after_days >= 0
252
+ ),
253
+ self.delete_after_days is None
254
+ or (
255
+ isinstance(self.delete_after_days, int)
256
+ and self.delete_after_days >= 0
257
+ ),
258
+ isinstance(self.customer_notification_type, str)
259
+ and self.customer_notification_type
260
+ in CustomerNotificationType.values(),
261
+ ]
262
+ )
255
263
 
256
264
  def __repr__(self):
257
265
  return self.to_yaml()
@@ -278,13 +286,13 @@ class ModelRetentionOperationDetails:
278
286
  """
279
287
 
280
288
  def __init__(
281
- self,
282
- archive_state: Optional[SettingStatus] = None,
283
- archive_state_details: Optional[str] = None,
284
- delete_state: Optional[SettingStatus] = None,
285
- delete_state_details: Optional[str] = None,
286
- time_archival_scheduled: Optional[int] = None,
287
- time_deletion_scheduled: Optional[int] = None,
289
+ self,
290
+ archive_state: Optional[SettingStatus] = None,
291
+ archive_state_details: Optional[str] = None,
292
+ delete_state: Optional[SettingStatus] = None,
293
+ delete_state_details: Optional[str] = None,
294
+ time_archival_scheduled: Optional[int] = None,
295
+ time_deletion_scheduled: Optional[int] = None,
288
296
  ):
289
297
  self.archive_state = archive_state
290
298
  self.archive_state_details = archive_state_details
@@ -308,9 +316,9 @@ class ModelRetentionOperationDetails:
308
316
  def from_dict(cls, data: Dict) -> "ModelRetentionOperationDetails":
309
317
  """Constructs retention operation details from a dictionary."""
310
318
  return cls(
311
- archive_state=SettingStatus(data.get("archive_state")) or None,
319
+ archive_state=data.get("archive_state") or None,
312
320
  archive_state_details=data.get("archive_state_details"),
313
- delete_state=SettingStatus(data.get("delete_state")) or None,
321
+ delete_state=data.get("delete_state") or None,
314
322
  delete_state_details=data.get("delete_state_details"),
315
323
  time_archival_scheduled=data.get("time_archival_scheduled"),
316
324
  time_deletion_scheduled=data.get("time_deletion_scheduled"),
@@ -334,8 +342,10 @@ class ModelRetentionOperationDetails:
334
342
  """Validates the retention operation details."""
335
343
  return all(
336
344
  [
337
- self.archive_state is None or self.archive_state in SettingStatus.values(),
338
- self.delete_state is None or self.delete_state in SettingStatus.values(),
345
+ self.archive_state is None
346
+ or self.archive_state in SettingStatus.values(),
347
+ self.delete_state is None
348
+ or self.delete_state in SettingStatus.values(),
339
349
  self.time_archival_scheduled is None
340
350
  or isinstance(self.time_archival_scheduled, int),
341
351
  self.time_deletion_scheduled is None
@@ -368,10 +378,10 @@ class ModelBackupOperationDetails:
368
378
  """
369
379
 
370
380
  def __init__(
371
- self,
372
- backup_state: Optional[SettingStatus] = None,
373
- backup_state_details: Optional[str] = None,
374
- time_last_backup: Optional[int] = None,
381
+ self,
382
+ backup_state: Optional[SettingStatus] = None,
383
+ backup_state_details: Optional[str] = None,
384
+ time_last_backup: Optional[int] = None,
375
385
  ):
376
386
  self.backup_state = backup_state
377
387
  self.backup_state_details = backup_state_details
@@ -389,7 +399,7 @@ class ModelBackupOperationDetails:
389
399
  def from_dict(cls, data: Dict) -> "ModelBackupOperationDetails":
390
400
  """Constructs backup operation details from a dictionary."""
391
401
  return cls(
392
- backup_state=SettingStatus(data.get("backup_state")) or None,
402
+ backup_state=data.get("backup_state") or None,
393
403
  backup_state_details=data.get("backup_state_details"),
394
404
  time_last_backup=data.get("time_last_backup"),
395
405
  )
@@ -411,8 +421,14 @@ class ModelBackupOperationDetails:
411
421
  def validate(self) -> bool:
412
422
  """Validates the backup operation details."""
413
423
  return not (
414
- (self.backup_state is not None and self.backup_state not in SettingStatus.values()) or
415
- (self.time_last_backup is not None and not isinstance(self.time_last_backup, int))
424
+ (
425
+ self.backup_state is not None
426
+ and self.backup_state not in SettingStatus.values()
427
+ )
428
+ or (
429
+ self.time_last_backup is not None
430
+ and not isinstance(self.time_last_backup, int)
431
+ )
416
432
  )
417
433
 
418
434
  def __repr__(self):
@@ -1042,7 +1058,7 @@ class DataScienceModel(Builder):
1042
1058
  elif json_string:
1043
1059
  json_data = json.loads(json_string)
1044
1060
  elif json_uri:
1045
- with open(json_uri, "r") as json_file:
1061
+ with open(json_uri) as json_file:
1046
1062
  json_data = json.load(json_file)
1047
1063
  else:
1048
1064
  raise ValueError("Must provide either a valid json string or URI location.")
@@ -1077,7 +1093,7 @@ class DataScienceModel(Builder):
1077
1093
  return self.get_spec(self.CONST_RETENTION_SETTING)
1078
1094
 
1079
1095
  def with_retention_setting(
1080
- self, retention_setting: Union[Dict, ModelRetentionSetting]
1096
+ self, retention_setting: Union[Dict, ModelRetentionSetting]
1081
1097
  ) -> "DataScienceModel":
1082
1098
  """
1083
1099
  Sets the retention setting details for the model.
@@ -1106,7 +1122,7 @@ class DataScienceModel(Builder):
1106
1122
  return self.get_spec(self.CONST_BACKUP_SETTING)
1107
1123
 
1108
1124
  def with_backup_setting(
1109
- self, backup_setting: Union[Dict, ModelBackupSetting]
1125
+ self, backup_setting: Union[Dict, ModelBackupSetting]
1110
1126
  ) -> "DataScienceModel":
1111
1127
  """
1112
1128
  Sets the model's backup setting details.
@@ -1368,8 +1384,8 @@ class DataScienceModel(Builder):
1368
1384
  shutil.rmtree(self.local_copy_dir, ignore_errors=True)
1369
1385
 
1370
1386
  def restore_model(
1371
- self,
1372
- restore_model_for_hours_specified: Optional[int] = None,
1387
+ self,
1388
+ restore_model_for_hours_specified: Optional[int] = None,
1373
1389
  ) -> None:
1374
1390
  """
1375
1391
  Restore archived model artifact.
@@ -1398,8 +1414,12 @@ class DataScienceModel(Builder):
1398
1414
 
1399
1415
  # Optional: Validate restore_model_for_hours_specified
1400
1416
  if restore_model_for_hours_specified is not None and (
1401
- not isinstance(restore_model_for_hours_specified, int) or restore_model_for_hours_specified <= 0):
1402
- raise ValueError("restore_model_for_hours_specified must be a positive integer.")
1417
+ not isinstance(restore_model_for_hours_specified, int)
1418
+ or restore_model_for_hours_specified <= 0
1419
+ ):
1420
+ raise ValueError(
1421
+ "restore_model_for_hours_specified must be a positive integer."
1422
+ )
1403
1423
 
1404
1424
  self.dsc_model.restore_archived_model_artifact(
1405
1425
  restore_model_for_hours_specified=restore_model_for_hours_specified,
@@ -1721,7 +1741,7 @@ class DataScienceModel(Builder):
1721
1741
  self.CONST_BACKUP_SETTING: ModelBackupSetting.to_dict,
1722
1742
  self.CONST_RETENTION_SETTING: ModelRetentionSetting.to_dict,
1723
1743
  self.CONST_BACKUP_OPERATION_DETAILS: ModelBackupOperationDetails.to_dict,
1724
- self.CONST_RETENTION_OPERATION_DETAILS: ModelRetentionOperationDetails.to_dict
1744
+ self.CONST_RETENTION_OPERATION_DETAILS: ModelRetentionOperationDetails.to_dict,
1725
1745
  }
1726
1746
 
1727
1747
  # Update the main properties
@@ -1758,7 +1778,7 @@ class DataScienceModel(Builder):
1758
1778
  artifact_info = self.dsc_model.get_artifact_info()
1759
1779
  _, file_name_info = cgi.parse_header(artifact_info["Content-Disposition"])
1760
1780
 
1761
- if self.dsc_model.is_model_by_reference():
1781
+ if self.dsc_model._is_model_by_reference():
1762
1782
  _, file_extension = os.path.splitext(file_name_info["filename"])
1763
1783
  if file_extension.lower() == ".json":
1764
1784
  bucket_uri, _ = self._download_file_description_artifact()
@@ -1,6 +1,6 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2022, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
6
  import inspect
@@ -63,7 +63,7 @@ from ads.model.model_introspect import (
63
63
  ModelIntrospect,
64
64
  )
65
65
  from ads.model.model_metadata import (
66
- ExtendedEnumMeta,
66
+ ExtendedEnum,
67
67
  Framework,
68
68
  MetadataCustomCategory,
69
69
  ModelCustomMetadata,
@@ -146,7 +146,7 @@ class ModelDeploymentRuntimeType:
146
146
  CONTAINER = "container"
147
147
 
148
148
 
149
- class DataScienceModelType(str, metaclass=ExtendedEnumMeta):
149
+ class DataScienceModelType(ExtendedEnum):
150
150
  MODEL_DEPLOYMENT = "datasciencemodeldeployment"
151
151
  MODEL = "datasciencemodel"
152
152
 
@@ -1,6 +1,6 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2021, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
6
  import json
@@ -21,7 +21,7 @@ from oci.util import to_dict
21
21
 
22
22
  from ads.common import logger
23
23
  from ads.common.error import ChangesNotCommitted
24
- from ads.common.extended_enum import ExtendedEnumMeta
24
+ from ads.common.extended_enum import ExtendedEnum
25
25
  from ads.common.object_storage_details import ObjectStorageDetails
26
26
  from ads.common.serializer import DataClassSerializable
27
27
  from ads.dataset import factory
@@ -81,19 +81,19 @@ class MetadataDescriptionTooLong(ValueError):
81
81
  )
82
82
 
83
83
 
84
- class MetadataCustomPrintColumns(str, metaclass=ExtendedEnumMeta):
84
+ class MetadataCustomPrintColumns(ExtendedEnum):
85
85
  KEY = "Key"
86
86
  VALUE = "Value"
87
87
  DESCRIPTION = "Description"
88
88
  CATEGORY = "Category"
89
89
 
90
90
 
91
- class MetadataTaxonomyPrintColumns(str, metaclass=ExtendedEnumMeta):
91
+ class MetadataTaxonomyPrintColumns(ExtendedEnum):
92
92
  KEY = "Key"
93
93
  VALUE = "Value"
94
94
 
95
95
 
96
- class MetadataTaxonomyKeys(str, metaclass=ExtendedEnumMeta):
96
+ class MetadataTaxonomyKeys(ExtendedEnum):
97
97
  USE_CASE_TYPE = "UseCaseType"
98
98
  FRAMEWORK = "Framework"
99
99
  FRAMEWORK_VERSION = "FrameworkVersion"
@@ -102,7 +102,7 @@ class MetadataTaxonomyKeys(str, metaclass=ExtendedEnumMeta):
102
102
  ARTIFACT_TEST_RESULT = "ArtifactTestResults"
103
103
 
104
104
 
105
- class MetadataCustomKeys(str, metaclass=ExtendedEnumMeta):
105
+ class MetadataCustomKeys(ExtendedEnum):
106
106
  SLUG_NAME = "SlugName"
107
107
  CONDA_ENVIRONMENT = "CondaEnvironment"
108
108
  CONDA_ENVIRONMENT_PATH = "CondaEnvironmentPath"
@@ -121,7 +121,7 @@ class MetadataCustomKeys(str, metaclass=ExtendedEnumMeta):
121
121
  MODEL_FILE_NAME = "ModelFileName"
122
122
 
123
123
 
124
- class MetadataCustomCategory(str, metaclass=ExtendedEnumMeta):
124
+ class MetadataCustomCategory(ExtendedEnum):
125
125
  PERFORMANCE = "Performance"
126
126
  TRAINING_PROFILE = "Training Profile"
127
127
  TRAINING_AND_VALIDATION_DATASETS = "Training and Validation Datasets"
@@ -129,7 +129,7 @@ class MetadataCustomCategory(str, metaclass=ExtendedEnumMeta):
129
129
  OTHER = "Other"
130
130
 
131
131
 
132
- class UseCaseType(str, metaclass=ExtendedEnumMeta):
132
+ class UseCaseType(ExtendedEnum):
133
133
  BINARY_CLASSIFICATION = "binary_classification"
134
134
  REGRESSION = "regression"
135
135
  MULTINOMIAL_CLASSIFICATION = "multinomial_classification"
@@ -146,7 +146,7 @@ class UseCaseType(str, metaclass=ExtendedEnumMeta):
146
146
  OTHER = "other"
147
147
 
148
148
 
149
- class Framework(str, metaclass=ExtendedEnumMeta):
149
+ class Framework(ExtendedEnum):
150
150
  SCIKIT_LEARN = "scikit-learn"
151
151
  XGBOOST = "xgboost"
152
152
  TENSORFLOW = "tensorflow"
@@ -1509,7 +1509,10 @@ class ModelTaxonomyMetadata(ModelMetadata):
1509
1509
  metadata = cls()
1510
1510
  for oci_item in metadata_list:
1511
1511
  item = ModelTaxonomyMetadataItem._from_oci_metadata(oci_item)
1512
- metadata[item.key].update(value=item.value)
1512
+ if item.key in metadata.keys:
1513
+ metadata[item.key].update(value=item.value)
1514
+ else:
1515
+ metadata._items.add(item)
1513
1516
  return metadata
1514
1517
 
1515
1518
  def to_dataframe(self) -> pd.DataFrame:
@@ -1562,7 +1565,10 @@ class ModelTaxonomyMetadata(ModelMetadata):
1562
1565
  metadata = cls()
1563
1566
  for item in data["data"]:
1564
1567
  item = ModelTaxonomyMetadataItem.from_dict(item)
1565
- metadata[item.key].update(value=item.value)
1568
+ if item.key in metadata.keys:
1569
+ metadata[item.key].update(value=item.value)
1570
+ else:
1571
+ metadata._items.add(item)
1566
1572
  return metadata
1567
1573
 
1568
1574