oracle-ads 2.12.5__py3-none-any.whl → 2.12.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/common/decorator.py +10 -0
- ads/aqua/common/utils.py +4 -1
- ads/aqua/constants.py +1 -0
- ads/aqua/evaluation/entities.py +14 -4
- ads/aqua/evaluation/evaluation.py +2 -6
- ads/aqua/extension/aqua_ws_msg_handler.py +2 -0
- ads/aqua/extension/base_handler.py +2 -0
- ads/aqua/extension/model_handler.py +4 -0
- ads/aqua/finetuning/constants.py +3 -0
- ads/aqua/finetuning/finetuning.py +13 -2
- ads/aqua/model/entities.py +2 -0
- ads/aqua/model/model.py +25 -19
- ads/llm/autogen/__init__.py +0 -0
- ads/llm/autogen/client_v02.py +282 -0
- ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +6 -5
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +12 -8
- ads/opctl/operator/lowcode/anomaly/model/autots.py +6 -3
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +19 -7
- ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +9 -10
- ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +10 -11
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +6 -2
- ads/opctl/operator/lowcode/common/data.py +13 -11
- ads/opctl/operator/lowcode/forecast/model/arima.py +14 -12
- ads/opctl/operator/lowcode/forecast/model/automlx.py +26 -26
- ads/opctl/operator/lowcode/forecast/model/autots.py +16 -18
- ads/opctl/operator/lowcode/forecast/model/base_model.py +45 -36
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +36 -47
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +3 -0
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +30 -46
- ads/opctl/operator/lowcode/forecast/model/prophet.py +15 -20
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +25 -20
- ads/opctl/operator/lowcode/forecast/utils.py +30 -33
- ads/opctl/operator/lowcode/pii/model/report.py +11 -7
- ads/opctl/operator/lowcode/recommender/model/base_model.py +58 -45
- ads/opctl/operator/lowcode/recommender/model/svd.py +47 -29
- {oracle_ads-2.12.5.dist-info → oracle_ads-2.12.7.dist-info}/METADATA +7 -6
- {oracle_ads-2.12.5.dist-info → oracle_ads-2.12.7.dist-info}/RECORD +40 -38
- {oracle_ads-2.12.5.dist-info → oracle_ads-2.12.7.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.5.dist-info → oracle_ads-2.12.7.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.5.dist-info → oracle_ads-2.12.7.dist-info}/entry_points.txt +0 -0
ads/aqua/common/decorator.py
CHANGED
@@ -69,6 +69,16 @@ def handle_exceptions(func):
|
|
69
69
|
reason=error.message,
|
70
70
|
service_payload=error.args[0] if error.args else None,
|
71
71
|
exc_info=sys.exc_info(),
|
72
|
+
aqua_api_details=dict(
|
73
|
+
# __qualname__ gives information of class and name of api
|
74
|
+
aqua_api_name=func.__qualname__,
|
75
|
+
oci_api_name=getattr(
|
76
|
+
error, "operation_name", "Unknown OCI Operation"
|
77
|
+
),
|
78
|
+
service_endpoint=getattr(
|
79
|
+
error, "request_endpoint", "Unknown Request Endpoint"
|
80
|
+
)
|
81
|
+
)
|
72
82
|
)
|
73
83
|
except (
|
74
84
|
ClientError,
|
ads/aqua/common/utils.py
CHANGED
@@ -788,13 +788,14 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
|
|
788
788
|
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
|
789
789
|
|
790
790
|
|
791
|
-
def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
|
791
|
+
def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
|
792
792
|
"""Upload the local folder to the object storage
|
793
793
|
|
794
794
|
Args:
|
795
795
|
os_path (str): object storage URI with prefix. This is the path to upload
|
796
796
|
local_dir (str): Local directory where the object is downloaded
|
797
797
|
model_name (str): Name of the huggingface model
|
798
|
+
exclude_pattern (optional, str): The matching pattern of files to be excluded from uploading.
|
798
799
|
Retuns:
|
799
800
|
str: Object name inside the bucket
|
800
801
|
"""
|
@@ -804,6 +805,8 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
|
|
804
805
|
auth_state = AuthState()
|
805
806
|
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
|
806
807
|
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
|
808
|
+
if exclude_pattern:
|
809
|
+
command += f" --exclude {exclude_pattern}"
|
807
810
|
try:
|
808
811
|
logger.info(f"Running: {command}")
|
809
812
|
subprocess.check_call(shlex.split(command))
|
ads/aqua/constants.py
CHANGED
@@ -35,6 +35,7 @@ AQUA_MODEL_ARTIFACT_CONFIG = "config.json"
|
|
35
35
|
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
|
36
36
|
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
|
37
37
|
AQUA_MODEL_ARTIFACT_FILE = "model_file"
|
38
|
+
HF_METADATA_FOLDER = ".cache/"
|
38
39
|
HF_LOGIN_DEFAULT_TIMEOUT = 2
|
39
40
|
|
40
41
|
TRAINING_METRICS_FINAL = "training_metrics_final"
|
ads/aqua/evaluation/entities.py
CHANGED
@@ -9,11 +9,12 @@ aqua.evaluation.entities
|
|
9
9
|
This module contains dataclasses for aqua evaluation.
|
10
10
|
"""
|
11
11
|
|
12
|
+
from typing import Any, Dict, List, Optional
|
13
|
+
|
12
14
|
from pydantic import Field
|
13
|
-
from typing import Any, Dict, List, Optional, Union
|
14
15
|
|
15
|
-
from ads.aqua.data import AquaResourceIdentifier
|
16
16
|
from ads.aqua.config.utils.serializer import Serializable
|
17
|
+
from ads.aqua.data import AquaResourceIdentifier
|
17
18
|
|
18
19
|
|
19
20
|
class CreateAquaEvaluationDetails(Serializable):
|
@@ -82,11 +83,13 @@ class CreateAquaEvaluationDetails(Serializable):
|
|
82
83
|
ocpus: Optional[float] = None
|
83
84
|
log_group_id: Optional[str] = None
|
84
85
|
log_id: Optional[str] = None
|
85
|
-
metrics: Optional[List[str]] = None
|
86
|
+
metrics: Optional[List[Dict[str, Any]]] = None
|
86
87
|
force_overwrite: Optional[bool] = False
|
87
88
|
|
88
89
|
class Config:
|
89
90
|
extra = "ignore"
|
91
|
+
protected_namespaces = ()
|
92
|
+
|
90
93
|
|
91
94
|
class AquaEvalReport(Serializable):
|
92
95
|
evaluation_id: str = ""
|
@@ -95,6 +98,7 @@ class AquaEvalReport(Serializable):
|
|
95
98
|
class Config:
|
96
99
|
extra = "ignore"
|
97
100
|
|
101
|
+
|
98
102
|
class AquaEvalParams(Serializable):
|
99
103
|
shape: str = ""
|
100
104
|
dataset_path: str = ""
|
@@ -103,6 +107,7 @@ class AquaEvalParams(Serializable):
|
|
103
107
|
class Config:
|
104
108
|
extra = "allow"
|
105
109
|
|
110
|
+
|
106
111
|
class AquaEvalMetric(Serializable):
|
107
112
|
key: str
|
108
113
|
name: str
|
@@ -111,6 +116,7 @@ class AquaEvalMetric(Serializable):
|
|
111
116
|
class Config:
|
112
117
|
extra = "ignore"
|
113
118
|
|
119
|
+
|
114
120
|
class AquaEvalMetricSummary(Serializable):
|
115
121
|
metric: str = ""
|
116
122
|
score: str = ""
|
@@ -119,6 +125,7 @@ class AquaEvalMetricSummary(Serializable):
|
|
119
125
|
class Config:
|
120
126
|
extra = "ignore"
|
121
127
|
|
128
|
+
|
122
129
|
class AquaEvalMetrics(Serializable):
|
123
130
|
id: str
|
124
131
|
report: str
|
@@ -128,17 +135,19 @@ class AquaEvalMetrics(Serializable):
|
|
128
135
|
class Config:
|
129
136
|
extra = "ignore"
|
130
137
|
|
138
|
+
|
131
139
|
class AquaEvaluationCommands(Serializable):
|
132
140
|
evaluation_id: str
|
133
141
|
evaluation_target_id: str
|
134
142
|
input_data: Dict[str, Any]
|
135
|
-
metrics: List[str]
|
143
|
+
metrics: List[Dict[str, Any]]
|
136
144
|
output_dir: str
|
137
145
|
params: Dict[str, Any]
|
138
146
|
|
139
147
|
class Config:
|
140
148
|
extra = "ignore"
|
141
149
|
|
150
|
+
|
142
151
|
class AquaEvaluationSummary(Serializable):
|
143
152
|
"""Represents a summary of Aqua evalution."""
|
144
153
|
|
@@ -157,6 +166,7 @@ class AquaEvaluationSummary(Serializable):
|
|
157
166
|
class Config:
|
158
167
|
extra = "ignore"
|
159
168
|
|
169
|
+
|
160
170
|
class AquaEvaluationDetail(AquaEvaluationSummary):
|
161
171
|
"""Represents a details of Aqua evalution."""
|
162
172
|
|
@@ -159,7 +159,8 @@ class AquaEvaluationApp(AquaApp):
|
|
159
159
|
create_aqua_evaluation_details = CreateAquaEvaluationDetails(**kwargs)
|
160
160
|
except Exception as ex:
|
161
161
|
custom_errors = {
|
162
|
-
".".join(map(str, e["loc"])): e["msg"]
|
162
|
+
".".join(map(str, e["loc"])): e["msg"]
|
163
|
+
for e in json.loads(ex.json())
|
163
164
|
}
|
164
165
|
raise AquaValueError(
|
165
166
|
f"Invalid create evaluation parameters. Error details: {custom_errors}."
|
@@ -619,11 +620,6 @@ class AquaEvaluationApp(AquaApp):
|
|
619
620
|
evaluation_id=evaluation_id,
|
620
621
|
evaluation_target_id=evaluation_source_id,
|
621
622
|
input_data={
|
622
|
-
"columns": {
|
623
|
-
"prompt": "prompt",
|
624
|
-
"completion": "completion",
|
625
|
-
"category": "category",
|
626
|
-
},
|
627
623
|
"format": Path(dataset_path).suffix,
|
628
624
|
"url": dataset_path,
|
629
625
|
},
|
@@ -78,10 +78,12 @@ class AquaWSMsgHandler:
|
|
78
78
|
logger.warning(reply["message"])
|
79
79
|
# telemetry may not be present if there is an error while initializing
|
80
80
|
if hasattr(self, "telemetry"):
|
81
|
+
aqua_api_details = kwargs.get("aqua_api_details", {})
|
81
82
|
self.telemetry.record_event_async(
|
82
83
|
category="aqua/error",
|
83
84
|
action=str(status_code),
|
84
85
|
value=reason,
|
86
|
+
**aqua_api_details
|
85
87
|
)
|
86
88
|
response = AquaWsError(
|
87
89
|
status=status_code,
|
@@ -98,10 +98,12 @@ class AquaAPIhandler(APIHandler):
|
|
98
98
|
|
99
99
|
# telemetry may not be present if there is an error while initializing
|
100
100
|
if hasattr(self, "telemetry"):
|
101
|
+
aqua_api_details = kwargs.get("aqua_api_details", {})
|
101
102
|
self.telemetry.record_event_async(
|
102
103
|
category="aqua/error",
|
103
104
|
action=str(status_code),
|
104
105
|
value=reason,
|
106
|
+
**aqua_api_details
|
105
107
|
)
|
106
108
|
|
107
109
|
self.finish(json.dumps(reply))
|
@@ -129,6 +129,8 @@ class AquaModelHandler(AquaAPIhandler):
|
|
129
129
|
str(input_data.get("download_from_hf", "false")).lower() == "true"
|
130
130
|
)
|
131
131
|
inference_container_uri = input_data.get("inference_container_uri")
|
132
|
+
allow_patterns = input_data.get("allow_patterns")
|
133
|
+
ignore_patterns = input_data.get("ignore_patterns")
|
132
134
|
|
133
135
|
return self.finish(
|
134
136
|
AquaModelApp().register(
|
@@ -141,6 +143,8 @@ class AquaModelHandler(AquaAPIhandler):
|
|
141
143
|
project_id=project_id,
|
142
144
|
model_file=model_file,
|
143
145
|
inference_container_uri=inference_container_uri,
|
146
|
+
allow_patterns=allow_patterns,
|
147
|
+
ignore_patterns=ignore_patterns,
|
144
148
|
)
|
145
149
|
)
|
146
150
|
|
ads/aqua/finetuning/constants.py
CHANGED
@@ -15,3 +15,6 @@ class FineTuneCustomMetadata(str, metaclass=ExtendedEnumMeta):
|
|
15
15
|
SERVICE_MODEL_ARTIFACT_LOCATION = "artifact_location"
|
16
16
|
SERVICE_MODEL_DEPLOYMENT_CONTAINER = "deployment-container"
|
17
17
|
SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container"
|
18
|
+
|
19
|
+
|
20
|
+
ENV_AQUA_FINE_TUNING_CONTAINER = "AQUA_FINE_TUNING_CONTAINER"
|
@@ -31,7 +31,10 @@ from ads.aqua.constants import (
|
|
31
31
|
UNKNOWN_DICT,
|
32
32
|
)
|
33
33
|
from ads.aqua.data import AquaResourceIdentifier
|
34
|
-
from ads.aqua.finetuning.constants import
|
34
|
+
from ads.aqua.finetuning.constants import (
|
35
|
+
ENV_AQUA_FINE_TUNING_CONTAINER,
|
36
|
+
FineTuneCustomMetadata,
|
37
|
+
)
|
35
38
|
from ads.aqua.finetuning.entities import *
|
36
39
|
from ads.common.auth import default_signer
|
37
40
|
from ads.common.object_storage_details import ObjectStorageDetails
|
@@ -310,6 +313,15 @@ class AquaFineTuningApp(AquaApp):
|
|
310
313
|
except Exception:
|
311
314
|
pass
|
312
315
|
|
316
|
+
if not is_custom_container and ENV_AQUA_FINE_TUNING_CONTAINER in os.environ:
|
317
|
+
ft_container = os.environ[ENV_AQUA_FINE_TUNING_CONTAINER]
|
318
|
+
logger.info(
|
319
|
+
"Using container set by environment variable %s=%s",
|
320
|
+
ENV_AQUA_FINE_TUNING_CONTAINER,
|
321
|
+
ft_container,
|
322
|
+
)
|
323
|
+
is_custom_container = True
|
324
|
+
|
313
325
|
ft_parameters.batch_size = ft_parameters.batch_size or (
|
314
326
|
ft_config.get("shape", UNKNOWN_DICT)
|
315
327
|
.get(create_fine_tuning_details.shape_name, UNKNOWN_DICT)
|
@@ -559,7 +571,6 @@ class AquaFineTuningApp(AquaApp):
|
|
559
571
|
Dict:
|
560
572
|
A dict of allowed finetuning configs.
|
561
573
|
"""
|
562
|
-
|
563
574
|
config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG)
|
564
575
|
if not config:
|
565
576
|
logger.debug(
|
ads/aqua/model/entities.py
CHANGED
@@ -289,6 +289,8 @@ class ImportModelDetails(CLIBuilderMixin):
|
|
289
289
|
project_id: Optional[str] = None
|
290
290
|
model_file: Optional[str] = None
|
291
291
|
inference_container_uri: Optional[str] = None
|
292
|
+
allow_patterns: Optional[List[str]] = None
|
293
|
+
ignore_patterns: Optional[List[str]] = None
|
292
294
|
|
293
295
|
def __post_init__(self):
|
294
296
|
self._command = "model register"
|
ads/aqua/model/model.py
CHANGED
@@ -40,6 +40,7 @@ from ads.aqua.constants import (
|
|
40
40
|
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
|
41
41
|
AQUA_MODEL_ARTIFACT_FILE,
|
42
42
|
AQUA_MODEL_TYPE_CUSTOM,
|
43
|
+
HF_METADATA_FOLDER,
|
43
44
|
LICENSE_TXT,
|
44
45
|
MODEL_BY_REFERENCE_OSS_PATH_KEY,
|
45
46
|
README,
|
@@ -1274,6 +1275,8 @@ class AquaModelApp(AquaApp):
|
|
1274
1275
|
model_name: str,
|
1275
1276
|
os_path: str,
|
1276
1277
|
local_dir: str = None,
|
1278
|
+
allow_patterns: List[str] = None,
|
1279
|
+
ignore_patterns: List[str] = None,
|
1277
1280
|
) -> str:
|
1278
1281
|
"""This helper function downloads the model artifact from Hugging Face to a local folder, then uploads
|
1279
1282
|
to object storage location.
|
@@ -1283,6 +1286,12 @@ class AquaModelApp(AquaApp):
|
|
1283
1286
|
model_name (str): The huggingface model name.
|
1284
1287
|
os_path (str): The OS path where the model files are located.
|
1285
1288
|
local_dir (str): The local temp dir to store the huggingface model.
|
1289
|
+
allow_patterns (list): Model files matching at least one pattern are downloaded.
|
1290
|
+
Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`.
|
1291
|
+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
|
1292
|
+
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
|
1293
|
+
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
|
1294
|
+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
|
1286
1295
|
|
1287
1296
|
Returns
|
1288
1297
|
-------
|
@@ -1293,30 +1302,19 @@ class AquaModelApp(AquaApp):
|
|
1293
1302
|
if not local_dir:
|
1294
1303
|
local_dir = os.path.join(os.path.expanduser("~"), "cached-model")
|
1295
1304
|
local_dir = os.path.join(local_dir, model_name)
|
1296
|
-
retry = 10
|
1297
|
-
i = 0
|
1298
|
-
huggingface_download_err_message = None
|
1299
|
-
while i < retry:
|
1300
|
-
try:
|
1301
|
-
# Download to cache folder. The while loop retries when there is a network failure
|
1302
|
-
snapshot_download(repo_id=model_name)
|
1303
|
-
except Exception as e:
|
1304
|
-
huggingface_download_err_message = str(e)
|
1305
|
-
i += 1
|
1306
|
-
else:
|
1307
|
-
break
|
1308
|
-
if i == retry:
|
1309
|
-
raise Exception(
|
1310
|
-
f"Could not download the model {model_name} from https://huggingface.co with message {huggingface_download_err_message}"
|
1311
|
-
)
|
1312
1305
|
os.makedirs(local_dir, exist_ok=True)
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1306
|
+
snapshot_download(
|
1307
|
+
repo_id=model_name,
|
1308
|
+
local_dir=local_dir,
|
1309
|
+
allow_patterns=allow_patterns,
|
1310
|
+
ignore_patterns=ignore_patterns,
|
1311
|
+
)
|
1312
|
+
# Upload to object storage and skip .cache/huggingface/ folder
|
1316
1313
|
model_artifact_path = upload_folder(
|
1317
1314
|
os_path=os_path,
|
1318
1315
|
local_dir=local_dir,
|
1319
1316
|
model_name=model_name,
|
1317
|
+
exclude_pattern=f"{HF_METADATA_FOLDER}*"
|
1320
1318
|
)
|
1321
1319
|
|
1322
1320
|
return model_artifact_path
|
@@ -1335,6 +1333,12 @@ class AquaModelApp(AquaApp):
|
|
1335
1333
|
os_path (str): Object storage destination URI to store the downloaded model. Format: oci://bucket-name@namespace/prefix
|
1336
1334
|
inference_container (str): selects service defaults
|
1337
1335
|
finetuning_container (str): selects service defaults
|
1336
|
+
allow_patterns (list): Model files matching at least one pattern are downloaded.
|
1337
|
+
Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`.
|
1338
|
+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
|
1339
|
+
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
|
1340
|
+
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
|
1341
|
+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
|
1338
1342
|
|
1339
1343
|
Returns:
|
1340
1344
|
AquaModel:
|
@@ -1381,6 +1385,8 @@ class AquaModelApp(AquaApp):
|
|
1381
1385
|
model_name=model_name,
|
1382
1386
|
os_path=import_model_details.os_path,
|
1383
1387
|
local_dir=import_model_details.local_dir,
|
1388
|
+
allow_patterns=import_model_details.allow_patterns,
|
1389
|
+
ignore_patterns=import_model_details.ignore_patterns,
|
1384
1390
|
).rstrip("/")
|
1385
1391
|
else:
|
1386
1392
|
artifact_path = import_model_details.os_path.rstrip("/")
|
File without changes
|
@@ -0,0 +1,282 @@
|
|
1
|
+
# coding: utf-8
|
2
|
+
# Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved.
|
3
|
+
# This software is dual-licensed to you under the Universal Permissive License (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose either license.
|
4
|
+
|
5
|
+
"""This module contains the custom LLM client for AutoGen v0.2 to use LangChain chat models.
|
6
|
+
https://microsoft.github.io/autogen/0.2/blog/2024/01/26/Custom-Models/
|
7
|
+
|
8
|
+
To use the custom client:
|
9
|
+
1. Prepare the LLM config, including the parameters for initializing the LangChain client.
|
10
|
+
2. Register the custom LLM
|
11
|
+
|
12
|
+
The LLM config should config the following keys:
|
13
|
+
* model_client_cls: Required by AutoGen to identify the custom client. It should be "LangChainModelClient"
|
14
|
+
* langchain_cls: LangChain class including the full import path.
|
15
|
+
* model: Name of the model to be used by AutoGen
|
16
|
+
* client_params: A dictionary containing the parameters to initialize the LangChain chat model.
|
17
|
+
|
18
|
+
Although the `LangChainModelClient` is designed to be generic and can potentially support any LangChain chat model,
|
19
|
+
the invocation depends on the server API spec and it may not be compatible with some implementations.
|
20
|
+
|
21
|
+
Following is an example config for OCI Generative AI service:
|
22
|
+
{
|
23
|
+
"model_client_cls": "LangChainModelClient",
|
24
|
+
"langchain_cls": "langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI",
|
25
|
+
"model": "cohere.command-r-plus",
|
26
|
+
# client_params will be used to initialize the LangChain ChatOCIGenAI class.
|
27
|
+
"client_params": {
|
28
|
+
"model_id": "cohere.command-r-plus",
|
29
|
+
"compartment_id": COMPARTMENT_OCID,
|
30
|
+
"model_kwargs": {"temperature": 0, "max_tokens": 2048},
|
31
|
+
# Update the authentication method as needed
|
32
|
+
"auth_type": "SECURITY_TOKEN",
|
33
|
+
"auth_profile": "DEFAULT",
|
34
|
+
# You may need to specify `service_endpoint` if the service is in a different region.
|
35
|
+
},
|
36
|
+
}
|
37
|
+
|
38
|
+
Following is an example config for OCI Data Science Model Deployment:
|
39
|
+
{
|
40
|
+
"model_client_cls": "LangChainModelClient",
|
41
|
+
"langchain_cls": "ads.llm.ChatOCIModelDeploymentVLLM",
|
42
|
+
"model": "odsc-llm",
|
43
|
+
"endpoint": "https://MODEL_DEPLOYMENT_URL/predict",
|
44
|
+
"model_kwargs": {"temperature": 0.1, "max_tokens": 2048},
|
45
|
+
# function_call_params will only be added to the API call when function/tools are added.
|
46
|
+
"function_call_params": {
|
47
|
+
"tool_choice": "auto",
|
48
|
+
"chat_template": ChatTemplates.mistral(),
|
49
|
+
},
|
50
|
+
}
|
51
|
+
|
52
|
+
Note that if `client_params` is not specified in the config, all arguments from the config except
|
53
|
+
`model_client_cls` and `langchain_cls`, and `function_call_params`, will be used to initialize
|
54
|
+
the LangChain chat model.
|
55
|
+
|
56
|
+
The `function_call_params` will only be used for function/tool calling when tools are specified.
|
57
|
+
|
58
|
+
To register the custom client:
|
59
|
+
|
60
|
+
from ads.llm.autogen.client_v02 import LangChainModelClient, register_custom_client
|
61
|
+
register_custom_client(LangChainModelClient)
|
62
|
+
|
63
|
+
Once registered with ADS, the custom LLM class will be auto-registered for all new agents.
|
64
|
+
There is no need to call `register_model_client()` on each agent.
|
65
|
+
|
66
|
+
References:
|
67
|
+
https://microsoft.github.io/autogen/0.2/docs/notebooks/agentchat_huggingface_langchain/
|
68
|
+
https://github.com/microsoft/autogen/blob/0.2/notebook/agentchat_custom_model.ipynb
|
69
|
+
|
70
|
+
"""
|
71
|
+
import copy
|
72
|
+
import importlib
|
73
|
+
import json
|
74
|
+
import logging
|
75
|
+
from typing import Any, Dict, List, Union
|
76
|
+
from types import SimpleNamespace
|
77
|
+
|
78
|
+
from autogen import ModelClient
|
79
|
+
from autogen.oai.client import OpenAIWrapper, PlaceHolderClient
|
80
|
+
from langchain_core.messages import AIMessage
|
81
|
+
|
82
|
+
|
83
|
+
logger = logging.getLogger(__name__)
|
84
|
+
|
85
|
+
# custom_clients is a dictionary mapping the name of the class to the actual class
|
86
|
+
custom_clients = {}
|
87
|
+
|
88
|
+
# There is a bug in GroupChat when using custom client:
|
89
|
+
# https://github.com/microsoft/autogen/issues/2956
|
90
|
+
# Here we will be patching the OpenAIWrapper to fix the issue.
|
91
|
+
# With this patch, you only need to register the client once with ADS.
|
92
|
+
# For example:
|
93
|
+
#
|
94
|
+
# from ads.llm.autogen.client_v02 import LangChainModelClient, register_custom_client
|
95
|
+
# register_custom_client(LangChainModelClient)
|
96
|
+
#
|
97
|
+
# This patch will auto-register the custom LLM to all new agents.
|
98
|
+
# So there is no need to call `register_model_client()` on each agent.
|
99
|
+
OpenAIWrapper._original_register_default_client = OpenAIWrapper._register_default_client
|
100
|
+
|
101
|
+
|
102
|
+
def _new_register_default_client(
|
103
|
+
self: OpenAIWrapper, config: Dict[str, Any], openai_config: Dict[str, Any]
|
104
|
+
) -> None:
|
105
|
+
"""This is a patched version of the _register_default_client() method
|
106
|
+
to automatically register custom client for agents.
|
107
|
+
"""
|
108
|
+
model_client_cls_name = config.get("model_client_cls")
|
109
|
+
if model_client_cls_name in custom_clients:
|
110
|
+
self._clients.append(PlaceHolderClient(config))
|
111
|
+
self.register_model_client(custom_clients[model_client_cls_name])
|
112
|
+
else:
|
113
|
+
self._original_register_default_client(
|
114
|
+
config=config, openai_config=openai_config
|
115
|
+
)
|
116
|
+
|
117
|
+
|
118
|
+
# Patch the _register_default_client() method
|
119
|
+
OpenAIWrapper._register_default_client = _new_register_default_client
|
120
|
+
|
121
|
+
|
122
|
+
def register_custom_client(client_class):
|
123
|
+
"""Registers custom client for AutoGen."""
|
124
|
+
if client_class.__name__ not in custom_clients:
|
125
|
+
custom_clients[client_class.__name__] = client_class
|
126
|
+
|
127
|
+
|
128
|
+
def _convert_to_langchain_tool(tool):
|
129
|
+
"""Converts the OpenAI tool spec to LangChain tool spec."""
|
130
|
+
if tool["type"] == "function":
|
131
|
+
tool = tool["function"]
|
132
|
+
required = tool["parameters"].get("required", [])
|
133
|
+
properties = copy.deepcopy(tool["parameters"]["properties"])
|
134
|
+
for key in properties.keys():
|
135
|
+
val = properties[key]
|
136
|
+
val["default"] = key in required
|
137
|
+
return {
|
138
|
+
"title": tool["name"],
|
139
|
+
"description": tool["description"],
|
140
|
+
"properties": properties,
|
141
|
+
}
|
142
|
+
raise NotImplementedError(f"Type {tool['type']} is not supported.")
|
143
|
+
|
144
|
+
|
145
|
+
def _convert_to_openai_tool_call(tool_call):
|
146
|
+
"""Converts the LangChain tool call in AI message to OpenAI tool call."""
|
147
|
+
return {
|
148
|
+
"id": tool_call.get("id"),
|
149
|
+
"function": {
|
150
|
+
"name": tool_call.get("name"),
|
151
|
+
"arguments": (
|
152
|
+
""
|
153
|
+
if tool_call.get("args") is None
|
154
|
+
else json.dumps(tool_call.get("args"))
|
155
|
+
),
|
156
|
+
},
|
157
|
+
"type": "function",
|
158
|
+
}
|
159
|
+
|
160
|
+
|
161
|
+
class Message(AIMessage):
|
162
|
+
"""Represents message returned from the LLM."""
|
163
|
+
|
164
|
+
@classmethod
|
165
|
+
def from_message(cls, message: AIMessage):
|
166
|
+
"""Converts from LangChain AIMessage."""
|
167
|
+
message = copy.deepcopy(message)
|
168
|
+
message.__class__ = cls
|
169
|
+
message.tool_calls = [
|
170
|
+
_convert_to_openai_tool_call(tool) for tool in message.tool_calls
|
171
|
+
]
|
172
|
+
return message
|
173
|
+
|
174
|
+
@property
|
175
|
+
def function_call(self):
|
176
|
+
"""Function calls."""
|
177
|
+
return self.tool_calls
|
178
|
+
|
179
|
+
|
180
|
+
class LangChainModelClient(ModelClient):
|
181
|
+
"""Represents a model client wrapping a LangChain chat model."""
|
182
|
+
|
183
|
+
def __init__(self, config: dict, **kwargs) -> None:
|
184
|
+
super().__init__()
|
185
|
+
logger.info("LangChain model client config: %s", str(config))
|
186
|
+
# Make a copy of the config since we are popping some keys
|
187
|
+
config = copy.deepcopy(config)
|
188
|
+
# model_client_cls will always be LangChainModelClient
|
189
|
+
self.client_class = config.pop("model_client_cls")
|
190
|
+
|
191
|
+
# model_name is used in constructing the response.
|
192
|
+
self.model_name = config.get("model", "")
|
193
|
+
|
194
|
+
# If the config specified function_call_params,
|
195
|
+
# Pop the params and use them only for tool calling.
|
196
|
+
self.function_call_params = config.pop("function_call_params", {})
|
197
|
+
|
198
|
+
# If the config specified invoke_params,
|
199
|
+
# Pop the params and use them only for invoking.
|
200
|
+
self.invoke_params = config.pop("invoke_params", {})
|
201
|
+
|
202
|
+
# Import the LangChain class
|
203
|
+
if "langchain_cls" not in config:
|
204
|
+
raise ValueError("Missing langchain_cls in LangChain Model Client config.")
|
205
|
+
module_cls = config.pop("langchain_cls")
|
206
|
+
module_name, cls_name = str(module_cls).rsplit(".", 1)
|
207
|
+
langchain_module = importlib.import_module(module_name)
|
208
|
+
langchain_cls = getattr(langchain_module, cls_name)
|
209
|
+
|
210
|
+
# If the config specified client_params,
|
211
|
+
# Only use the client_params to initialize the LangChain model.
|
212
|
+
# Otherwise, use the config
|
213
|
+
self.client_params = config.get("client_params", config)
|
214
|
+
|
215
|
+
# Initialize the LangChain client
|
216
|
+
self.model = langchain_cls(**self.client_params)
|
217
|
+
|
218
|
+
def create(self, params) -> ModelClient.ModelClientResponseProtocol:
|
219
|
+
"""Creates a LLM completion for a given config.
|
220
|
+
|
221
|
+
Parameters
|
222
|
+
----------
|
223
|
+
params : dict
|
224
|
+
OpenAI API compatible parameters, including all the keys from llm_config.
|
225
|
+
|
226
|
+
Returns
|
227
|
+
-------
|
228
|
+
ModelClientResponseProtocol
|
229
|
+
Response from LLM
|
230
|
+
|
231
|
+
"""
|
232
|
+
streaming = params.get("stream", False)
|
233
|
+
# TODO: num_of_responses
|
234
|
+
num_of_responses = params.get("n", 1)
|
235
|
+
messages = params.pop("messages", [])
|
236
|
+
|
237
|
+
invoke_params = copy.deepcopy(self.invoke_params)
|
238
|
+
|
239
|
+
tools = params.get("tools")
|
240
|
+
if tools:
|
241
|
+
model = self.model.bind_tools(
|
242
|
+
[_convert_to_langchain_tool(tool) for tool in tools]
|
243
|
+
)
|
244
|
+
# invoke_params["tools"] = tools
|
245
|
+
invoke_params.update(self.function_call_params)
|
246
|
+
else:
|
247
|
+
model = self.model
|
248
|
+
|
249
|
+
response = SimpleNamespace()
|
250
|
+
response.choices = []
|
251
|
+
response.model = self.model_name
|
252
|
+
|
253
|
+
if streaming and messages:
|
254
|
+
# If streaming is enabled and has messages, then iterate over the chunks of the response.
|
255
|
+
raise NotImplementedError()
|
256
|
+
else:
|
257
|
+
# If streaming is not enabled, send a regular chat completion request
|
258
|
+
ai_message = model.invoke(messages, **invoke_params)
|
259
|
+
choice = SimpleNamespace()
|
260
|
+
choice.message = Message.from_message(ai_message)
|
261
|
+
response.choices.append(choice)
|
262
|
+
return response
|
263
|
+
|
264
|
+
def message_retrieval(
|
265
|
+
self, response: ModelClient.ModelClientResponseProtocol
|
266
|
+
) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
|
267
|
+
"""
|
268
|
+
Retrieve and return a list of strings or a list of Choice.Message from the response.
|
269
|
+
|
270
|
+
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
|
271
|
+
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
|
272
|
+
"""
|
273
|
+
return [choice.message for choice in response.choices]
|
274
|
+
|
275
|
+
def cost(self, response: ModelClient.ModelClientResponseProtocol) -> float:
|
276
|
+
response.cost = 0
|
277
|
+
return 0
|
278
|
+
|
279
|
+
@staticmethod
|
280
|
+
def get_usage(response: ModelClient.ModelClientResponseProtocol) -> Dict:
|
281
|
+
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
|
282
|
+
return {}
|