oracle-ads 2.11.19__py3-none-any.whl → 2.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/llm/__init__.py +10 -4
- ads/llm/chat_template.py +31 -0
- ads/llm/guardrails/base.py +3 -2
- ads/llm/guardrails/huggingface.py +1 -1
- ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +924 -0
- ads/llm/langchain/plugins/llms/__init__.py +5 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +939 -0
- ads/llm/requirements.txt +2 -2
- ads/llm/serialize.py +3 -6
- ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
- ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
- {oracle_ads-2.11.19.dist-info → oracle_ads-2.12.0.dist-info}/METADATA +6 -4
- {oracle_ads-2.11.19.dist-info → oracle_ads-2.12.0.dist-info}/RECORD +17 -15
- ads/llm/langchain/plugins/base.py +0 -118
- ads/llm/langchain/plugins/contant.py +0 -44
- ads/llm/langchain/plugins/embeddings.py +0 -64
- ads/llm/langchain/plugins/llm_gen_ai.py +0 -301
- ads/llm/langchain/plugins/llm_md.py +0 -316
- {oracle_ads-2.11.19.dist-info → oracle_ads-2.12.0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.11.19.dist-info → oracle_ads-2.12.0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.11.19.dist-info → oracle_ads-2.12.0.dist-info}/entry_points.txt +0 -0
ads/llm/__init__.py
CHANGED
@@ -6,10 +6,16 @@
|
|
6
6
|
|
7
7
|
try:
|
8
8
|
import langchain
|
9
|
-
from ads.llm.langchain.plugins.
|
10
|
-
|
11
|
-
|
12
|
-
|
9
|
+
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
|
10
|
+
OCIModelDeploymentVLLM,
|
11
|
+
OCIModelDeploymentTGI,
|
12
|
+
)
|
13
|
+
from ads.llm.langchain.plugins.chat_models.oci_data_science import (
|
14
|
+
ChatOCIModelDeployment,
|
15
|
+
ChatOCIModelDeploymentVLLM,
|
16
|
+
ChatOCIModelDeploymentTGI,
|
17
|
+
)
|
18
|
+
from ads.llm.chat_template import ChatTemplates
|
13
19
|
except ImportError as ex:
|
14
20
|
if ex.name == "langchain":
|
15
21
|
raise ImportError(
|
ads/llm/chat_template.py
ADDED
@@ -0,0 +1,31 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*--
|
3
|
+
|
4
|
+
# Copyright (c) 2023 Oracle and/or its affiliates.
|
5
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
+
|
7
|
+
|
8
|
+
import os
|
9
|
+
|
10
|
+
|
11
|
+
class ChatTemplates:
|
12
|
+
"""Contains chat templates."""
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def _read_template(filename):
|
16
|
+
with open(
|
17
|
+
os.path.join(os.path.dirname(__file__), "templates", filename),
|
18
|
+
mode="r",
|
19
|
+
encoding="utf-8",
|
20
|
+
) as f:
|
21
|
+
return f.read()
|
22
|
+
|
23
|
+
@staticmethod
|
24
|
+
def mistral():
|
25
|
+
"""Chat template for auto tool calling with Mistral model deploy with vLLM."""
|
26
|
+
return ChatTemplates._read_template("tool_chat_template_mistral_parallel.jinja")
|
27
|
+
|
28
|
+
@staticmethod
|
29
|
+
def hermes():
|
30
|
+
"""Chat template for auto tool calling with Hermes model deploy with vLLM."""
|
31
|
+
return ChatTemplates._read_template("tool_chat_template_hermes.jinja")
|
ads/llm/guardrails/base.py
CHANGED
@@ -14,7 +14,7 @@ import sys
|
|
14
14
|
from typing import Any, List, Dict, Tuple
|
15
15
|
from langchain.schema.prompt import PromptValue
|
16
16
|
from langchain.tools.base import BaseTool, ToolException
|
17
|
-
from
|
17
|
+
from pydantic import BaseModel, model_validator
|
18
18
|
|
19
19
|
|
20
20
|
class RunInfo(BaseModel):
|
@@ -190,7 +190,8 @@ class Guardrail(BaseTool):
|
|
190
190
|
This is used by the ``apply_filter()`` method.
|
191
191
|
"""
|
192
192
|
|
193
|
-
@
|
193
|
+
@model_validator(mode="before")
|
194
|
+
@classmethod
|
194
195
|
def default_name(cls, values):
|
195
196
|
"""Sets the default name of the guardrail."""
|
196
197
|
if not values.get("name"):
|