oracle-ads 2.12.11__py3-none-any.whl → 2.13.1rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. ads/aqua/app.py +23 -10
  2. ads/aqua/common/enums.py +19 -14
  3. ads/aqua/common/errors.py +3 -4
  4. ads/aqua/common/utils.py +2 -2
  5. ads/aqua/constants.py +1 -0
  6. ads/aqua/evaluation/constants.py +7 -7
  7. ads/aqua/evaluation/errors.py +3 -4
  8. ads/aqua/extension/model_handler.py +23 -0
  9. ads/aqua/extension/models/ws_models.py +5 -6
  10. ads/aqua/finetuning/constants.py +3 -3
  11. ads/aqua/model/constants.py +7 -7
  12. ads/aqua/model/enums.py +4 -5
  13. ads/aqua/model/model.py +22 -0
  14. ads/aqua/modeldeployment/entities.py +3 -1
  15. ads/common/auth.py +33 -20
  16. ads/common/extended_enum.py +52 -44
  17. ads/llm/__init__.py +11 -8
  18. ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
  19. ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
  20. ads/model/artifact_downloader.py +3 -4
  21. ads/model/datascience_model.py +84 -64
  22. ads/model/generic_model.py +3 -3
  23. ads/model/model_metadata.py +17 -11
  24. ads/model/service/oci_datascience_model.py +12 -14
  25. ads/opctl/anomaly_detection.py +11 -0
  26. ads/opctl/backend/marketplace/helm_helper.py +13 -14
  27. ads/opctl/cli.py +4 -5
  28. ads/opctl/cmds.py +28 -32
  29. ads/opctl/config/merger.py +8 -11
  30. ads/opctl/config/resolver.py +25 -30
  31. ads/opctl/forecast.py +11 -0
  32. ads/opctl/operator/cli.py +9 -9
  33. ads/opctl/operator/common/backend_factory.py +56 -60
  34. ads/opctl/operator/common/const.py +5 -5
  35. ads/opctl/operator/lowcode/anomaly/const.py +8 -9
  36. ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +43 -48
  37. ads/opctl/operator/lowcode/forecast/__main__.py +5 -5
  38. ads/opctl/operator/lowcode/forecast/const.py +6 -6
  39. ads/opctl/operator/lowcode/forecast/model/arima.py +6 -3
  40. ads/opctl/operator/lowcode/forecast/model/automlx.py +53 -31
  41. ads/opctl/operator/lowcode/forecast/model/base_model.py +57 -30
  42. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +60 -2
  43. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +5 -2
  44. ads/opctl/operator/lowcode/forecast/model/prophet.py +28 -15
  45. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +19 -11
  46. ads/opctl/operator/lowcode/pii/constant.py +6 -7
  47. ads/opctl/operator/lowcode/recommender/constant.py +12 -7
  48. ads/opctl/operator/runtime/marketplace_runtime.py +4 -10
  49. ads/opctl/operator/runtime/runtime.py +4 -6
  50. ads/pipeline/ads_pipeline_run.py +13 -25
  51. ads/pipeline/visualizer/graph_renderer.py +3 -4
  52. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/METADATA +6 -6
  53. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/RECORD +56 -52
  54. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/LICENSE.txt +0 -0
  55. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/WHEEL +0 -0
  56. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/entry_points.txt +0 -0
ads/common/auth.py CHANGED
@@ -1,23 +1,25 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2021, 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2021, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import copy
8
- from datetime import datetime
9
7
  import os
10
- from dataclasses import dataclass
11
8
  import time
9
+ from dataclasses import dataclass
10
+ from datetime import datetime
12
11
  from typing import Any, Callable, Dict, Optional, Union
13
12
 
14
- import ads.telemetry
15
13
  import oci
14
+ from oci.config import (
15
+ DEFAULT_LOCATION, # "~/.oci/config"
16
+ DEFAULT_PROFILE, # "DEFAULT"
17
+ )
18
+
19
+ import ads.telemetry
16
20
  from ads.common import logger
