oracle-ads 2.12.11__py3-none-any.whl → 2.13.1rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/app.py +23 -10
- ads/aqua/common/enums.py +19 -14
- ads/aqua/common/errors.py +3 -4
- ads/aqua/common/utils.py +2 -2
- ads/aqua/constants.py +1 -0
- ads/aqua/evaluation/constants.py +7 -7
- ads/aqua/evaluation/errors.py +3 -4
- ads/aqua/extension/model_handler.py +23 -0
- ads/aqua/extension/models/ws_models.py +5 -6
- ads/aqua/finetuning/constants.py +3 -3
- ads/aqua/model/constants.py +7 -7
- ads/aqua/model/enums.py +4 -5
- ads/aqua/model/model.py +22 -0
- ads/aqua/modeldeployment/entities.py +3 -1
- ads/common/auth.py +33 -20
- ads/common/extended_enum.py +52 -44
- ads/llm/__init__.py +11 -8
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/model/artifact_downloader.py +3 -4
- ads/model/datascience_model.py +84 -64
- ads/model/generic_model.py +3 -3
- ads/model/model_metadata.py +17 -11
- ads/model/service/oci_datascience_model.py +12 -14
- ads/opctl/anomaly_detection.py +11 -0
- ads/opctl/backend/marketplace/helm_helper.py +13 -14
- ads/opctl/cli.py +4 -5
- ads/opctl/cmds.py +28 -32
- ads/opctl/config/merger.py +8 -11
- ads/opctl/config/resolver.py +25 -30
- ads/opctl/forecast.py +11 -0
- ads/opctl/operator/cli.py +9 -9
- ads/opctl/operator/common/backend_factory.py +56 -60
- ads/opctl/operator/common/const.py +5 -5
- ads/opctl/operator/lowcode/anomaly/const.py +8 -9
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +43 -48
- ads/opctl/operator/lowcode/forecast/__main__.py +5 -5
- ads/opctl/operator/lowcode/forecast/const.py +6 -6
- ads/opctl/operator/lowcode/forecast/model/arima.py +6 -3
- ads/opctl/operator/lowcode/forecast/model/automlx.py +53 -31
- ads/opctl/operator/lowcode/forecast/model/base_model.py +57 -30
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +60 -2
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +5 -2
- ads/opctl/operator/lowcode/forecast/model/prophet.py +28 -15
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +19 -11
- ads/opctl/operator/lowcode/pii/constant.py +6 -7
- ads/opctl/operator/lowcode/recommender/constant.py +12 -7
- ads/opctl/operator/runtime/marketplace_runtime.py +4 -10
- ads/opctl/operator/runtime/runtime.py +4 -6
- ads/pipeline/ads_pipeline_run.py +13 -25
- ads/pipeline/visualizer/graph_renderer.py +3 -4
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/METADATA +6 -6
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/RECORD +56 -52
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1rc0.dist-info}/entry_points.txt +0 -0
ads/common/auth.py
CHANGED
@@ -1,23 +1,25 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
|
-
# Copyright (c) 2021,
|
3
|
+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
import copy
|
8
|
-
from datetime import datetime
|
9
7
|
import os
|
10
|
-
from dataclasses import dataclass
|
11
8
|
import time
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from datetime import datetime
|
12
11
|
from typing import Any, Callable, Dict, Optional, Union
|
13
12
|
|
14
|
-
import ads.telemetry
|
15
13
|
import oci
|
14
|
+
from oci.config import (
|
15
|
+
DEFAULT_LOCATION, # "~/.oci/config"
|
16
|
+
DEFAULT_PROFILE, # "DEFAULT"
|
17
|
+
)
|
18
|
+
|
19
|
+
import ads.telemetry
|
16
20
|
from ads.common import logger
|
17
21
|
from ads.common.decorator.deprecate import deprecated
|
18
|
-
from ads.common.extended_enum import
|
19
|
-
from oci.config import DEFAULT_LOCATION # "~/.oci/config"
|
20
|
-
from oci.config import DEFAULT_PROFILE # "DEFAULT"
|
22
|
+
from ads.common.extended_enum import ExtendedEnum
|
21
23
|
|
22
24
|
SECURITY_TOKEN_LEFT_TIME = 600
|
23
25
|
|
@@ -26,7 +28,7 @@ class SecurityTokenError(Exception): # pragma: no cover
|
|
26
28
|
pass
|
27
29
|
|
28
30
|
|
29
|
-
class AuthType(
|
31
|
+
class AuthType(ExtendedEnum):
|
30
32
|
API_KEY = "api_key"
|
31
33
|
RESOURCE_PRINCIPAL = "resource_principal"
|
32
34
|
INSTANCE_PRINCIPAL = "instance_principal"
|
@@ -73,7 +75,11 @@ class AuthState(metaclass=SingletonMeta):
|
|
73
75
|
self.oci_key_profile = self.oci_key_profile or os.environ.get(
|
74
76
|
"OCI_CONFIG_PROFILE", DEFAULT_PROFILE
|
75
77
|
)
|
76
|
-
self.oci_config =
|
78
|
+
self.oci_config = (
|
79
|
+
self.oci_config or {"region": os.environ["OCI_RESOURCE_REGION"]}
|
80
|
+
if os.environ.get("OCI_RESOURCE_REGION")
|
81
|
+
else {}
|
82
|
+
)
|
77
83
|
self.oci_signer_kwargs = self.oci_signer_kwargs or {}
|
78
84
|
self.oci_client_kwargs = self.oci_client_kwargs or {}
|
79
85
|
|
@@ -82,7 +88,9 @@ def set_auth(
|
|
82
88
|
auth: Optional[str] = AuthType.API_KEY,
|
83
89
|
oci_config_location: Optional[str] = DEFAULT_LOCATION,
|
84
90
|
profile: Optional[str] = DEFAULT_PROFILE,
|
85
|
-
config: Optional[Dict] = {"region": os.environ["OCI_RESOURCE_REGION"]}
|
91
|
+
config: Optional[Dict] = {"region": os.environ["OCI_RESOURCE_REGION"]}
|
92
|
+
if os.environ.get("OCI_RESOURCE_REGION")
|
93
|
+
else {},
|
86
94
|
signer: Optional[Any] = None,
|
87
95
|
signer_callable: Optional[Callable] = None,
|
88
96
|
signer_kwargs: Optional[Dict] = {},
|
@@ -202,8 +210,8 @@ def set_auth(
|
|
202
210
|
oci_config_location != DEFAULT_LOCATION or profile != DEFAULT_PROFILE
|
203
211
|
):
|
204
212
|
raise ValueError(
|
205
|
-
|
206
|
-
|
213
|
+
"'config' and 'oci_config_location', 'profile' pair are mutually exclusive."
|
214
|
+
"Please specify 'config' OR 'oci_config_location', 'profile' pair."
|
207
215
|
)
|
208
216
|
|
209
217
|
auth_state.oci_config = config
|
@@ -621,7 +629,7 @@ class APIKey(AuthSignerGenerator):
|
|
621
629
|
)
|
622
630
|
|
623
631
|
oci.config.validate_config(configuration)
|
624
|
-
logger.debug(
|
632
|
+
logger.debug("Using 'api_key' authentication.")
|
625
633
|
return {
|
626
634
|
"config": configuration,
|
627
635
|
"signer": oci.signer.Signer(
|
@@ -684,14 +692,19 @@ class ResourcePrincipal(AuthSignerGenerator):
|
|
684
692
|
"signer": oci.auth.signers.get_resource_principals_signer(),
|
685
693
|
"client_kwargs": self.client_kwargs,
|
686
694
|
}
|
687
|
-
logger.debug(
|
695
|
+
logger.debug("Using 'resource_principal' authentication.")
|
688
696
|
return signer_dict
|
689
697
|
|
690
698
|
@staticmethod
|
691
699
|
def supported():
|
692
700
|
return any(
|
693
701
|
os.environ.get(var)
|
694
|
-
for var in [
|
702
|
+
for var in [
|
703
|
+
"JOB_RUN_OCID",
|
704
|
+
"NB_SESSION_OCID",
|
705
|
+
"DATAFLOW_RUN_ID",
|
706
|
+
"PIPELINE_RUN_OCID",
|
707
|
+
]
|
695
708
|
)
|
696
709
|
|
697
710
|
|
@@ -747,7 +760,7 @@ class InstancePrincipal(AuthSignerGenerator):
|
|
747
760
|
),
|
748
761
|
"client_kwargs": self.client_kwargs,
|
749
762
|
}
|
750
|
-
logger.debug(
|
763
|
+
logger.debug("Using 'instance_principal' authentication.")
|
751
764
|
return signer_dict
|
752
765
|
|
753
766
|
|
@@ -814,7 +827,7 @@ class SecurityToken(AuthSignerGenerator):
|
|
814
827
|
oci.config.from_file(self.oci_config_location, self.oci_key_profile)
|
815
828
|
)
|
816
829
|
|
817
|
-
logger.debug(
|
830
|
+
logger.debug("Using 'security_token' authentication.")
|
818
831
|
|
819
832
|
for parameter in self.SECURITY_TOKEN_REQUIRED:
|
820
833
|
if parameter not in configuration:
|
@@ -903,7 +916,7 @@ class SecurityToken(AuthSignerGenerator):
|
|
903
916
|
raise ValueError("Invalid `security_token_file`. Specify a valid path.")
|
904
917
|
try:
|
905
918
|
token = None
|
906
|
-
with open(expanded_path
|
919
|
+
with open(expanded_path) as f:
|
907
920
|
token = f.read()
|
908
921
|
return token
|
909
922
|
except:
|
@@ -1023,7 +1036,7 @@ class OCIAuthContext:
|
|
1023
1036
|
logger.debug(f"OCI profile set to {self.profile}")
|
1024
1037
|
else:
|
1025
1038
|
ads.set_auth(auth=AuthType.RESOURCE_PRINCIPAL)
|
1026
|
-
logger.debug(
|
1039
|
+
logger.debug("OCI auth set to resource principal")
|
1027
1040
|
return self
|
1028
1041
|
|
1029
1042
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
ads/common/extended_enum.py
CHANGED
@@ -1,73 +1,81 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
|
-
# Copyright (c) 2022,
|
3
|
+
# Copyright (c) 2022, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
|
8
7
|
from abc import ABCMeta
|
9
|
-
from enum import Enum
|
10
8
|
|
11
9
|
|
12
10
|
class ExtendedEnumMeta(ABCMeta):
|
13
|
-
"""
|
11
|
+
"""
|
12
|
+
A helper metaclass to extend functionality of a generic "Enum-like" class.
|
14
13
|
|
15
14
|
Methods
|
16
15
|
-------
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
... KEY2 = "value2"
|
25
|
-
>>> print(TestEnum.KEY1) # "value1"
|
16
|
+
__contains__(cls, item) -> bool:
|
17
|
+
Checks if `item` is among the attribute values of the class.
|
18
|
+
Case-insensitive if `item` is a string.
|
19
|
+
values(cls) -> tuple:
|
20
|
+
Returns the tuple of class attribute values.
|
21
|
+
keys(cls) -> tuple:
|
22
|
+
Returns the tuple of class attribute names.
|
26
23
|
"""
|
27
24
|
|
28
|
-
def __contains__(cls,
|
29
|
-
|
30
|
-
|
31
|
-
def values(cls) -> list:
|
32
|
-
"""Gets the list of class attributes values.
|
25
|
+
def __contains__(cls, item: object) -> bool:
|
26
|
+
"""
|
27
|
+
Checks if `item` is a member of the class's values.
|
33
28
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
29
|
+
- If `item` is a string, does a case-insensitive match against any string
|
30
|
+
values stored in the class.
|
31
|
+
- Otherwise, does a direct membership test.
|
32
|
+
"""
|
33
|
+
# Gather the attribute values
|
34
|
+
attr_values = cls.values()
|
35
|
+
|
36
|
+
# If item is a string, compare case-insensitively to any str-type values
|
37
|
+
if isinstance(item, str):
|
38
|
+
return any(
|
39
|
+
isinstance(val, str) and val.lower() == item.lower()
|
40
|
+
for val in attr_values
|
41
|
+
)
|
42
|
+
else:
|
43
|
+
# For non-string items (e.g., int), do a direct membership check
|
44
|
+
return item in attr_values
|
45
|
+
|
46
|
+
def __iter__(cls):
|
47
|
+
# Make the class iterable by returning an iterator over its values
|
48
|
+
return iter(cls.values())
|
49
|
+
|
50
|
+
def values(cls) -> tuple:
|
51
|
+
"""
|
52
|
+
Gets the tuple of class attribute values, excluding private or special
|
53
|
+
attributes and any callables (methods, etc.).
|
38
54
|
"""
|
39
55
|
return tuple(
|
40
|
-
value
|
56
|
+
value
|
57
|
+
for key, value in cls.__dict__.items()
|
58
|
+
if not key.startswith("_") and not callable(value)
|
41
59
|
)
|
42
60
|
|
43
|
-
def keys(cls) ->
|
44
|
-
"""
|
45
|
-
|
46
|
-
|
47
|
-
-------
|
48
|
-
list
|
49
|
-
The list of class attributes names.
|
61
|
+
def keys(cls) -> tuple:
|
62
|
+
"""
|
63
|
+
Gets the tuple of class attribute names, excluding private or special
|
64
|
+
attributes and any callables (methods, etc.).
|
50
65
|
"""
|
51
66
|
return tuple(
|
52
|
-
key
|
67
|
+
key
|
68
|
+
for key, value in cls.__dict__.items()
|
69
|
+
if not key.startswith("_") and not callable(value)
|
53
70
|
)
|
54
71
|
|
55
72
|
|
56
|
-
class ExtendedEnum(
|
73
|
+
class ExtendedEnum(metaclass=ExtendedEnumMeta):
|
57
74
|
"""The base class to extend functionality of a generic Enum.
|
58
75
|
|
59
76
|
Examples
|
60
77
|
--------
|
61
|
-
>>> class TestEnum(
|
62
|
-
... KEY1 = "
|
63
|
-
... KEY2 = "
|
64
|
-
>>> print(TestEnum.KEY1.value) # "value1"
|
78
|
+
>>> class TestEnum(ExtendedEnum):
|
79
|
+
... KEY1 = "v1"
|
80
|
+
... KEY2 = "v2"
|
65
81
|
"""
|
66
|
-
|
67
|
-
@classmethod
|
68
|
-
def values(cls):
|
69
|
-
return sorted(map(lambda c: c.value, cls))
|
70
|
-
|
71
|
-
@classmethod
|
72
|
-
def keys(cls):
|
73
|
-
return sorted(map(lambda c: c.name, cls))
|
ads/llm/__init__.py
CHANGED
@@ -1,21 +1,24 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c)
|
3
|
+
# Copyright (c) 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
try:
|
8
7
|
import langchain
|
9
|
-
|
10
|
-
|
11
|
-
OCIModelDeploymentTGI,
|
12
|
-
)
|
8
|
+
|
9
|
+
from ads.llm.chat_template import ChatTemplates
|
13
10
|
from ads.llm.langchain.plugins.chat_models.oci_data_science import (
|
14
11
|
ChatOCIModelDeployment,
|
15
|
-
ChatOCIModelDeploymentVLLM,
|
16
12
|
ChatOCIModelDeploymentTGI,
|
13
|
+
ChatOCIModelDeploymentVLLM,
|
14
|
+
)
|
15
|
+
from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import (
|
16
|
+
OCIDataScienceEmbedding,
|
17
|
+
)
|
18
|
+
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
|
19
|
+
OCIModelDeploymentTGI,
|
20
|
+
OCIModelDeploymentVLLM,
|
17
21
|
)
|
18
|
-
from ads.llm.chat_template import ChatTemplates
|
19
22
|
except ImportError as ex:
|
20
23
|
if ex.name == "langchain":
|
21
24
|
raise ImportError(
|
@@ -0,0 +1,184 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
# Copyright (c) 2025 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
from typing import Any, Callable, Dict, List, Mapping, Optional
|
7
|
+
|
8
|
+
import requests
|
9
|
+
from langchain_core.embeddings import Embeddings
|
10
|
+
from langchain_core.language_models.llms import create_base_retry_decorator
|
11
|
+
from pydantic import BaseModel, Field
|
12
|
+
|
13
|
+
DEFAULT_HEADER = {
|
14
|
+
"Content-Type": "application/json",
|
15
|
+
}
|
16
|
+
|
17
|
+
|
18
|
+
class TokenExpiredError(Exception):
|
19
|
+
pass
|
20
|
+
|
21
|
+
|
22
|
+
def _create_retry_decorator(llm) -> Callable[[Any], Any]:
|
23
|
+
"""Creates a retry decorator."""
|
24
|
+
errors = [requests.exceptions.ConnectTimeout, TokenExpiredError]
|
25
|
+
decorator = create_base_retry_decorator(
|
26
|
+
error_types=errors, max_retries=llm.max_retries
|
27
|
+
)
|
28
|
+
return decorator
|
29
|
+
|
30
|
+
|
31
|
+
class OCIDataScienceEmbedding(BaseModel, Embeddings):
|
32
|
+
"""Embedding model deployed on OCI Data Science Model Deployment.
|
33
|
+
|
34
|
+
Example:
|
35
|
+
|
36
|
+
.. code-block:: python
|
37
|
+
|
38
|
+
from ads.llm import OCIDataScienceEmbedding
|
39
|
+
|
40
|
+
embeddings = OCIDataScienceEmbedding(
|
41
|
+
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<md_ocid>/predict",
|
42
|
+
)
|
43
|
+
""" # noqa: E501
|
44
|
+
|
45
|
+
auth: dict = Field(default_factory=dict, exclude=True)
|
46
|
+
"""ADS auth dictionary for OCI authentication:
|
47
|
+
https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html.
|
48
|
+
This can be generated by calling `ads.common.auth.api_keys()`
|
49
|
+
or `ads.common.auth.resource_principal()`. If this is not
|
50
|
+
provided then the `ads.common.default_signer()` will be used."""
|
51
|
+
|
52
|
+
endpoint: str = ""
|
53
|
+
"""The uri of the endpoint from the deployed Model Deployment model."""
|
54
|
+
|
55
|
+
model_kwargs: Optional[Dict] = None
|
56
|
+
"""Keyword arguments to pass to the model."""
|
57
|
+
|
58
|
+
endpoint_kwargs: Optional[Dict] = None
|
59
|
+
"""Optional attributes (except for headers) passed to the request.post
|
60
|
+
function.
|
61
|
+
"""
|
62
|
+
|
63
|
+
max_retries: int = 1
|
64
|
+
"""The maximum number of retries to make when generating."""
|
65
|
+
|
66
|
+
@property
|
67
|
+
def _identifying_params(self) -> Mapping[str, Any]:
|
68
|
+
"""Get the identifying parameters."""
|
69
|
+
_model_kwargs = self.model_kwargs or {}
|
70
|
+
return {
|
71
|
+
**{"endpoint": self.endpoint},
|
72
|
+
**{"model_kwargs": _model_kwargs},
|
73
|
+
}
|
74
|
+
|
75
|
+
def _embed_with_retry(self, **kwargs) -> Any:
|
76
|
+
"""Use tenacity to retry the call."""
|
77
|
+
retry_decorator = _create_retry_decorator(self)
|
78
|
+
|
79
|
+
@retry_decorator
|
80
|
+
def _completion_with_retry(**kwargs: Any) -> Any:
|
81
|
+
try:
|
82
|
+
response = requests.post(self.endpoint, **kwargs)
|
83
|
+
response.raise_for_status()
|
84
|
+
return response
|
85
|
+
except requests.exceptions.HTTPError as http_err:
|
86
|
+
if response.status_code == 401 and self._refresh_signer():
|
87
|
+
raise TokenExpiredError() from http_err
|
88
|
+
else:
|
89
|
+
raise ValueError(
|
90
|
+
f"Server error: {str(http_err)}. Message: {response.text}"
|
91
|
+
) from http_err
|
92
|
+
except Exception as e:
|
93
|
+
raise ValueError(f"Error occurs by inference endpoint: {str(e)}") from e
|
94
|
+
|
95
|
+
return _completion_with_retry(**kwargs)
|
96
|
+
|
97
|
+
def _embedding(self, texts: List[str]) -> List[List[float]]:
|
98
|
+
"""Call out to OCI Data Science Model Deployment Endpoint.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
texts: A list of texts to embed.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
A list of list of floats representing the embeddings, or None if an
|
105
|
+
error occurs.
|
106
|
+
"""
|
107
|
+
_model_kwargs = self.model_kwargs or {}
|
108
|
+
body = self._construct_request_body(texts, _model_kwargs)
|
109
|
+
request_kwargs = self._construct_request_kwargs(body)
|
110
|
+
response = self._embed_with_retry(**request_kwargs)
|
111
|
+
return self._proceses_response(response)
|
112
|
+
|
113
|
+
def _construct_request_kwargs(self, body: Any) -> dict:
|
114
|
+
"""Constructs the request kwargs as a dictionary."""
|
115
|
+
from ads.model.common.utils import _is_json_serializable
|
116
|
+
|
117
|
+
_endpoint_kwargs = self.endpoint_kwargs or {}
|
118
|
+
headers = _endpoint_kwargs.pop("headers", DEFAULT_HEADER)
|
119
|
+
return (
|
120
|
+
dict(
|
121
|
+
headers=headers,
|
122
|
+
json=body,
|
123
|
+
auth=self.auth.get("signer"),
|
124
|
+
**_endpoint_kwargs,
|
125
|
+
)
|
126
|
+
if _is_json_serializable(body)
|
127
|
+
else dict(
|
128
|
+
headers=headers,
|
129
|
+
data=body,
|
130
|
+
auth=self.auth.get("signer"),
|
131
|
+
**_endpoint_kwargs,
|
132
|
+
)
|
133
|
+
)
|
134
|
+
|
135
|
+
def _construct_request_body(self, texts: List[str], params: dict) -> Any:
|
136
|
+
"""Constructs the request body."""
|
137
|
+
return {"input": texts}
|
138
|
+
|
139
|
+
def _proceses_response(self, response: requests.Response) -> List[List[float]]:
|
140
|
+
"""Extracts results from requests.Response."""
|
141
|
+
try:
|
142
|
+
res_json = response.json()
|
143
|
+
embeddings = res_json["data"][0]["embedding"]
|
144
|
+
except Exception as e:
|
145
|
+
raise ValueError(
|
146
|
+
f"Error raised by inference API: {e}.\nResponse: {response.text}"
|
147
|
+
) from e
|
148
|
+
return embeddings
|
149
|
+
|
150
|
+
def embed_documents(
|
151
|
+
self,
|
152
|
+
texts: List[str],
|
153
|
+
chunk_size: Optional[int] = None,
|
154
|
+
) -> List[List[float]]:
|
155
|
+
"""Compute doc embeddings using OCI Data Science Model Deployment Endpoint.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
texts: The list of texts to embed.
|
159
|
+
chunk_size: The chunk size defines how many input texts will
|
160
|
+
be grouped together as request. If None, will use the
|
161
|
+
chunk size specified by the class.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
List of embeddings, one for each text.
|
165
|
+
"""
|
166
|
+
results = []
|
167
|
+
_chunk_size = (
|
168
|
+
len(texts) if (not chunk_size or chunk_size > len(texts)) else chunk_size
|
169
|
+
)
|
170
|
+
for i in range(0, len(texts), _chunk_size):
|
171
|
+
response = self._embedding(texts[i : i + _chunk_size])
|
172
|
+
results.extend(response)
|
173
|
+
return results
|
174
|
+
|
175
|
+
def embed_query(self, text: str) -> List[float]:
|
176
|
+
"""Compute query embeddings using OCI Data Science Model Deployment Endpoint.
|
177
|
+
|
178
|
+
Args:
|
179
|
+
text: The text to embed.
|
180
|
+
|
181
|
+
Returns:
|
182
|
+
Embeddings for the text.
|
183
|
+
"""
|
184
|
+
return self._embedding([text])[0]
|
ads/model/artifact_downloader.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
3
|
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
@@ -12,9 +11,9 @@ from typing import Dict, Optional
|
|
12
11
|
from zipfile import ZipFile
|
13
12
|
|
14
13
|
from ads.common import utils
|
14
|
+
from ads.common.object_storage_details import ObjectStorageDetails
|
15
15
|
from ads.common.utils import extract_region
|
16
16
|
from ads.model.service.oci_datascience_model import OCIDataScienceModel
|
17
|
-
from ads.common.object_storage_details import ObjectStorageDetails
|
18
17
|
|
19
18
|
|
20
19
|
class ArtifactDownloader(ABC):
|
@@ -169,9 +168,9 @@ class LargeArtifactDownloader(ArtifactDownloader):
|
|
169
168
|
|
170
169
|
def _download(self):
|
171
170
|
"""Downloads model artifacts."""
|
172
|
-
self.progress.update(
|
171
|
+
self.progress.update("Importing model artifacts from catalog")
|
173
172
|
|
174
|
-
if self.dsc_model.
|
173
|
+
if self.dsc_model._is_model_by_reference() and self.model_file_description:
|
175
174
|
self.download_from_model_file_description()
|
176
175
|
self.progress.update()
|
177
176
|
return
|