oracle-ads 2.12.2__py3-none-any.whl → 2.12.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/common/enums.py +9 -0
- ads/aqua/common/utils.py +83 -6
- ads/aqua/config/config.py +0 -16
- ads/aqua/constants.py +2 -0
- ads/aqua/evaluation/entities.py +45 -50
- ads/aqua/evaluation/evaluation.py +26 -61
- ads/aqua/extension/deployment_handler.py +35 -0
- ads/aqua/extension/errors.py +1 -0
- ads/aqua/extension/evaluation_handler.py +0 -5
- ads/aqua/extension/finetune_handler.py +1 -2
- ads/aqua/extension/model_handler.py +38 -1
- ads/aqua/extension/ui_handler.py +22 -3
- ads/aqua/finetuning/entities.py +5 -4
- ads/aqua/finetuning/finetuning.py +13 -8
- ads/aqua/model/constants.py +1 -0
- ads/aqua/model/entities.py +2 -0
- ads/aqua/model/model.py +350 -140
- ads/aqua/modeldeployment/deployment.py +118 -62
- ads/aqua/modeldeployment/entities.py +10 -2
- ads/aqua/ui.py +29 -16
- ads/config.py +3 -8
- ads/dataset/dataset.py +2 -2
- ads/dataset/factory.py +1 -1
- ads/llm/deploy.py +6 -0
- ads/llm/guardrails/base.py +0 -1
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +118 -41
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +18 -14
- ads/llm/templates/score_chain.jinja2 +0 -1
- ads/model/datascience_model.py +519 -16
- ads/model/deployment/model_deployment.py +13 -0
- ads/model/deployment/model_deployment_infrastructure.py +34 -0
- ads/model/generic_model.py +10 -0
- ads/model/model_properties.py +1 -0
- ads/model/service/oci_datascience_model.py +28 -0
- ads/opctl/operator/lowcode/anomaly/const.py +66 -1
- ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +161 -0
- ads/opctl/operator/lowcode/anomaly/model/autots.py +30 -15
- ads/opctl/operator/lowcode/anomaly/model/factory.py +15 -3
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +1 -1
- ads/opctl/operator/lowcode/anomaly/schema.yaml +10 -0
- ads/opctl/operator/lowcode/anomaly/utils.py +3 -0
- ads/opctl/operator/lowcode/forecast/cmd.py +3 -9
- ads/opctl/operator/lowcode/forecast/const.py +3 -4
- ads/opctl/operator/lowcode/forecast/model/factory.py +13 -12
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +4 -3
- ads/opctl/operator/lowcode/forecast/operator_config.py +17 -10
- ads/opctl/operator/lowcode/forecast/schema.yaml +2 -2
- ads/oracledb/oracle_db.py +32 -20
- ads/secrets/adb.py +28 -6
- {oracle_ads-2.12.2.dist-info → oracle_ads-2.12.4.dist-info}/METADATA +3 -2
- {oracle_ads-2.12.2.dist-info → oracle_ads-2.12.4.dist-info}/RECORD +54 -55
- {oracle_ads-2.12.2.dist-info → oracle_ads-2.12.4.dist-info}/WHEEL +1 -1
- ads/aqua/config/deployment_config_defaults.json +0 -38
- ads/aqua/config/resource_limit_names.json +0 -9
- {oracle_ads-2.12.2.dist-info → oracle_ads-2.12.4.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.2.dist-info → oracle_ads-2.12.4.dist-info}/entry_points.txt +0 -0
@@ -3,23 +3,24 @@
|
|
3
3
|
|
4
4
|
# Copyright (c) 2023 Oracle and/or its affiliates.
|
5
5
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
"""Chat model for OCI data science model deployment endpoint."""
|
6
7
|
|
7
|
-
|
8
|
+
import importlib
|
8
9
|
import json
|
9
10
|
import logging
|
10
11
|
from operator import itemgetter
|
11
12
|
from typing import (
|
12
13
|
Any,
|
13
14
|
AsyncIterator,
|
15
|
+
Callable,
|
14
16
|
Dict,
|
15
17
|
Iterator,
|
16
18
|
List,
|
17
19
|
Literal,
|
18
20
|
Optional,
|
21
|
+
Sequence,
|
19
22
|
Type,
|
20
23
|
Union,
|
21
|
-
Sequence,
|
22
|
-
Callable,
|
23
24
|
)
|
24
25
|
|
25
26
|
from langchain_core.callbacks import (
|
@@ -33,21 +34,16 @@ from langchain_core.language_models.chat_models import (
|
|
33
34
|
generate_from_stream,
|
34
35
|
)
|
35
36
|
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
36
|
-
from langchain_core.tools import BaseTool
|
37
37
|
from langchain_core.output_parsers import (
|
38
38
|
JsonOutputParser,
|
39
39
|
PydanticOutputParser,
|
40
40
|
)
|
41
41
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
42
42
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
43
|
+
from langchain_core.tools import BaseTool
|
43
44
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
44
|
-
from
|
45
|
-
_convert_delta_to_message_chunk,
|
46
|
-
_convert_message_to_dict,
|
47
|
-
_convert_dict_to_message,
|
48
|
-
)
|
45
|
+
from pydantic import BaseModel, Field, model_validator
|
49
46
|
|
50
|
-
from pydantic import BaseModel, Field
|
51
47
|
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
|
52
48
|
DEFAULT_MODEL_NAME,
|
53
49
|
BaseOCIModelDeployment,
|
@@ -63,15 +59,40 @@ def _is_pydantic_class(obj: Any) -> bool:
|
|
63
59
|
class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
64
60
|
"""OCI Data Science Model Deployment chat model integration.
|
65
61
|
|
66
|
-
|
67
|
-
|
62
|
+
Setup:
|
63
|
+
Install ``oracle-ads`` and ``langchain-openai``.
|
68
64
|
|
69
|
-
|
70
|
-
credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
|
65
|
+
.. code-block:: bash
|
71
66
|
|
72
|
-
|
73
|
-
|
74
|
-
|
67
|
+
pip install -U oracle-ads langchain-openai
|
68
|
+
|
69
|
+
Use `ads.set_auth()` to configure authentication.
|
70
|
+
For example, to use OCI resource_principal for authentication:
|
71
|
+
|
72
|
+
.. code-block:: python
|
73
|
+
|
74
|
+
import ads
|
75
|
+
ads.set_auth("resource_principal")
|
76
|
+
|
77
|
+
For more details on authentication, see:
|
78
|
+
https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
|
79
|
+
|
80
|
+
Make sure to have the required policies to access the OCI Data
|
81
|
+
Science Model Deployment endpoint. See:
|
82
|
+
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm
|
83
|
+
|
84
|
+
|
85
|
+
Key init args - completion params:
|
86
|
+
endpoint: str
|
87
|
+
The OCI model deployment endpoint.
|
88
|
+
temperature: float
|
89
|
+
Sampling temperature.
|
90
|
+
max_tokens: Optional[int]
|
91
|
+
Max number of tokens to generate.
|
92
|
+
|
93
|
+
Key init args — client params:
|
94
|
+
auth: dict
|
95
|
+
ADS auth dictionary for OCI authentication.
|
75
96
|
|
76
97
|
Instantiate:
|
77
98
|
.. code-block:: python
|
@@ -79,7 +100,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
79
100
|
from langchain_community.chat_models import ChatOCIModelDeployment
|
80
101
|
|
81
102
|
chat = ChatOCIModelDeployment(
|
82
|
-
endpoint="https://modeldeployment
|
103
|
+
endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<ocid>/predict",
|
83
104
|
model="odsc-llm",
|
84
105
|
streaming=True,
|
85
106
|
max_retries=3,
|
@@ -94,7 +115,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
94
115
|
.. code-block:: python
|
95
116
|
|
96
117
|
messages = [
|
97
|
-
("system", "
|
118
|
+
("system", "Translate the user sentence to French."),
|
98
119
|
("human", "Hello World!"),
|
99
120
|
]
|
100
121
|
chat.invoke(messages)
|
@@ -102,7 +123,19 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
102
123
|
.. code-block:: python
|
103
124
|
|
104
125
|
AIMessage(
|
105
|
-
content='Bonjour le monde!',
|
126
|
+
content='Bonjour le monde!',
|
127
|
+
response_metadata={
|
128
|
+
'token_usage': {
|
129
|
+
'prompt_tokens': 40,
|
130
|
+
'total_tokens': 50,
|
131
|
+
'completion_tokens': 10
|
132
|
+
},
|
133
|
+
'model_name': 'odsc-llm',
|
134
|
+
'system_fingerprint': '',
|
135
|
+
'finish_reason': 'stop'
|
136
|
+
},
|
137
|
+
id='run-cbed62da-e1b3-4abd-9df3-ec89d69ca012-0'
|
138
|
+
)
|
106
139
|
|
107
140
|
Streaming:
|
108
141
|
.. code-block:: python
|
@@ -112,18 +145,18 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
112
145
|
|
113
146
|
.. code-block:: python
|
114
147
|
|
115
|
-
content='' id='run-
|
116
|
-
content='\n' id='run-
|
117
|
-
content='B' id='run-
|
118
|
-
content='on' id='run-
|
119
|
-
content='j' id='run-
|
120
|
-
content='our' id='run-
|
121
|
-
content=' le' id='run-
|
122
|
-
content=' monde' id='run-
|
123
|
-
content='!' id='run-
|
124
|
-
content='' response_metadata={'finish_reason': 'stop'} id='run-
|
125
|
-
|
126
|
-
|
148
|
+
content='' id='run-02c6-c43f-42de'
|
149
|
+
content='\n' id='run-02c6-c43f-42de'
|
150
|
+
content='B' id='run-02c6-c43f-42de'
|
151
|
+
content='on' id='run-02c6-c43f-42de'
|
152
|
+
content='j' id='run-02c6-c43f-42de'
|
153
|
+
content='our' id='run-02c6-c43f-42de'
|
154
|
+
content=' le' id='run-02c6-c43f-42de'
|
155
|
+
content=' monde' id='run-02c6-c43f-42de'
|
156
|
+
content='!' id='run-02c6-c43f-42de'
|
157
|
+
content='' response_metadata={'finish_reason': 'stop'} id='run-02c6-c43f-42de'
|
158
|
+
|
159
|
+
Async:
|
127
160
|
.. code-block:: python
|
128
161
|
|
129
162
|
await chat.ainvoke(messages)
|
@@ -133,7 +166,11 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
133
166
|
|
134
167
|
.. code-block:: python
|
135
168
|
|
136
|
-
AIMessage(
|
169
|
+
AIMessage(
|
170
|
+
content='Bonjour le monde!',
|
171
|
+
response_metadata={'finish_reason': 'stop'},
|
172
|
+
id='run-8657a105-96b7-4bb6-b98e-b69ca420e5d1-0'
|
173
|
+
)
|
137
174
|
|
138
175
|
Structured output:
|
139
176
|
.. code-block:: python
|
@@ -147,19 +184,22 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
147
184
|
|
148
185
|
structured_llm = chat.with_structured_output(Joke, method="json_mode")
|
149
186
|
structured_llm.invoke(
|
150
|
-
"Tell me a joke about cats,
|
187
|
+
"Tell me a joke about cats, "
|
188
|
+
"respond in JSON with `setup` and `punchline` keys"
|
151
189
|
)
|
152
190
|
|
153
191
|
.. code-block:: python
|
154
192
|
|
155
|
-
Joke(
|
193
|
+
Joke(
|
194
|
+
setup='Why did the cat get stuck in the tree?',
|
195
|
+
punchline='Because it was chasing its tail!'
|
196
|
+
)
|
156
197
|
|
157
198
|
See ``ChatOCIModelDeployment.with_structured_output()`` for more.
|
158
199
|
|
159
200
|
Customized Usage:
|
160
|
-
|
161
|
-
|
162
|
-
`_construct_json_body` for satisfying customized needed.
|
201
|
+
You can inherit from base class and overwrite the `_process_response`,
|
202
|
+
`_process_stream_response`, `_construct_json_body` for customized usage.
|
163
203
|
|
164
204
|
.. code-block:: python
|
165
205
|
|
@@ -180,12 +220,31 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
180
220
|
}
|
181
221
|
|
182
222
|
chat = MyChatModel(
|
183
|
-
endpoint=f"https://modeldeployment
|
223
|
+
endpoint=f"https://modeldeployment.<region>.oci.customer-oci.com/{ocid}/predict",
|
184
224
|
model="odsc-llm",
|
185
225
|
}
|
186
226
|
|
187
227
|
chat.invoke("tell me a joke")
|
188
228
|
|
229
|
+
Response metadata
|
230
|
+
.. code-block:: python
|
231
|
+
|
232
|
+
ai_msg = chat.invoke(messages)
|
233
|
+
ai_msg.response_metadata
|
234
|
+
|
235
|
+
.. code-block:: python
|
236
|
+
|
237
|
+
{
|
238
|
+
'token_usage': {
|
239
|
+
'prompt_tokens': 40,
|
240
|
+
'total_tokens': 50,
|
241
|
+
'completion_tokens': 10
|
242
|
+
},
|
243
|
+
'model_name': 'odsc-llm',
|
244
|
+
'system_fingerprint': '',
|
245
|
+
'finish_reason': 'stop'
|
246
|
+
}
|
247
|
+
|
189
248
|
""" # noqa: E501
|
190
249
|
|
191
250
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
@@ -198,6 +257,17 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
198
257
|
"""Stop words to use when generating. Model output is cut off
|
199
258
|
at the first occurrence of any of these substrings."""
|
200
259
|
|
260
|
+
@model_validator(mode="before")
|
261
|
+
@classmethod
|
262
|
+
def validate_openai(cls, values: Any) -> Any:
|
263
|
+
"""Checks if langchain_openai is installed."""
|
264
|
+
if not importlib.util.find_spec("langchain_openai"):
|
265
|
+
raise ImportError(
|
266
|
+
"Could not import langchain_openai package. "
|
267
|
+
"Please install it with `pip install langchain_openai`."
|
268
|
+
)
|
269
|
+
return values
|
270
|
+
|
201
271
|
@property
|
202
272
|
def _llm_type(self) -> str:
|
203
273
|
"""Return type of llm."""
|
@@ -552,6 +622,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
552
622
|
converted messages and additional parameters.
|
553
623
|
|
554
624
|
"""
|
625
|
+
from langchain_openai.chat_models.base import _convert_message_to_dict
|
626
|
+
|
555
627
|
return {
|
556
628
|
"messages": [_convert_message_to_dict(m) for m in messages],
|
557
629
|
**params,
|
@@ -578,6 +650,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
578
650
|
ValueError: If the response JSON is not well-formed or does not
|
579
651
|
contain the expected structure.
|
580
652
|
"""
|
653
|
+
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
|
654
|
+
|
581
655
|
try:
|
582
656
|
choice = response_json["choices"][0]
|
583
657
|
if not isinstance(choice, dict):
|
@@ -616,6 +690,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
616
690
|
contain the expected structure.
|
617
691
|
|
618
692
|
"""
|
693
|
+
from langchain_openai.chat_models.base import _convert_dict_to_message
|
694
|
+
|
619
695
|
generations = []
|
620
696
|
try:
|
621
697
|
choices = response_json["choices"]
|
@@ -760,8 +836,9 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
760
836
|
tool_choice: Optional[str] = None
|
761
837
|
"""Whether to use tool calling.
|
762
838
|
Defaults to None, tool calling is disabled.
|
763
|
-
Tool calling requires model support and vLLM to be configured
|
764
|
-
|
839
|
+
Tool calling requires model support and the vLLM to be configured
|
840
|
+
with `--tool-call-parser`.
|
841
|
+
Set this to `auto` for the model to make tool calls automatically.
|
765
842
|
Set this to `required` to force the model to always call one or more tools.
|
766
843
|
"""
|
767
844
|
|
@@ -5,8 +5,11 @@
|
|
5
5
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
6
|
|
7
7
|
|
8
|
+
"""LLM for OCI data science model deployment endpoint."""
|
9
|
+
|
8
10
|
import json
|
9
11
|
import logging
|
12
|
+
import traceback
|
10
13
|
from typing import (
|
11
14
|
Any,
|
12
15
|
AsyncIterator,
|
@@ -21,7 +24,6 @@ from typing import (
|
|
21
24
|
|
22
25
|
import aiohttp
|
23
26
|
import requests
|
24
|
-
import traceback
|
25
27
|
from langchain_core.callbacks import (
|
26
28
|
AsyncCallbackManagerForLLMRun,
|
27
29
|
CallbackManagerForLLMRun,
|
@@ -29,9 +31,10 @@ from langchain_core.callbacks import (
|
|
29
31
|
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
|
30
32
|
from langchain_core.load.serializable import Serializable
|
31
33
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
32
|
-
from langchain_core.utils import get_from_dict_or_env
|
34
|
+
from langchain_core.utils import get_from_dict_or_env
|
35
|
+
from pydantic import Field, model_validator
|
36
|
+
|
33
37
|
from langchain_community.utilities.requests import Requests
|
34
|
-
from pydantic import Field
|
35
38
|
|
36
39
|
logger = logging.getLogger(__name__)
|
37
40
|
|
@@ -83,11 +86,12 @@ class BaseOCIModelDeployment(Serializable):
|
|
83
86
|
max_retries: int = 3
|
84
87
|
"""Maximum number of retries to make when generating."""
|
85
88
|
|
86
|
-
@
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
89
|
+
@model_validator(mode="before")
|
90
|
+
@classmethod
|
91
|
+
def validate_environment(cls, values: Dict) -> Dict:
|
92
|
+
"""Checks if oracle-ads is installed and
|
93
|
+
get credentials/endpoint from environment.
|
94
|
+
"""
|
91
95
|
try:
|
92
96
|
import ads
|
93
97
|
|
@@ -256,7 +260,7 @@ class BaseOCIModelDeployment(Serializable):
|
|
256
260
|
if hasattr(response, "status_code")
|
257
261
|
else response.status
|
258
262
|
)
|
259
|
-
if status_code
|
263
|
+
if status_code in [401, 404] and self._refresh_signer():
|
260
264
|
raise TokenExpiredError() from http_err
|
261
265
|
|
262
266
|
raise ServerError(
|
@@ -353,6 +357,11 @@ class BaseOCIModelDeployment(Serializable):
|
|
353
357
|
self.auth["signer"].refresh_security_token()
|
354
358
|
return True
|
355
359
|
return False
|
360
|
+
|
361
|
+
@classmethod
|
362
|
+
def is_lc_serializable(cls) -> bool:
|
363
|
+
"""Return whether this model can be serialized by LangChain."""
|
364
|
+
return True
|
356
365
|
|
357
366
|
|
358
367
|
class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
@@ -445,11 +454,6 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
|
445
454
|
"""Return type of llm."""
|
446
455
|
return "oci_model_deployment_endpoint"
|
447
456
|
|
448
|
-
@classmethod
|
449
|
-
def is_lc_serializable(cls) -> bool:
|
450
|
-
"""Return whether this model can be serialized by Langchain."""
|
451
|
-
return True
|
452
|
-
|
453
457
|
@property
|
454
458
|
def _default_params(self) -> Dict[str, Any]:
|
455
459
|
"""Get the default parameters."""
|