17
21
  from ads.common.decorator.deprecate import deprecated
18
- from ads.common.extended_enum import ExtendedEnumMeta
19
- from oci.config import DEFAULT_LOCATION # "~/.oci/config"
20
- from oci.config import DEFAULT_PROFILE # "DEFAULT"
22
+ from ads.common.extended_enum import ExtendedEnum
21
23
 
22
24
  SECURITY_TOKEN_LEFT_TIME = 600
23
25
 
@@ -26,7 +28,7 @@ class SecurityTokenError(Exception): # pragma: no cover
26
28
  pass
27
29
 
28
30
 
29
- class AuthType(str, metaclass=ExtendedEnumMeta):
31
+ class AuthType(ExtendedEnum):
30
32
  API_KEY = "api_key"
31
33
  RESOURCE_PRINCIPAL = "resource_principal"
32
34
  INSTANCE_PRINCIPAL = "instance_principal"
@@ -73,7 +75,11 @@ class AuthState(metaclass=SingletonMeta):
73
75
  self.oci_key_profile = self.oci_key_profile or os.environ.get(
74
76
  "OCI_CONFIG_PROFILE", DEFAULT_PROFILE
75
77
  )
76
- self.oci_config = self.oci_config or {"region": os.environ["OCI_RESOURCE_REGION"]} if os.environ.get("OCI_RESOURCE_REGION") else {}
78
+ self.oci_config = (
79
+ self.oci_config or {"region": os.environ["OCI_RESOURCE_REGION"]}
80
+ if os.environ.get("OCI_RESOURCE_REGION")
81
+ else {}
82
+ )
77
83
  self.oci_signer_kwargs = self.oci_signer_kwargs or {}
78
84
  self.oci_client_kwargs = self.oci_client_kwargs or {}
79
85
 
@@ -82,7 +88,9 @@ def set_auth(
82
88
  auth: Optional[str] = AuthType.API_KEY,
83
89
  oci_config_location: Optional[str] = DEFAULT_LOCATION,
84
90
  profile: Optional[str] = DEFAULT_PROFILE,
85
- config: Optional[Dict] = {"region": os.environ["OCI_RESOURCE_REGION"]} if os.environ.get("OCI_RESOURCE_REGION") else {},
91
+ config: Optional[Dict] = {"region": os.environ["OCI_RESOURCE_REGION"]}
92
+ if os.environ.get("OCI_RESOURCE_REGION")
93
+ else {},
86
94
  signer: Optional[Any] = None,
87
95
  signer_callable: Optional[Callable] = None,
88
96
  signer_kwargs: Optional[Dict] = {},
@@ -202,8 +210,8 @@ def set_auth(
202
210
  oci_config_location != DEFAULT_LOCATION or profile != DEFAULT_PROFILE
203
211
  ):
204
212
  raise ValueError(
205
- f"'config' and 'oci_config_location', 'profile' pair are mutually exclusive."
206
- f"Please specify 'config' OR 'oci_config_location', 'profile' pair."
213
+ "'config' and 'oci_config_location', 'profile' pair are mutually exclusive."
214
+ "Please specify 'config' OR 'oci_config_location', 'profile' pair."
207
215
  )
208
216
 
209
217
  auth_state.oci_config = config
@@ -621,7 +629,7 @@ class APIKey(AuthSignerGenerator):
621
629
  )
622
630
 
623
631
  oci.config.validate_config(configuration)
