oracle-ads 2.12.6__py3-none-any.whl → 2.12.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/common/utils.py +4 -1
- ads/aqua/constants.py +1 -0
- ads/aqua/evaluation/entities.py +2 -2
- ads/aqua/evaluation/evaluation.py +2 -6
- ads/aqua/extension/model_handler.py +4 -0
- ads/aqua/model/entities.py +2 -0
- ads/aqua/model/model.py +25 -19
- ads/llm/autogen/__init__.py +0 -0
- ads/llm/autogen/client_v02.py +282 -0
- {oracle_ads-2.12.6.dist-info → oracle_ads-2.12.7.dist-info}/METADATA +3 -2
- {oracle_ads-2.12.6.dist-info → oracle_ads-2.12.7.dist-info}/RECORD +14 -12
- {oracle_ads-2.12.6.dist-info → oracle_ads-2.12.7.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.6.dist-info → oracle_ads-2.12.7.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.6.dist-info → oracle_ads-2.12.7.dist-info}/entry_points.txt +0 -0
ads/aqua/common/utils.py
CHANGED
@@ -788,13 +788,14 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
|
|
788
788
|
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
|
789
789
|
|
790
790
|
|
791
|
-
def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
|
791
|
+
def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
|
792
792
|
"""Upload the local folder to the object storage
|
793
793
|
|
794
794
|
Args:
|
795
795
|
os_path (str): object storage URI with prefix. This is the path to upload
|
796
796
|
local_dir (str): Local directory where the object is downloaded
|
797
797
|
model_name (str): Name of the huggingface model
|
798
|
+
exclude_pattern (optional, str): The matching pattern of files to be excluded from uploading.
|
798
799
|
Retuns:
|
799
800
|
str: Object name inside the bucket
|
800
801
|
"""
|
@@ -804,6 +805,8 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
|
|
804
805
|
auth_state = AuthState()
|
805
806
|
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
|
806
807
|
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
|
808
|
+
if exclude_pattern:
|
809
|
+
command += f" --exclude {exclude_pattern}"
|
807
810
|
try:
|
808
811
|
logger.info(f"Running: {command}")
|
809
812
|
subprocess.check_call(shlex.split(command))
|
ads/aqua/constants.py
CHANGED
@@ -35,6 +35,7 @@ AQUA_MODEL_ARTIFACT_CONFIG = "config.json"
|
|
35
35
|
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
|
36
36
|
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
|
37
37
|
AQUA_MODEL_ARTIFACT_FILE = "model_file"
|
38
|
+
HF_METADATA_FOLDER = ".cache/"
|
38
39
|
HF_LOGIN_DEFAULT_TIMEOUT = 2
|
39
40
|
|
40
41
|
TRAINING_METRICS_FINAL = "training_metrics_final"
|
ads/aqua/evaluation/entities.py
CHANGED
@@ -83,7 +83,7 @@ class CreateAquaEvaluationDetails(Serializable):
|
|
83
83
|
ocpus: Optional[float] = None
|
84
84
|
log_group_id: Optional[str] = None
|
85
85
|
log_id: Optional[str] = None
|
86
|
-
metrics: Optional[List[str]] = None
|
86
|
+
metrics: Optional[List[Dict[str, Any]]] = None
|
87
87
|
force_overwrite: Optional[bool] = False
|
88
88
|
|
89
89
|
class Config:
|
@@ -140,7 +140,7 @@ class AquaEvaluationCommands(Serializable):
|
|
140
140
|
evaluation_id: str
|
141
141
|
evaluation_target_id: str
|
142
142
|
input_data: Dict[str, Any]
|
143
|
-
metrics: List[str]
|
143
|
+
metrics: List[Dict[str, Any]]
|
144
144
|
output_dir: str
|
145
145
|
params: Dict[str, Any]
|
146
146
|
|
@@ -159,7 +159,8 @@ class AquaEvaluationApp(AquaApp):
|
|
159
159
|
create_aqua_evaluation_details = CreateAquaEvaluationDetails(**kwargs)
|
160
160
|
except Exception as ex:
|
161
161
|
custom_errors = {
|
162
|
-
".".join(map(str, e["loc"])): e["msg"]
|
162
|
+
".".join(map(str, e["loc"])): e["msg"]
|
163
|
+
for e in json.loads(ex.json())
|
163
164
|
}
|
164
165
|
raise AquaValueError(
|
165
166
|
f"Invalid create evaluation parameters. Error details: {custom_errors}."
|
@@ -619,11 +620,6 @@ class AquaEvaluationApp(AquaApp):
|
|
619
620
|
evaluation_id=evaluation_id,
|
620
621
|
evaluation_target_id=evaluation_source_id,
|
621
622
|
input_data={
|
622
|
-
"columns": {
|
623
|
-
"prompt": "prompt",
|
624
|
-
"completion": "completion",
|
625
|
-
"category": "category",
|
626
|
-
},
|
627
623
|
"format": Path(dataset_path).suffix,
|
628
624
|
"url": dataset_path,
|
629
625
|
},
|
@@ -129,6 +129,8 @@ class AquaModelHandler(AquaAPIhandler):
|
|
129
129
|
str(input_data.get("download_from_hf", "false")).lower() == "true"
|
130
130
|
)
|
131
131
|
inference_container_uri = input_data.get("inference_container_uri")
|
132
|
+
allow_patterns = input_data.get("allow_patterns")
|
133
|
+
ignore_patterns = input_data.get("ignore_patterns")
|
132
134
|
|
133
135
|
return self.finish(
|
134
136
|
AquaModelApp().register(
|
@@ -141,6 +143,8 @@ class AquaModelHandler(AquaAPIhandler):
|
|
141
143
|
project_id=project_id,
|
142
144
|
model_file=model_file,
|
143
145
|
inference_container_uri=inference_container_uri,
|
146
|
+
allow_patterns=allow_patterns,
|
147
|
+
ignore_patterns=ignore_patterns,
|
144
148
|
)
|
145
149
|
)
|
146
150
|
|
ads/aqua/model/entities.py
CHANGED
@@ -289,6 +289,8 @@ class ImportModelDetails(CLIBuilderMixin):
|
|
289
289
|
project_id: Optional[str] = None
|
290
290
|
model_file: Optional[str] = None
|
291
291
|
inference_container_uri: Optional[str] = None
|
292
|
+
allow_patterns: Optional[List[str]] = None
|
293
|
+
ignore_patterns: Optional[List[str]] = None
|
292
294
|
|
293
295
|
def __post_init__(self):
|
294
296
|
self._command = "model register"
|
ads/aqua/model/model.py
CHANGED
@@ -40,6 +40,7 @@ from ads.aqua.constants import (
|
|
40
40
|
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
|
41
41
|
AQUA_MODEL_ARTIFACT_FILE,
|
42
42
|
AQUA_MODEL_TYPE_CUSTOM,
|
43
|
+
HF_METADATA_FOLDER,
|
43
44
|
LICENSE_TXT,
|
44
45
|
MODEL_BY_REFERENCE_OSS_PATH_KEY,
|
45
46
|
README,
|
@@ -1274,6 +1275,8 @@ class AquaModelApp(AquaApp):
|
|
1274
1275
|
model_name: str,
|
1275
1276
|
os_path: str,
|
1276
1277
|
local_dir: str = None,
|
1278
|
+
allow_patterns: List[str] = None,
|
1279
|
+
ignore_patterns: List[str] = None,
|
1277
1280
|
) -> str:
|
1278
1281
|
"""This helper function downloads the model artifact from Hugging Face to a local folder, then uploads
|
1279
1282
|
to object storage location.
|
@@ -1283,6 +1286,12 @@ class AquaModelApp(AquaApp):
|
|
1283
1286
|
model_name (str): The huggingface model name.
|
1284
1287
|
os_path (str): The OS path where the model files are located.
|
1285
1288
|
local_dir (str): The local temp dir to store the huggingface model.
|
1289
|
+
allow_patterns (list): Model files matching at least one pattern are downloaded.
|
1290
|
+
Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`.
|
1291
|
+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
|
1292
|
+
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
|
1293
|
+
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
|
1294
|
+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
|
1286
1295
|
|
1287
1296
|
Returns
|
1288
1297
|
-------
|
@@ -1293,30 +1302,19 @@ class AquaModelApp(AquaApp):
|
|
1293
1302
|
if not local_dir:
|
1294
1303
|
local_dir = os.path.join(os.path.expanduser("~"), "cached-model")
|
1295
1304
|
local_dir = os.path.join(local_dir, model_name)
|
1296
|
-
retry = 10
|
1297
|
-
i = 0
|
1298
|
-
huggingface_download_err_message = None
|
1299
|
-
while i < retry:
|
1300
|
-
try:
|
1301
|
-
# Download to cache folder. The while loop retries when there is a network failure
|
1302
|
-
snapshot_download(repo_id=model_name)
|
1303
|
-
except Exception as e:
|
1304
|
-
huggingface_download_err_message = str(e)
|
1305
|
-
i += 1
|
1306
|
-
else:
|
1307
|
-
break
|
1308
|
-
if i == retry:
|
1309
|
-
raise Exception(
|
1310
|
-
f"Could not download the model {model_name} from https://huggingface.co with message {huggingface_download_err_message}"
|
1311
|
-
)
|
1312
1305
|
os.makedirs(local_dir, exist_ok=True)
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1306
|
+
snapshot_download(
|
1307
|
+
repo_id=model_name,
|
1308
|
+
local_dir=local_dir,
|
1309
|
+
allow_patterns=allow_patterns,
|
1310
|
+
ignore_patterns=ignore_patterns,
|
1311
|
+
)
|
1312
|
+
# Upload to object storage and skip .cache/huggingface/ folder
|
1316
1313
|
model_artifact_path = upload_folder(
|
1317
1314
|
os_path=os_path,
|
1318
1315
|
local_dir=local_dir,
|
1319
1316
|
model_name=model_name,
|
1317
|
+
exclude_pattern=f"{HF_METADATA_FOLDER}*"
|
1320
1318
|
)
|
1321
1319
|
|
1322
1320
|
return model_artifact_path
|
@@ -1335,6 +1333,12 @@ class AquaModelApp(AquaApp):
|
|
1335
1333
|
os_path (str): Object storage destination URI to store the downloaded model. Format: oci://bucket-name@namespace/prefix
|
1336
1334
|
inference_container (str): selects service defaults
|
1337
1335
|
finetuning_container (str): selects service defaults
|
1336
|
+
allow_patterns (list): Model files matching at least one pattern are downloaded.
|
1337
|
+
Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`.
|
1338
|
+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
|
1339
|
+
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
|
1340
|
+
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
|
1341
|
+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
|
1338
1342
|
|
1339
1343
|
Returns:
|
1340
1344
|
AquaModel:
|
@@ -1381,6 +1385,8 @@ class AquaModelApp(AquaApp):
|
|
1381
1385
|
model_name=model_name,
|
1382
1386
|
os_path=import_model_details.os_path,
|
1383
1387
|
local_dir=import_model_details.local_dir,
|
1388
|
+
allow_patterns=import_model_details.allow_patterns,
|
1389
|
+
ignore_patterns=import_model_details.ignore_patterns,
|
1384
1390
|
).rstrip("/")
|
1385
1391
|
else:
|
1386
1392
|
artifact_path = import_model_details.os_path.rstrip("/")
|
File without changes
|
@@ -0,0 +1,282 @@
|
|
1
|
+
# coding: utf-8
|
2
|
+
# Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved.
|
3
|
+
# This software is dual-licensed to you under the Universal Permissive License (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose either license.
|
4
|
+
|
5
|
+
"""This module contains the custom LLM client for AutoGen v0.2 to use LangChain chat models.
|
6
|
+
https://microsoft.github.io/autogen/0.2/blog/2024/01/26/Custom-Models/
|
7
|
+
|
8
|
+
To use the custom client:
|
9
|
+
1. Prepare the LLM config, including the parameters for initializing the LangChain client.
|
10
|
+
2. Register the custom LLM
|
11
|
+
|
12
|
+
The LLM config should config the following keys:
|
13
|
+
* model_client_cls: Required by AutoGen to identify the custom client. It should be "LangChainModelClient"
|
14
|
+
* langchain_cls: LangChain class including the full import path.
|
15
|
+
* model: Name of the model to be used by AutoGen
|
16
|
+
* client_params: A dictionary containing the parameters to initialize the LangChain chat model.
|
17
|
+
|
18
|
+
Although the `LangChainModelClient` is designed to be generic and can potentially support any LangChain chat model,
|
19
|
+
the invocation depends on the server API spec and it may not be compatible with some implementations.
|
20
|
+
|
21
|
+
Following is an example config for OCI Generative AI service:
|
22
|
+
{
|
23
|
+
"model_client_cls": "LangChainModelClient",
|
24
|
+
"langchain_cls": "langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI",
|
25
|
+
"model": "cohere.command-r-plus",
|
26
|
+
# client_params will be used to initialize the LangChain ChatOCIGenAI class.
|
27
|
+
"client_params": {
|
28
|
+
"model_id": "cohere.command-r-plus",
|
29
|
+
"compartment_id": COMPARTMENT_OCID,
|
30
|
+
"model_kwargs": {"temperature": 0, "max_tokens": 2048},
|
31
|
+
# Update the authentication method as needed
|
32
|
+
"auth_type": "SECURITY_TOKEN",
|
33
|
+
"auth_profile": "DEFAULT",
|
34
|
+
# You may need to specify `service_endpoint` if the service is in a different region.
|
35
|
+
},
|
36
|
+
}
|
37
|
+
|
38
|
+
Following is an example config for OCI Data Science Model Deployment:
|
39
|
+
{
|
40
|
+
"model_client_cls": "LangChainModelClient",
|
41
|
+
"langchain_cls": "ads.llm.ChatOCIModelDeploymentVLLM",
|
42
|
+
"model": "odsc-llm",
|
43
|
+
"endpoint": "https://MODEL_DEPLOYMENT_URL/predict",
|
44
|
+
"model_kwargs": {"temperature": 0.1, "max_tokens": 2048},
|
45
|
+
# function_call_params will only be added to the API call when function/tools are added.
|
46
|
+
"function_call_params": {
|
47
|
+
"tool_choice": "auto",
|
48
|
+
"chat_template": ChatTemplates.mistral(),
|
49
|
+
},
|
50
|
+
}
|
51
|
+
|
52
|
+
Note that if `client_params` is not specified in the config, all arguments from the config except
|
53
|
+
`model_client_cls` and `langchain_cls`, and `function_call_params`, will be used to initialize
|
54
|
+
the LangChain chat model.
|
55
|
+
|
56
|
+
The `function_call_params` will only be used for function/tool calling when tools are specified.
|
57
|
+
|
58
|
+
To register the custom client:
|
59
|
+
|
60
|
+
from ads.llm.autogen.client_v02 import LangChainModelClient, register_custom_client
|
61
|
+
register_custom_client(LangChainModelClient)
|
62
|
+
|
63
|
+
Once registered with ADS, the custom LLM class will be auto-registered for all new agents.
|
64
|
+
There is no need to call `register_model_client()` on each agent.
|
65
|
+
|
66
|
+
References:
|
67
|
+
https://microsoft.github.io/autogen/0.2/docs/notebooks/agentchat_huggingface_langchain/
|
68
|
+
https://github.com/microsoft/autogen/blob/0.2/notebook/agentchat_custom_model.ipynb
|
69
|
+
|
70
|
+
"""
|
71
|
+
import copy
|
72
|
+
import importlib
|
73
|
+
import json
|
74
|
+
import logging
|
75
|
+
from typing import Any, Dict, List, Union
|
76
|
+
from types import SimpleNamespace
|
77
|
+
|
78
|
+
from autogen import ModelClient
|
79
|
+
from autogen.oai.client import OpenAIWrapper, PlaceHolderClient
|
80
|
+
from langchain_core.messages import AIMessage
|
81
|
+
|
82
|
+
|
83
|
+
logger = logging.getLogger(__name__)
|
84
|
+
|
85
|
+
# custom_clients is a dictionary mapping the name of the class to the actual class
|
86
|
+
custom_clients = {}
|
87
|
+
|
88
|
+
# There is a bug in GroupChat when using custom client:
|
89
|
+
# https://github.com/microsoft/autogen/issues/2956
|
90
|
+
# Here we will be patching the OpenAIWrapper to fix the issue.
|
91
|
+
# With this patch, you only need to register the client once with ADS.
|
92
|
+
# For example:
|
93
|
+
#
|
94
|
+
# from ads.llm.autogen.client_v02 import LangChainModelClient, register_custom_client
|
95
|
+
# register_custom_client(LangChainModelClient)
|
96
|
+
#
|
97
|
+
# This patch will auto-register the custom LLM to all new agents.
|
98
|
+
# So there is no need to call `register_model_client()` on each agent.
|
99
|
+
OpenAIWrapper._original_register_default_client = OpenAIWrapper._register_default_client
|
100
|
+
|
101
|
+
|
102
|
+
def _new_register_default_client(
|
103
|
+
self: OpenAIWrapper, config: Dict[str, Any], openai_config: Dict[str, Any]
|
104
|
+
) -> None:
|
105
|
+
"""This is a patched version of the _register_default_client() method
|
106
|
+
to automatically register custom client for agents.
|
107
|
+
"""
|
108
|
+
model_client_cls_name = config.get("model_client_cls")
|
109
|
+
if model_client_cls_name in custom_clients:
|
110
|
+
self._clients.append(PlaceHolderClient(config))
|
111
|
+
self.register_model_client(custom_clients[model_client_cls_name])
|
112
|
+
else:
|
113
|
+
self._original_register_default_client(
|
114
|
+
config=config, openai_config=openai_config
|
115
|
+
)
|
116
|
+
|
117
|
+
|
118
|
+
# Patch the _register_default_client() method
|
119
|
+
OpenAIWrapper._register_default_client = _new_register_default_client
|
120
|
+
|
121
|
+
|
122
|
+
def register_custom_client(client_class):
|
123
|
+
"""Registers custom client for AutoGen."""
|
124
|
+
if client_class.__name__ not in custom_clients:
|
125
|
+
custom_clients[client_class.__name__] = client_class
|
126
|
+
|
127
|
+
|
128
|
+
def _convert_to_langchain_tool(tool):
|
129
|
+
"""Converts the OpenAI tool spec to LangChain tool spec."""
|
130
|
+
if tool["type"] == "function":
|
131
|
+
tool = tool["function"]
|
132
|
+
required = tool["parameters"].get("required", [])
|
133
|
+
properties = copy.deepcopy(tool["parameters"]["properties"])
|
134
|
+
for key in properties.keys():
|
135
|
+
val = properties[key]
|
136
|
+
val["default"] = key in required
|
137
|
+
return {
|
138
|
+
"title": tool["name"],
|
139
|
+
"description": tool["description"],
|
140
|
+
"properties": properties,
|
141
|
+
}
|
142
|
+
raise NotImplementedError(f"Type {tool['type']} is not supported.")
|
143
|
+
|
144
|
+
|
145
|
+
def _convert_to_openai_tool_call(tool_call):
|
146
|
+
"""Converts the LangChain tool call in AI message to OpenAI tool call."""
|
147
|
+
return {
|
148
|
+
"id": tool_call.get("id"),
|
149
|
+
"function": {
|
150
|
+
"name": tool_call.get("name"),
|
151
|
+
"arguments": (
|
152
|
+
""
|
153
|
+
if tool_call.get("args") is None
|
154
|
+
else json.dumps(tool_call.get("args"))
|
155
|
+
),
|
156
|
+
},
|
157
|
+
"type": "function",
|
158
|
+
}
|
159
|
+
|
160
|
+
|
161
|
+
class Message(AIMessage):
|
162
|
+
"""Represents message returned from the LLM."""
|
163
|
+
|
164
|
+
@classmethod
|
165
|
+
def from_message(cls, message: AIMessage):
|
166
|
+
"""Converts from LangChain AIMessage."""
|
167
|
+
message = copy.deepcopy(message)
|
168
|
+
message.__class__ = cls
|
169
|
+
message.tool_calls = [
|
170
|
+
_convert_to_openai_tool_call(tool) for tool in message.tool_calls
|
171
|
+
]
|
172
|
+
return message
|
173
|
+
|
174
|
+
@property
|
175
|
+
def function_call(self):
|
176
|
+
"""Function calls."""
|
177
|
+
return self.tool_calls
|
178
|
+
|
179
|
+
|
180
|
+
class LangChainModelClient(ModelClient):
|
181
|
+
"""Represents a model client wrapping a LangChain chat model."""
|
182
|
+
|
183
|
+
def __init__(self, config: dict, **kwargs) -> None:
|
184
|
+
super().__init__()
|
185
|
+
logger.info("LangChain model client config: %s", str(config))
|
186
|
+
# Make a copy of the config since we are popping some keys
|
187
|
+
config = copy.deepcopy(config)
|
188
|
+
# model_client_cls will always be LangChainModelClient
|
189
|
+
self.client_class = config.pop("model_client_cls")
|
190
|
+
|
191
|
+
# model_name is used in constructing the response.
|
192
|
+
self.model_name = config.get("model", "")
|
193
|
+
|
194
|
+
# If the config specified function_call_params,
|
195
|
+
# Pop the params and use them only for tool calling.
|
196
|
+
self.function_call_params = config.pop("function_call_params", {})
|
197
|
+
|
198
|
+
# If the config specified invoke_params,
|
199
|
+
# Pop the params and use them only for invoking.
|
200
|
+
self.invoke_params = config.pop("invoke_params", {})
|
201
|
+
|
202
|
+
# Import the LangChain class
|
203
|
+
if "langchain_cls" not in config:
|
204
|
+
raise ValueError("Missing langchain_cls in LangChain Model Client config.")
|
205
|
+
module_cls = config.pop("langchain_cls")
|
206
|
+
module_name, cls_name = str(module_cls).rsplit(".", 1)
|
207
|
+
langchain_module = importlib.import_module(module_name)
|
208
|
+
langchain_cls = getattr(langchain_module, cls_name)
|
209
|
+
|
210
|
+
# If the config specified client_params,
|
211
|
+
# Only use the client_params to initialize the LangChain model.
|
212
|
+
# Otherwise, use the config
|
213
|
+
self.client_params = config.get("client_params", config)
|
214
|
+
|
215
|
+
# Initialize the LangChain client
|
216
|
+
self.model = langchain_cls(**self.client_params)
|
217
|
+
|
218
|
+
def create(self, params) -> ModelClient.ModelClientResponseProtocol:
|
219
|
+
"""Creates a LLM completion for a given config.
|
220
|
+
|
221
|
+
Parameters
|
222
|
+
----------
|
223
|
+
params : dict
|
224
|
+
OpenAI API compatible parameters, including all the keys from llm_config.
|
225
|
+
|
226
|
+
Returns
|
227
|
+
-------
|
228
|
+
ModelClientResponseProtocol
|
229
|
+
Response from LLM
|
230
|
+
|
231
|
+
"""
|
232
|
+
streaming = params.get("stream", False)
|
233
|
+
# TODO: num_of_responses
|
234
|
+
num_of_responses = params.get("n", 1)
|
235
|
+
messages = params.pop("messages", [])
|
236
|
+
|
237
|
+
invoke_params = copy.deepcopy(self.invoke_params)
|
238
|
+
|
239
|
+
tools = params.get("tools")
|
240
|
+
if tools:
|
241
|
+
model = self.model.bind_tools(
|
242
|
+
[_convert_to_langchain_tool(tool) for tool in tools]
|
243
|
+
)
|
244
|
+
# invoke_params["tools"] = tools
|
245
|
+
invoke_params.update(self.function_call_params)
|
246
|
+
else:
|
247
|
+
model = self.model
|
248
|
+
|
249
|
+
response = SimpleNamespace()
|
250
|
+
response.choices = []
|
251
|
+
response.model = self.model_name
|
252
|
+
|
253
|
+
if streaming and messages:
|
254
|
+
# If streaming is enabled and has messages, then iterate over the chunks of the response.
|
255
|
+
raise NotImplementedError()
|
256
|
+
else:
|
257
|
+
# If streaming is not enabled, send a regular chat completion request
|
258
|
+
ai_message = model.invoke(messages, **invoke_params)
|
259
|
+
choice = SimpleNamespace()
|
260
|
+
choice.message = Message.from_message(ai_message)
|
261
|
+
response.choices.append(choice)
|
262
|
+
return response
|
263
|
+
|
264
|
+
def message_retrieval(
|
265
|
+
self, response: ModelClient.ModelClientResponseProtocol
|
266
|
+
) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
|
267
|
+
"""
|
268
|
+
Retrieve and return a list of strings or a list of Choice.Message from the response.
|
269
|
+
|
270
|
+
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
|
271
|
+
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
|
272
|
+
"""
|
273
|
+
return [choice.message for choice in response.choices]
|
274
|
+
|
275
|
+
def cost(self, response: ModelClient.ModelClientResponseProtocol) -> float:
|
276
|
+
response.cost = 0
|
277
|
+
return 0
|
278
|
+
|
279
|
+
@staticmethod
|
280
|
+
def get_usage(response: ModelClient.ModelClientResponseProtocol) -> Dict:
|
281
|
+
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
|
282
|
+
return {}
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: oracle_ads
|
3
|
-
Version: 2.12.
|
3
|
+
Version: 2.12.7
|
4
4
|
Summary: Oracle Accelerated Data Science SDK
|
5
5
|
Keywords: Oracle Cloud Infrastructure,OCI,Machine Learning,ML,Artificial Intelligence,AI,Data Science,Cloud,Oracle
|
6
6
|
Author: Oracle Data Science
|
@@ -109,7 +109,7 @@ Requires-Dist: py-cpuinfo ; extra == "opctl"
|
|
109
109
|
Requires-Dist: rich ; extra == "opctl"
|
110
110
|
Requires-Dist: fire ; extra == "opctl"
|
111
111
|
Requires-Dist: cachetools ; extra == "opctl"
|
112
|
-
Requires-Dist: huggingface_hub==0.
|
112
|
+
Requires-Dist: huggingface_hub==0.26.2 ; extra == "opctl"
|
113
113
|
Requires-Dist: optuna==2.9.0 ; extra == "optuna"
|
114
114
|
Requires-Dist: oracle_ads[viz] ; extra == "optuna"
|
115
115
|
Requires-Dist: aiohttp ; extra == "pii"
|
@@ -130,6 +130,7 @@ Requires-Dist: pyspark>=3.0.0 ; extra == "spark"
|
|
130
130
|
Requires-Dist: oracle_ads[viz] ; extra == "tensorflow"
|
131
131
|
Requires-Dist: tensorflow<=2.15.1 ; extra == "tensorflow"
|
132
132
|
Requires-Dist: arff ; extra == "testsuite"
|
133
|
+
Requires-Dist: autogen-agentchat~=0.2 ; extra == "testsuite"
|
133
134
|
Requires-Dist: category_encoders==2.6.3 ; extra == "testsuite"
|
134
135
|
Requires-Dist: cohere==4.53 ; extra == "testsuite"
|
135
136
|
Requires-Dist: faiss-cpu ; extra == "testsuite"
|
@@ -4,7 +4,7 @@ ads/config.py,sha256=WGFgS5-dxqC9_iRJKakn-mh9545gHJpWB_Y0hT5O3ec,8016
|
|
4
4
|
ads/aqua/__init__.py,sha256=IUKZAsxUGVicsyeSwsGwK6rAUJ1vIUW9ywduA3U22xc,1015
|
5
5
|
ads/aqua/app.py,sha256=BQuQ9RERU0rKmn3N3xicKzYaXOd7xBwX1aVuVLNgw98,11993
|
6
6
|
ads/aqua/cli.py,sha256=W-0kswzRDEilqHyw5GSMOrARgvOyPRtkEtpy54ew0Jo,3907
|
7
|
-
ads/aqua/constants.py,sha256=
|
7
|
+
ads/aqua/constants.py,sha256=fTPrRuWaZB1_THZ2I1nOrwW1pQGpvMC44--Ok5Myr5Y,2978
|
8
8
|
ads/aqua/data.py,sha256=7T7kdHGnEH6FXL_7jv_Da0CjEWXfjQZTFkaZWQikis4,932
|
9
9
|
ads/aqua/ui.py,sha256=hGl4btUsMImkpzZ-Ae_WVVaRqfpdG_gUeHKD9E1nKbE,26195
|
10
10
|
ads/aqua/common/__init__.py,sha256=rZrmh1nho40OCeabXCNWtze-mXi-PGKetcZdxZSn3_0,204
|
@@ -12,7 +12,7 @@ ads/aqua/common/decorator.py,sha256=JEN6Cy4DYgQbmIR3ShCjTuBMCnilDxq7jkYMJse1rcM,
|
|
12
12
|
ads/aqua/common/entities.py,sha256=UsP8CczuifLOLr_gAhulh8VmgGSFir3rli1MMQ-CZhk,537
|
13
13
|
ads/aqua/common/enums.py,sha256=HnaraHfkYmuqC5mEF7gyvQmqbOc6r_9EI2MF-cieb5o,2991
|
14
14
|
ads/aqua/common/errors.py,sha256=Ev2xbaqkDqeCYDx4ZgOKOoM0sXsOXP3GIV6N1lhIUxM,3085
|
15
|
-
ads/aqua/common/utils.py,sha256=
|
15
|
+
ads/aqua/common/utils.py,sha256=ipWRenYo3x_N9QN9pyverZXfxxd9fBIk4acmpZclwzY,37516
|
16
16
|
ads/aqua/config/__init__.py,sha256=2a_1LI4jWtJpbic5_v4EoOUTXCAH7cmsy9BW5prDHjU,179
|
17
17
|
ads/aqua/config/config.py,sha256=MNY4ttccaQdhxUyS1o367YIDl-U_AiSLVlgvzSd7JE4,944
|
18
18
|
ads/aqua/config/evaluation/__init__.py,sha256=2a_1LI4jWtJpbic5_v4EoOUTXCAH7cmsy9BW5prDHjU,179
|
@@ -26,9 +26,9 @@ ads/aqua/dummy_data/oci_models.json,sha256=mxUU8o3plmAFfr06fQmIQuiGe2qFFBlUB7QNP
|
|
26
26
|
ads/aqua/dummy_data/readme.md,sha256=AlBPt0HBSOFA5HbYVsFsdTm-BC3R5NRpcKrTxdjEnlI,1256
|
27
27
|
ads/aqua/evaluation/__init__.py,sha256=Fd7WL7MpQ1FtJjlftMY2KHli5cz1wr5MDu3hGmV89a0,298
|
28
28
|
ads/aqua/evaluation/constants.py,sha256=GvcXvPIw-VDKw4a8WNKs36uWdT-f7VJrWSpnnRnthGg,1533
|
29
|
-
ads/aqua/evaluation/entities.py,sha256=
|
29
|
+
ads/aqua/evaluation/entities.py,sha256=OqD2AfCO31ZO88hfORsjLdmJRqOjZrep2zVESEj6qJc,5488
|
30
30
|
ads/aqua/evaluation/errors.py,sha256=qzR63YEIA8haCh4HcBHFFm7j4g6jWDfGszqrPkXx9zQ,4564
|
31
|
-
ads/aqua/evaluation/evaluation.py,sha256=
|
31
|
+
ads/aqua/evaluation/evaluation.py,sha256=UGo6Ly148qw3br1tNo-fagvyipDi4P-2AEZ8T4m6GR4,57856
|
32
32
|
ads/aqua/extension/__init__.py,sha256=mRArjU6UZpZYVr0qHSSkPteA_CKcCZIczOFaK421m9o,1453
|
33
33
|
ads/aqua/extension/aqua_ws_msg_handler.py,sha256=soSRnIFx93JCFf6HsuF_BQEpJ2mre-IVQDUDKUKPijY,3392
|
34
34
|
ads/aqua/extension/base_handler.py,sha256=Zbb-uSNLljRU5NPOndn3_lx8MN_1yxlF2GHVpBT-kWk,5233
|
@@ -40,7 +40,7 @@ ads/aqua/extension/errors.py,sha256=ojDolyr3_0UCCwKqPtiZZyMQuX35jr8h8MQRP6HcBs4,
|
|
40
40
|
ads/aqua/extension/evaluation_handler.py,sha256=fJH73fa0xmkEiP8SxKL4A4dJgj-NoL3z_G-w_WW2zJs,4353
|
41
41
|
ads/aqua/extension/evaluation_ws_msg_handler.py,sha256=dv0iwOSTxYj1kQ1rPEoDmGgFBzLUCLXq5h7rpmY2T1M,2098
|
42
42
|
ads/aqua/extension/finetune_handler.py,sha256=abiDXNhkhtoV9hrYhCzwhDjdQKlqQ_KSqxKWntkvh3E,3288
|
43
|
-
ads/aqua/extension/model_handler.py,sha256=
|
43
|
+
ads/aqua/extension/model_handler.py,sha256=Ec7NiU3Xvp_sZEvCvN6aVqeoiFrOpJMhDI5xtP_pSuw,10612
|
44
44
|
ads/aqua/extension/models_ws_msg_handler.py,sha256=3CPfzWl1xfrE2Dpn_WYP9zY0kY5zlsAE8tU_6Y2-i18,1801
|
45
45
|
ads/aqua/extension/ui_handler.py,sha256=3TibTMeqcsSWfPsorspFrhIV0PRh8_4FoWpudycT80g,10664
|
46
46
|
ads/aqua/extension/ui_websocket_handler.py,sha256=oLFjaDrqkSERbhExdvxjLJX0oRcP-DVJ_aWn0qy0uvo,5084
|
@@ -53,9 +53,9 @@ ads/aqua/finetuning/entities.py,sha256=S7Ll_0WyWGh23my-6ow3vwHLDZqTel8CMCoE9oLow
|
|
53
53
|
ads/aqua/finetuning/finetuning.py,sha256=mwKl8KA2Artp0dXzjXxxKn_UBnkYpNXMYN7ykrZcyEM,25145
|
54
54
|
ads/aqua/model/__init__.py,sha256=j2iylvERdANxgrEDp7b_mLcKMz1CF5Go0qgYCiMwdos,278
|
55
55
|
ads/aqua/model/constants.py,sha256=H239zDu3koa3UTdw-uQveXHX2NDwidclVcS4QIrCTJo,1593
|
56
|
-
ads/aqua/model/entities.py,sha256=
|
56
|
+
ads/aqua/model/entities.py,sha256=wv1j18OG8NrmKLwIevyJ1ZVw965n3_3titOfwqyzlI8,9765
|
57
57
|
ads/aqua/model/enums.py,sha256=t8GbK2nblIPm3gClR8W31RmbtTuqpoSzoN4W3JfD6AI,1004
|
58
|
-
ads/aqua/model/model.py,sha256=
|
58
|
+
ads/aqua/model/model.py,sha256=pFG4lkaqtovSpiu3BOCGT7scMtXt4rwup9Rof6Hl_CU,63908
|
59
59
|
ads/aqua/modeldeployment/__init__.py,sha256=RJCfU1yazv3hVWi5rS08QVLTpTwZLnlC8wU8diwFjnM,391
|
60
60
|
ads/aqua/modeldeployment/constants.py,sha256=lJF77zwxmlECljDYjwFAMprAUR_zctZHmawiP-4alLg,296
|
61
61
|
ads/aqua/modeldeployment/deployment.py,sha256=8qx4cxzuln5FZpAXTZlvaHCio2fzFJxO4PrrAS1_b6A,30652
|
@@ -450,6 +450,8 @@ ads/llm/chat_template.py,sha256=t2QRfLLR_c_cq3JqABghWqiCSWjjuVc_mfEN-yVYG10,934
|
|
450
450
|
ads/llm/deploy.py,sha256=5oZipFWU6q_9dCyt3WE4ic-n5rNZgQsYU_3lS_Vp_nY,2275
|
451
451
|
ads/llm/requirements.txt,sha256=vaVwhWCteqmo0fRsEk6M8S1LQMjULU_Bt_syBAa2G-s,55
|
452
452
|
ads/llm/serialize.py,sha256=WjQNMPACyR8nIh1dB7BLFUmqUrumld6vt91lg1DWzWI,7281
|
453
|
+
ads/llm/autogen/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
454
|
+
ads/llm/autogen/client_v02.py,sha256=-8fH-u769txu9eCfGi8XDkQ09DMPl5cCOmmywOFUguc,11127
|
453
455
|
ads/llm/guardrails/__init__.py,sha256=sAqmLhogrLXb3xI7dPOj9HmSkpTnLh9wkzysuGd8AXk,204
|
454
456
|
ads/llm/guardrails/base.py,sha256=scli_YSqDbArIJW5sA5PLjCd6G8_-dNUcpTybvQvZnk,16468
|
455
457
|
ads/llm/guardrails/huggingface.py,sha256=4DFanCYb3R1SKYSFdcEyGH2ywQgf2yFDDZGJtOcoph0,1304
|
@@ -813,8 +815,8 @@ ads/type_discovery/unknown_detector.py,sha256=yZuYQReO7PUyoWZE7onhhtYaOg6088wf1y
|
|
813
815
|
ads/type_discovery/zipcode_detector.py,sha256=3AlETg_ZF4FT0u914WXvTT3F3Z6Vf51WiIt34yQMRbw,1421
|
814
816
|
ads/vault/__init__.py,sha256=x9tMdDAOdF5iDHk9u2di_K-ze5Nq068x25EWOBoWwqY,245
|
815
817
|
ads/vault/vault.py,sha256=hFBkpYE-Hfmzu1L0sQwUfYcGxpWmgG18JPndRl0NOXI,8624
|
816
|
-
oracle_ads-2.12.
|
817
|
-
oracle_ads-2.12.
|
818
|
-
oracle_ads-2.12.
|
819
|
-
oracle_ads-2.12.
|
820
|
-
oracle_ads-2.12.
|
818
|
+
oracle_ads-2.12.7.dist-info/entry_points.txt,sha256=9VFnjpQCsMORA4rVkvN8eH6D3uHjtegb9T911t8cqV0,35
|
819
|
+
oracle_ads-2.12.7.dist-info/LICENSE.txt,sha256=zoGmbfD1IdRKx834U0IzfFFFo5KoFK71TND3K9xqYqo,1845
|
820
|
+
oracle_ads-2.12.7.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
|
821
|
+
oracle_ads-2.12.7.dist-info/METADATA,sha256=npukk9HNdJhLDD1g3tKxGQ8AFtlqGADZzKovEmgm_u0,16282
|
822
|
+
oracle_ads-2.12.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|