624
- logger.debug(f"Using 'api_key' authentication.")
632
+ logger.debug("Using 'api_key' authentication.")
625
633
  return {
626
634
  "config": configuration,
627
635
  "signer": oci.signer.Signer(
@@ -684,14 +692,19 @@ class ResourcePrincipal(AuthSignerGenerator):
684
692
  "signer": oci.auth.signers.get_resource_principals_signer(),
685
693
  "client_kwargs": self.client_kwargs,
686
694
  }
687
- logger.debug(f"Using 'resource_principal' authentication.")
695
+ logger.debug("Using 'resource_principal' authentication.")
688
696
  return signer_dict
689
697
 
690
698
  @staticmethod
691
699
  def supported():
692
700
  return any(
693
701
  os.environ.get(var)
694
- for var in ['JOB_RUN_OCID', 'NB_SESSION_OCID', 'DATAFLOW_RUN_ID', 'PIPELINE_RUN_OCID']
702
+ for var in [
703
+ "JOB_RUN_OCID",
704
+ "NB_SESSION_OCID",
705
+ "DATAFLOW_RUN_ID",
706
+ "PIPELINE_RUN_OCID",
707
+ ]
695
708
  )
696
709
 
697
710
 
@@ -747,7 +760,7 @@ class InstancePrincipal(AuthSignerGenerator):
747
760
  ),
748
761
  "client_kwargs": self.client_kwargs,
749
762
  }
750
- logger.debug(f"Using 'instance_principal' authentication.")
763
+ logger.debug("Using 'instance_principal' authentication.")
751
764
  return signer_dict
752
765
 
753
766
 
@@ -814,7 +827,7 @@ class SecurityToken(AuthSignerGenerator):
814
827
  oci.config.from_file(self.oci_config_location, self.oci_key_profile)
815
828
  )
816
829
 
817
- logger.debug(f"Using 'security_token' authentication.")
830
+ logger.debug("Using 'security_token' authentication.")
818
831
 
819
832
  for parameter in self.SECURITY_TOKEN_REQUIRED:
820
833
  if parameter not in configuration:
@@ -903,7 +916,7 @@ class SecurityToken(AuthSignerGenerator):
903
916
  raise ValueError("Invalid `security_token_file`. Specify a valid path.")
904
917
  try:
905
918
  token = None
906
- with open(expanded_path, "r") as f:
919
+ with open(expanded_path) as f:
907
920
  token = f.read()
908
921
  return token
909
922
  except:
@@ -1023,7 +1036,7 @@ class OCIAuthContext:
1023
1036
  logger.debug(f"OCI profile set to {self.profile}")
1024
1037
  else:
1025
1038
  ads.set_auth(auth=AuthType.RESOURCE_PRINCIPAL)
1026
- logger.debug(f"OCI auth set to resource principal")
1039
+ logger.debug("OCI auth set to resource principal")
1027
1040
  return self
1028
1041
 
1029
1042
  def __exit__(self, exc_type, exc_val, exc_tb):
@@ -1,73 +1,81 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2022, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
 
8
7
  from abc import ABCMeta
9
- from enum import Enum
10
8
 
11
9
 
12
10
  class ExtendedEnumMeta(ABCMeta):
13
- """The helper metaclass to extend functionality of a generic Enum.
11
+ """
12
+ A helper metaclass to extend functionality of a generic "Enum-like" class.
14
13
 
15
14
  Methods
16
15
  -------
17
- values(cls) -> list:
18
- Gets the list of class attributes.
19
-
20
- Examples
21
- --------
22
- >>> class TestEnum(str, metaclass=ExtendedEnumMeta):
23
- ... KEY1 = "value1"
24
- ... KEY2 = "value2"
25
- >>> print(TestEnum.KEY1) # "value1"
16
+ __contains__(cls, item) -> bool:
17
+ Checks if `item` is among the attribute values of the class.
18
+ Case-insensitive if `item` is a string.
19
+ values(cls) -> tuple:
20
+ Returns the tuple of class attribute values.
21
+ keys(cls) -> tuple:
22
+ Returns the tuple of class attribute names.
26
23
  """
27
24
 
28
- def __contains__(cls, value):
29
- return value and value.lower() in tuple(value.lower() for value in cls.values())
30
-
31
- def values(cls) -> list:
32
- """Gets the list of class attributes values.
25
+ def __contains__(cls, item: object) -> bool:
26
+ """
27
+ Checks if `item` is a member of the class's values.
33
28
 
34
- Returns
35
- -------
36
- list
37
- The list of class values.
29
+ - If `item` is a string, does a case-insensitive match against any string
30
+ values stored in the class.
31
+ - Otherwise, does a direct membership test.
32
+ """
33
+ # Gather the attribute values
34
+ attr_values = cls.values()
35
+
36
+ # If item is a string, compare case-insensitively to any str-type values
37
+ if isinstance(item, str):
38
+ return any(
39
+ isinstance(val, str) and val.lower() == item.lower()
40
+ for val in attr_values
41
+ )
42
+ else:
43
+ # For non-string items (e.g., int), do a direct membership check
44
+ return item in attr_values
45
+
46
+ def __iter__(cls):
47
+ # Make the class iterable by returning an iterator over its values
48
+ return iter(cls.values())
49
+
50
+ def values(cls) -> tuple:
51
+ """
52
+ Gets the tuple of class attribute values, excluding private or special
53
+ attributes and any callables (methods, etc.).
38
54
  """
39
55
  return tuple(
40
- value for key, value in cls.__dict__.items() if not key.startswith("_")
56
+ value
57
+ for key, value in cls.__dict__.items()
58
+ if not key.startswith("_") and not callable(value)
41
59
  )
42
60
 
43
- def keys(cls) -> list:
44
- """Gets the list of class attributes names.
45
-
46
- Returns
47
- -------
48
- list
49
- The list of class attributes names.
61
+ def keys(cls) -> tuple:
62
+ """
63
+ Gets the tuple of class attribute names, excluding private or special
64
+ attributes and any callables (methods, etc.).
50
65
  """
51
66
  return tuple(
52
- key for key, value in cls.__dict__.items() if not key.startswith("_")
67
+ key
68
+ for key, value in cls.__dict__.items()
69
+ if not key.startswith("_") and not callable(value)
53
70
  )
54
71
 
55
72
 
56
- class ExtendedEnum(Enum):
73
+ class ExtendedEnum(metaclass=ExtendedEnumMeta):
57
74
  """The base class to extend functionality of a generic Enum.
58
75
 
59
76
  Examples
60
77
  --------
61
- >>> class TestEnum(ExtendedEnumMeta):
62
- ... KEY1 = "value1"
63
- ... KEY2 = "value2"
64
- >>> print(TestEnum.KEY1.value) # "value1"
78
+ >>> class TestEnum(ExtendedEnum):
79
+ ... KEY1 = "v1"
80
+ ... KEY2 = "v2"
65
81
  """
66
-
67
- @classmethod
68
- def values(cls):
69
- return sorted(map(lambda c: c.value, cls))
70
-
71
- @classmethod
72
- def keys(cls):
73
- return sorted(map(lambda c: c.name, cls))
ads/llm/__init__.py CHANGED
@@ -1,21 +1,24 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  try:
8
7
  import langchain
9
- from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
10
- OCIModelDeploymentVLLM,
11
- OCIModelDeploymentTGI,
12
- )
8
+
9
+ from ads.llm.chat_template import ChatTemplates
13
10
  from ads.llm.langchain.plugins.chat_models.oci_data_science import (
14
11
  ChatOCIModelDeployment,
15
- ChatOCIModelDeploymentVLLM,
16
12
  ChatOCIModelDeploymentTGI,
13
+ ChatOCIModelDeploymentVLLM,
14
+ )
15
+ from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import (
16
+ OCIDataScienceEmbedding,
17
+ )
18
+ from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
19
+ OCIModelDeploymentTGI,
20
+ OCIModelDeploymentVLLM,
17
21
  )
18
- from ads.llm.chat_template import ChatTemplates
19
22
  except ImportError as ex:
20
23
  if ex.name == "langchain":
21
24
  raise ImportError(
@@ -0,0 +1,4 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) 2025 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -0,0 +1,184 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) 2025 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ from typing import Any, Callable, Dict, List, Mapping, Optional
7
+
8
+ import requests
9
+ from langchain_core.embeddings import Embeddings
10
+ from langchain_core.language_models.llms import create_base_retry_decorator
11
+ from pydantic import BaseModel, Field
12
+
13
+ DEFAULT_HEADER = {
14
+ "Content-Type": "application/json",
15
+ }
16
+
17
+
18
+ class TokenExpiredError(Exception):
19
+ pass
20
+
21
+
22
+ def _create_retry_decorator(llm) -> Callable[[Any], Any]:
23
+ """Creates a retry decorator."""
24
+ errors = [requests.exceptions.ConnectTimeout, TokenExpiredError]
25
+ decorator = create_base_retry_decorator(
26
+ error_types=errors, max_retries=llm.max_retries
27
+ )
28
+ return decorator
29
+
30
+
31
+ class OCIDataScienceEmbedding(BaseModel, Embeddings):
32
+ """Embedding model deployed on OCI Data Science Model Deployment.
33
+
34
+ Example:
35
+
36
+ .. code-block:: python
37
+
38
+ from ads.llm import OCIDataScienceEmbedding
39
+
40
+ embeddings = OCIDataScienceEmbedding(
41
+ endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<md_ocid>/predict",
42
+ )
43
+ """ # noqa: E501
44
+
45
+ auth: dict = Field(default_factory=dict, exclude=True)
46
+ """ADS auth dictionary for OCI authentication:
47
+ https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html.
48
+ This can be generated by calling `ads.common.auth.api_keys()`
49
+ or `ads.common.auth.resource_principal()`. If this is not
50
+ provided then the `ads.common.default_signer()` will be used."""
51
+
52
+ endpoint: str = ""
53
+ """The uri of the endpoint from the deployed Model Deployment model."""
54
+
55
+ model_kwargs: Optional[Dict] = None
56
+ """Keyword arguments to pass to the model."""
57
+
58
+ endpoint_kwargs: Optional[Dict] = None
59
+ """Optional attributes (except for headers) passed to the request.post
60
+ function.
61
+ """
62
+
63
+ max_retries: int = 1
64
+ """The maximum number of retries to make when generating."""
65
+
66
+ @property
67
+ def _identifying_params(self) -> Mapping[str, Any]:
68
+ """Get the identifying parameters."""
69
+ _model_kwargs = self.model_kwargs or {}
70
+ return {
71
+ **{"endpoint": self.endpoint},
72
+ **{"model_kwargs": _model_kwargs},
73
+ }
74
+
75
+ def _embed_with_retry(self, **kwargs) -> Any:
76
+ """Use tenacity to retry the call."""
77
+ retry_decorator = _create_retry_decorator(self)
78
+
79
+ @retry_decorator
80
+ def _completion_with_retry(**kwargs: Any) -> Any:
81
+ try:
82
+ response = requests.post(self.endpoint, **kwargs)
83
+ response.raise_for_status()
84
+ return response
85
+ except requests.exceptions.HTTPError as http_err:
86
+ if response.status_code == 401 and self._refresh_signer():
87
+ raise TokenExpiredError() from http_err
88
+ else:
89
+ raise ValueError(
90
+ f"Server error: {str(http_err)}. Message: {response.text}"
91
+ ) from http_err
92
+ except Exception as e:
93
+ raise ValueError(f"Error occurs by inference endpoint: {str(e)}") from e
94
+
95
+ return _completion_with_retry(**kwargs)
96
+
97
+ def _embedding(self, texts: List[str]) -> List[List[float]]:
98
+ """Call out to OCI Data Science Model Deployment Endpoint.
99
+
100
+ Args:
101
+ texts: A list of texts to embed.
102
+
103
+ Returns:
104
+ A list of list of floats representing the embeddings, or None if an
105
+ error occurs.
106
+ """
107
+ _model_kwargs = self.model_kwargs or {}
108
+ body = self._construct_request_body(texts, _model_kwargs)
109
+ request_kwargs = self._construct_request_kwargs(body)
110
+ response = self._embed_with_retry(**request_kwargs)
111
+ return self._proceses_response(response)
112
+
113
+ def _construct_request_kwargs(self, body: Any) -> dict:
114
+ """Constructs the request kwargs as a dictionary."""
115
+ from ads.model.common.utils import _is_json_serializable
116
+
117
+ _endpoint_kwargs = self.endpoint_kwargs or {}
118
+ headers = _endpoint_kwargs.pop("headers", DEFAULT_HEADER)
119
+ return (
120
+ dict(
121
+ headers=headers,
122
+ json=body,
123
+ auth=self.auth.get("signer"),
124
+ **_endpoint_kwargs,
125
+ )
126
+ if _is_json_serializable(body)
127
+ else dict(
128
+ headers=headers,
129
+ data=body,
130
+ auth=self.auth.get("signer"),
131
+ **_endpoint_kwargs,
132
+ )
133
+ )
134
+
135
+ def _construct_request_body(self, texts: List[str], params: dict) -> Any:
136
+ """Constructs the request body."""
137
+ return {"input": texts}
138
+
139
+ def _proceses_response(self, response: requests.Response) -> List[List[float]]:
140
+ """Extracts results from requests.Response."""
141
+ try:
142
+ res_json = response.json()
143
+ embeddings = res_json["data"][0]["embedding"]
144
+ except Exception as e:
145
+ raise ValueError(
146
+ f"Error raised by inference API: {e}.\nResponse: {response.text}"
147
+ ) from e
148
+ return embeddings
149
+
150
+ def embed_documents(
151
+ self,
152
+ texts: List[str],
153
+ chunk_size: Optional[int] = None,
154
+ ) -> List[List[float]]:
155
+ """Compute doc embeddings using OCI Data Science Model Deployment Endpoint.
156
+
157
+ Args:
158
+ texts: The list of texts to embed.
159
+ chunk_size: The chunk size defines how many input texts will
160
+ be grouped together as request. If None, will use the
161
+ chunk size specified by the class.
162
+
163
+ Returns:
164
+ List of embeddings, one for each text.
165
+ """
166
+ results = []
167
+ _chunk_size = (
168
+ len(texts) if (not chunk_size or chunk_size > len(texts)) else chunk_size
169
+ )
170
+ for i in range(0, len(texts), _chunk_size):
171
+ response = self._embedding(texts[i : i + _chunk_size])
172
+ results.extend(response)
173
+ return results
174
+
175
+ def embed_query(self, text: str) -> List[float]:
176
+ """Compute query embeddings using OCI Data Science Model Deployment Endpoint.
177
+
178
+ Args:
179
+ text: The text to embed.
180
+
181
+ Returns:
182
+ Embeddings for the text.
183
+ """
184
+ return self._embedding([text])[0]
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
3
  # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -12,9 +11,9 @@ from typing import Dict, Optional
12
11
  from zipfile import ZipFile
13
12
 
14
13
  from ads.common import utils
14
+ from ads.common.object_storage_details import ObjectStorageDetails
15
15
  from ads.common.utils import extract_region
16
16
  from ads.model.service.oci_datascience_model import OCIDataScienceModel
17
- from ads.common.object_storage_details import ObjectStorageDetails
18
17
 
19
18
 
20
19
  class ArtifactDownloader(ABC):
@@ -169,9 +168,9 @@ class LargeArtifactDownloader(ArtifactDownloader):
169
168
 
170
169
  def _download(self):
171
170
  """Downloads model artifacts."""
172
- self.progress.update(f"Importing model artifacts from catalog")
171
+ self.progress.update("Importing model artifacts from catalog")
173
172
 
174
- if self.dsc_model.is_model_by_reference() and self.model_file_description:
173
+ if self.dsc_model._is_model_by_reference() and self.model_file_description:
175
174
  self.download_from_model_file_description()
176
175
  self.progress.update()
177
176
  return