opengradient 0.3.24__py3-none-any.whl → 0.3.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opengradient/__init__.py +125 -98
- opengradient/account.py +6 -4
- opengradient/cli.py +151 -154
- opengradient/client.py +300 -362
- opengradient/defaults.py +7 -7
- opengradient/exceptions.py +25 -0
- opengradient/llm/__init__.py +7 -10
- opengradient/llm/og_langchain.py +34 -51
- opengradient/llm/og_openai.py +54 -61
- opengradient/mltools/__init__.py +2 -7
- opengradient/mltools/model_tool.py +20 -26
- opengradient/proto/infer_pb2.py +24 -29
- opengradient/proto/infer_pb2_grpc.py +95 -86
- opengradient/types.py +39 -35
- opengradient/utils.py +30 -31
- {opengradient-0.3.24.dist-info → opengradient-0.3.26.dist-info}/METADATA +5 -92
- opengradient-0.3.26.dist-info/RECORD +26 -0
- opengradient-0.3.24.dist-info/RECORD +0 -26
- {opengradient-0.3.24.dist-info → opengradient-0.3.26.dist-info}/LICENSE +0 -0
- {opengradient-0.3.24.dist-info → opengradient-0.3.26.dist-info}/WHEEL +0 -0
- {opengradient-0.3.24.dist-info → opengradient-0.3.26.dist-info}/entry_points.txt +0 -0
- {opengradient-0.3.24.dist-info → opengradient-0.3.26.dist-info}/top_level.txt +0 -0
opengradient/defaults.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
# Default variables
|
|
2
|
-
DEFAULT_RPC_URL="http://18.218.115.248:8545"
|
|
3
|
-
DEFAULT_OG_FAUCET_URL="http://18.218.115.248:8080/?address="
|
|
4
|
-
DEFAULT_HUB_SIGNUP_URL="https://hub.opengradient.ai/signup"
|
|
5
|
-
DEFAULT_INFERENCE_CONTRACT_ADDRESS="0x3fDCb0394CF4919ff4361f4EbA0750cEc2e3bBc7"
|
|
6
|
-
DEFAULT_BLOCKCHAIN_EXPLORER="http://3.145.62.2/tx/"
|
|
7
|
-
DEFAULT_IMAGE_GEN_HOST="18.217.25.69"
|
|
8
|
-
DEFAULT_IMAGE_GEN_PORT=5125
|
|
2
|
+
DEFAULT_RPC_URL = "http://18.218.115.248:8545"
|
|
3
|
+
DEFAULT_OG_FAUCET_URL = "http://18.218.115.248:8080/?address="
|
|
4
|
+
DEFAULT_HUB_SIGNUP_URL = "https://hub.opengradient.ai/signup"
|
|
5
|
+
DEFAULT_INFERENCE_CONTRACT_ADDRESS = "0x3fDCb0394CF4919ff4361f4EbA0750cEc2e3bBc7"
|
|
6
|
+
DEFAULT_BLOCKCHAIN_EXPLORER = "http://3.145.62.2/tx/"
|
|
7
|
+
DEFAULT_IMAGE_GEN_HOST = "18.217.25.69"
|
|
8
|
+
DEFAULT_IMAGE_GEN_PORT = 5125
|
opengradient/exceptions.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
class OpenGradientError(Exception):
|
|
2
2
|
"""Base exception for OpenGradient SDK"""
|
|
3
|
+
|
|
3
4
|
def __init__(self, message, status_code=None, response=None):
|
|
4
5
|
self.message = message
|
|
5
6
|
self.status_code = status_code
|
|
@@ -9,69 +10,93 @@ class OpenGradientError(Exception):
|
|
|
9
10
|
def __str__(self):
|
|
10
11
|
return f"{self.message} (Status code: {self.status_code})"
|
|
11
12
|
|
|
13
|
+
|
|
12
14
|
class FileNotFoundError(OpenGradientError):
|
|
13
15
|
"""Raised when a file is not found"""
|
|
16
|
+
|
|
14
17
|
def __init__(self, file_path):
|
|
15
18
|
super().__init__(f"File not found: {file_path}")
|
|
16
19
|
self.file_path = file_path
|
|
17
20
|
|
|
21
|
+
|
|
18
22
|
class UploadError(OpenGradientError):
|
|
19
23
|
"""Raised when there's an error during file upload"""
|
|
24
|
+
|
|
20
25
|
def __init__(self, message, file_path=None, **kwargs):
|
|
21
26
|
super().__init__(message, **kwargs)
|
|
22
27
|
self.file_path = file_path
|
|
23
28
|
|
|
29
|
+
|
|
24
30
|
class InferenceError(OpenGradientError):
|
|
25
31
|
"""Raised when there's an error during inference"""
|
|
32
|
+
|
|
26
33
|
def __init__(self, message, model_cid=None, **kwargs):
|
|
27
34
|
super().__init__(message, **kwargs)
|
|
28
35
|
self.model_cid = model_cid
|
|
29
36
|
|
|
37
|
+
|
|
30
38
|
class ResultRetrievalError(OpenGradientError):
|
|
31
39
|
"""Raised when there's an error retrieving results"""
|
|
40
|
+
|
|
32
41
|
def __init__(self, message, inference_cid=None, **kwargs):
|
|
33
42
|
super().__init__(message, **kwargs)
|
|
34
43
|
self.inference_cid = inference_cid
|
|
35
44
|
|
|
45
|
+
|
|
36
46
|
class AuthenticationError(OpenGradientError):
|
|
37
47
|
"""Raised when there's an authentication error"""
|
|
48
|
+
|
|
38
49
|
def __init__(self, message="Authentication failed", **kwargs):
|
|
39
50
|
super().__init__(message, **kwargs)
|
|
40
51
|
|
|
52
|
+
|
|
41
53
|
class RateLimitError(OpenGradientError):
|
|
42
54
|
"""Raised when API rate limit is exceeded"""
|
|
55
|
+
|
|
43
56
|
def __init__(self, message="Rate limit exceeded", retry_after=None, **kwargs):
|
|
44
57
|
super().__init__(message, **kwargs)
|
|
45
58
|
self.retry_after = retry_after
|
|
46
59
|
|
|
60
|
+
|
|
47
61
|
class InvalidInputError(OpenGradientError):
|
|
48
62
|
"""Raised when invalid input is provided"""
|
|
63
|
+
|
|
49
64
|
def __init__(self, message, invalid_fields=None, **kwargs):
|
|
50
65
|
super().__init__(message, **kwargs)
|
|
51
66
|
self.invalid_fields = invalid_fields or []
|
|
52
67
|
|
|
68
|
+
|
|
53
69
|
class ServerError(OpenGradientError):
|
|
54
70
|
"""Raised when a server error occurs"""
|
|
71
|
+
|
|
55
72
|
pass
|
|
56
73
|
|
|
74
|
+
|
|
57
75
|
class TimeoutError(OpenGradientError):
|
|
58
76
|
"""Raised when a request times out"""
|
|
77
|
+
|
|
59
78
|
def __init__(self, message="Request timed out", timeout=None, **kwargs):
|
|
60
79
|
super().__init__(message, **kwargs)
|
|
61
80
|
self.timeout = timeout
|
|
62
81
|
|
|
82
|
+
|
|
63
83
|
class NetworkError(OpenGradientError):
|
|
64
84
|
"""Raised when a network error occurs"""
|
|
85
|
+
|
|
65
86
|
pass
|
|
66
87
|
|
|
88
|
+
|
|
67
89
|
class UnsupportedModelError(OpenGradientError):
|
|
68
90
|
"""Raised when an unsupported model type is used"""
|
|
91
|
+
|
|
69
92
|
def __init__(self, model_type):
|
|
70
93
|
super().__init__(f"Unsupported model type: {model_type}")
|
|
71
94
|
self.model_type = model_type
|
|
72
95
|
|
|
96
|
+
|
|
73
97
|
class InsufficientCreditsError(OpenGradientError):
|
|
74
98
|
"""Raised when the user has insufficient credits for the operation"""
|
|
99
|
+
|
|
75
100
|
def __init__(self, message="Insufficient credits", required_credits=None, available_credits=None, **kwargs):
|
|
76
101
|
super().__init__(message, **kwargs)
|
|
77
102
|
self.required_credits = required_credits
|
opengradient/llm/__init__.py
CHANGED
|
@@ -9,15 +9,14 @@ into existing applications and agent frameworks.
|
|
|
9
9
|
from .og_langchain import *
|
|
10
10
|
from .og_openai import *
|
|
11
11
|
|
|
12
|
+
|
|
12
13
|
def langchain_adapter(private_key: str, model_cid: str, max_tokens: int = 300) -> OpenGradientChatModel:
|
|
13
14
|
"""
|
|
14
15
|
Returns an OpenGradient LLM that implements LangChain's LLM interface
|
|
15
16
|
and can be plugged into LangChain agents.
|
|
16
17
|
"""
|
|
17
|
-
return OpenGradientChatModel(
|
|
18
|
-
|
|
19
|
-
model_cid=model_cid,
|
|
20
|
-
max_tokens=max_tokens)
|
|
18
|
+
return OpenGradientChatModel(private_key=private_key, model_cid=model_cid, max_tokens=max_tokens)
|
|
19
|
+
|
|
21
20
|
|
|
22
21
|
def openai_adapter(private_key: str) -> OpenGradientOpenAIClient:
|
|
23
22
|
"""
|
|
@@ -27,12 +26,10 @@ def openai_adapter(private_key: str) -> OpenGradientOpenAIClient:
|
|
|
27
26
|
"""
|
|
28
27
|
return OpenGradientOpenAIClient(private_key=private_key)
|
|
29
28
|
|
|
29
|
+
|
|
30
30
|
__all__ = [
|
|
31
|
-
|
|
32
|
-
|
|
31
|
+
"langchain_adapter",
|
|
32
|
+
"openai_adapter",
|
|
33
33
|
]
|
|
34
34
|
|
|
35
|
-
__pdoc__ = {
|
|
36
|
-
'og_langchain': False,
|
|
37
|
-
'og_openai': False
|
|
38
|
-
}
|
|
35
|
+
__pdoc__ = {"og_langchain": False, "og_openai": False}
|
opengradient/llm/og_langchain.py
CHANGED
|
@@ -1,24 +1,22 @@
|
|
|
1
|
-
from typing import List, Dict, Optional, Any, Sequence, Union
|
|
2
1
|
import json
|
|
2
|
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
|
3
3
|
|
|
4
4
|
from langchain.chat_models.base import BaseChatModel
|
|
5
5
|
from langchain.schema import (
|
|
6
6
|
AIMessage,
|
|
7
|
-
HumanMessage,
|
|
8
|
-
SystemMessage,
|
|
9
7
|
BaseMessage,
|
|
10
|
-
ChatResult,
|
|
11
8
|
ChatGeneration,
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
9
|
+
ChatResult,
|
|
10
|
+
HumanMessage,
|
|
11
|
+
SystemMessage,
|
|
15
12
|
)
|
|
16
13
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
|
17
|
-
from langchain_core.tools import BaseTool
|
|
18
14
|
from langchain_core.messages import ToolCall
|
|
15
|
+
from langchain_core.messages.tool import ToolMessage
|
|
16
|
+
from langchain_core.tools import BaseTool
|
|
19
17
|
|
|
20
18
|
from opengradient import Client, LlmInferenceMode
|
|
21
|
-
from opengradient.defaults import
|
|
19
|
+
from opengradient.defaults import DEFAULT_INFERENCE_CONTRACT_ADDRESS, DEFAULT_RPC_URL
|
|
22
20
|
|
|
23
21
|
|
|
24
22
|
class OpenGradientChatModel(BaseChatModel):
|
|
@@ -32,18 +30,15 @@ class OpenGradientChatModel(BaseChatModel):
|
|
|
32
30
|
def __init__(self, private_key: str, model_cid: str, max_tokens: int = 300):
|
|
33
31
|
super().__init__()
|
|
34
32
|
self.client = Client(
|
|
35
|
-
private_key=private_key,
|
|
36
|
-
|
|
37
|
-
contract_address=DEFAULT_INFERENCE_CONTRACT_ADDRESS,
|
|
38
|
-
email=None,
|
|
39
|
-
password=None)
|
|
33
|
+
private_key=private_key, rpc_url=DEFAULT_RPC_URL, contract_address=DEFAULT_INFERENCE_CONTRACT_ADDRESS, email=None, password=None
|
|
34
|
+
)
|
|
40
35
|
self.model_cid = model_cid
|
|
41
36
|
self.max_tokens = max_tokens
|
|
42
37
|
|
|
43
38
|
@property
|
|
44
39
|
def _llm_type(self) -> str:
|
|
45
40
|
return "opengradient"
|
|
46
|
-
|
|
41
|
+
|
|
47
42
|
def bind_tools(
|
|
48
43
|
self,
|
|
49
44
|
tools: Sequence[Union[BaseTool, Dict]],
|
|
@@ -52,17 +47,19 @@ class OpenGradientChatModel(BaseChatModel):
|
|
|
52
47
|
tool_dicts = []
|
|
53
48
|
for tool in tools:
|
|
54
49
|
if isinstance(tool, BaseTool):
|
|
55
|
-
tool_dicts.append(
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
"
|
|
59
|
-
|
|
60
|
-
|
|
50
|
+
tool_dicts.append(
|
|
51
|
+
{
|
|
52
|
+
"type": "function",
|
|
53
|
+
"function": {
|
|
54
|
+
"name": tool.name,
|
|
55
|
+
"description": tool.description,
|
|
56
|
+
"parameters": tool.args_schema.schema() if hasattr(tool, "args_schema") else {},
|
|
57
|
+
},
|
|
61
58
|
}
|
|
62
|
-
|
|
59
|
+
)
|
|
63
60
|
else:
|
|
64
61
|
tool_dicts.append(tool)
|
|
65
|
-
|
|
62
|
+
|
|
66
63
|
self.tools = tool_dicts
|
|
67
64
|
return self
|
|
68
65
|
|
|
@@ -73,7 +70,6 @@ class OpenGradientChatModel(BaseChatModel):
|
|
|
73
70
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
74
71
|
**kwargs: Any,
|
|
75
72
|
) -> ChatResult:
|
|
76
|
-
|
|
77
73
|
sdk_messages = []
|
|
78
74
|
for message in messages:
|
|
79
75
|
if isinstance(message, SystemMessage):
|
|
@@ -81,14 +77,15 @@ class OpenGradientChatModel(BaseChatModel):
|
|
|
81
77
|
elif isinstance(message, HumanMessage):
|
|
82
78
|
sdk_messages.append({"role": "user", "content": message.content})
|
|
83
79
|
elif isinstance(message, AIMessage):
|
|
84
|
-
sdk_messages.append(
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
"
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
}
|
|
80
|
+
sdk_messages.append(
|
|
81
|
+
{
|
|
82
|
+
"role": "assistant",
|
|
83
|
+
"content": message.content,
|
|
84
|
+
"tool_calls": [
|
|
85
|
+
{"id": call["id"], "name": call["name"], "arguments": json.dumps(call["args"])} for call in message.tool_calls
|
|
86
|
+
],
|
|
87
|
+
}
|
|
88
|
+
)
|
|
92
89
|
elif isinstance(message, ToolMessage):
|
|
93
90
|
sdk_messages.append({"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id})
|
|
94
91
|
else:
|
|
@@ -100,33 +97,19 @@ class OpenGradientChatModel(BaseChatModel):
|
|
|
100
97
|
stop_sequence=stop,
|
|
101
98
|
max_tokens=self.max_tokens,
|
|
102
99
|
tools=self.tools,
|
|
103
|
-
inference_mode=LlmInferenceMode.VANILLA
|
|
100
|
+
inference_mode=LlmInferenceMode.VANILLA,
|
|
104
101
|
)
|
|
105
102
|
|
|
106
103
|
if "tool_calls" in chat_response and chat_response["tool_calls"]:
|
|
107
104
|
tool_calls = []
|
|
108
105
|
for tool_call in chat_response["tool_calls"]:
|
|
109
|
-
tool_calls.append(
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
name=tool_call["name"],
|
|
113
|
-
args=json.loads(tool_call["arguments"])
|
|
114
|
-
)
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
message = AIMessage(
|
|
118
|
-
content='',
|
|
119
|
-
tool_calls=tool_calls
|
|
120
|
-
)
|
|
106
|
+
tool_calls.append(ToolCall(id=tool_call.get("id", ""), name=tool_call["name"], args=json.loads(tool_call["arguments"])))
|
|
107
|
+
|
|
108
|
+
message = AIMessage(content="", tool_calls=tool_calls)
|
|
121
109
|
else:
|
|
122
110
|
message = AIMessage(content=chat_response["content"])
|
|
123
111
|
|
|
124
|
-
return ChatResult(
|
|
125
|
-
generations=[ChatGeneration(
|
|
126
|
-
message=message,
|
|
127
|
-
generation_info={"finish_reason": finish_reason}
|
|
128
|
-
)]
|
|
129
|
-
)
|
|
112
|
+
return ChatResult(generations=[ChatGeneration(message=message, generation_info={"finish_reason": finish_reason})])
|
|
130
113
|
|
|
131
114
|
@property
|
|
132
115
|
def _identifying_params(self) -> Dict[str, Any]:
|
opengradient/llm/og_openai.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import uuid
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
1
5
|
from openai.types.chat import ChatCompletion
|
|
6
|
+
|
|
2
7
|
import opengradient as og
|
|
3
|
-
from opengradient.defaults import
|
|
8
|
+
from opengradient.defaults import DEFAULT_INFERENCE_CONTRACT_ADDRESS, DEFAULT_RPC_URL
|
|
4
9
|
|
|
5
|
-
from typing import List
|
|
6
|
-
import time
|
|
7
|
-
import uuid
|
|
8
10
|
|
|
9
11
|
class OGCompletions(object):
|
|
10
12
|
client: og.Client
|
|
@@ -13,14 +15,14 @@ class OGCompletions(object):
|
|
|
13
15
|
self.client = client
|
|
14
16
|
|
|
15
17
|
def create(
|
|
16
|
-
self,
|
|
17
|
-
model: str,
|
|
18
|
-
messages: List[object],
|
|
19
|
-
tools: List[object],
|
|
20
|
-
tool_choice: str,
|
|
21
|
-
stream: bool = False,
|
|
22
|
-
parallel_tool_calls: bool = False
|
|
23
|
-
|
|
18
|
+
self,
|
|
19
|
+
model: str,
|
|
20
|
+
messages: List[object],
|
|
21
|
+
tools: List[object],
|
|
22
|
+
tool_choice: str,
|
|
23
|
+
stream: bool = False,
|
|
24
|
+
parallel_tool_calls: bool = False,
|
|
25
|
+
) -> ChatCompletion:
|
|
24
26
|
# convert OpenAI message format so it's compatible with the SDK
|
|
25
27
|
sdk_messages = OGCompletions.convert_to_abi_compatible(messages)
|
|
26
28
|
|
|
@@ -29,38 +31,32 @@ class OGCompletions(object):
|
|
|
29
31
|
messages=sdk_messages,
|
|
30
32
|
max_tokens=200,
|
|
31
33
|
tools=tools,
|
|
32
|
-
tool_choice=tool_choice,
|
|
34
|
+
tool_choice=tool_choice,
|
|
33
35
|
temperature=0.25,
|
|
34
|
-
inference_mode=og.LlmInferenceMode.VANILLA
|
|
36
|
+
inference_mode=og.LlmInferenceMode.VANILLA,
|
|
35
37
|
)
|
|
36
38
|
|
|
37
39
|
choice = {
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
40
|
+
"index": 0, # Add missing index field
|
|
41
|
+
"finish_reason": finish_reason,
|
|
42
|
+
"message": {
|
|
43
|
+
"role": chat_completion["role"],
|
|
44
|
+
"content": chat_completion["content"],
|
|
45
|
+
"tool_calls": [
|
|
44
46
|
{
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
}
|
|
47
|
+
"id": tool_call["id"],
|
|
48
|
+
"type": "function", # Add missing type field
|
|
49
|
+
"function": { # Add missing function field
|
|
50
|
+
"name": tool_call["name"],
|
|
51
|
+
"arguments": tool_call["arguments"],
|
|
52
|
+
},
|
|
51
53
|
}
|
|
52
|
-
for tool_call in chat_completion.get(
|
|
53
|
-
]
|
|
54
|
-
}
|
|
54
|
+
for tool_call in chat_completion.get("tool_calls", [])
|
|
55
|
+
],
|
|
56
|
+
},
|
|
55
57
|
}
|
|
56
58
|
|
|
57
|
-
return ChatCompletion(
|
|
58
|
-
id=str(uuid.uuid4()),
|
|
59
|
-
created=int(time.time()),
|
|
60
|
-
model=model,
|
|
61
|
-
object='chat.completion',
|
|
62
|
-
choices=[choice]
|
|
63
|
-
)
|
|
59
|
+
return ChatCompletion(id=str(uuid.uuid4()), created=int(time.time()), model=model, object="chat.completion", choices=[choice])
|
|
64
60
|
|
|
65
61
|
@staticmethod
|
|
66
62
|
@staticmethod
|
|
@@ -68,53 +64,50 @@ class OGCompletions(object):
|
|
|
68
64
|
sdk_messages = []
|
|
69
65
|
|
|
70
66
|
for message in messages:
|
|
71
|
-
role = message[
|
|
72
|
-
sdk_message = {
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
sdk_message[
|
|
78
|
-
elif role ==
|
|
79
|
-
sdk_message[
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
sdk_message['tool_call_id'] = message['tool_call_id']
|
|
83
|
-
elif role == 'assistant':
|
|
67
|
+
role = message["role"]
|
|
68
|
+
sdk_message = {"role": role}
|
|
69
|
+
|
|
70
|
+
if role == "system":
|
|
71
|
+
sdk_message["content"] = message["content"]
|
|
72
|
+
elif role == "user":
|
|
73
|
+
sdk_message["content"] = message["content"]
|
|
74
|
+
elif role == "tool":
|
|
75
|
+
sdk_message["content"] = message["content"]
|
|
76
|
+
sdk_message["tool_call_id"] = message["tool_call_id"]
|
|
77
|
+
elif role == "assistant":
|
|
84
78
|
flattened_calls = []
|
|
85
|
-
for tool_call in message[
|
|
79
|
+
for tool_call in message["tool_calls"]:
|
|
86
80
|
# OpenAI format
|
|
87
81
|
flattened_call = {
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
82
|
+
"id": tool_call["id"],
|
|
83
|
+
"name": tool_call["function"]["name"],
|
|
84
|
+
"arguments": tool_call["function"]["arguments"],
|
|
91
85
|
}
|
|
92
86
|
flattened_calls.append(flattened_call)
|
|
93
87
|
|
|
94
|
-
sdk_message[
|
|
95
|
-
sdk_message[
|
|
88
|
+
sdk_message["tool_calls"] = flattened_calls
|
|
89
|
+
sdk_message["content"] = message["content"]
|
|
96
90
|
|
|
97
91
|
sdk_messages.append(sdk_message)
|
|
98
92
|
|
|
99
93
|
return sdk_messages
|
|
100
94
|
|
|
95
|
+
|
|
101
96
|
class OGChat(object):
|
|
102
97
|
completions: OGCompletions
|
|
103
98
|
|
|
104
99
|
def __init__(self, client: og.Client):
|
|
105
100
|
self.completions = OGCompletions(client)
|
|
106
101
|
|
|
102
|
+
|
|
107
103
|
class OpenGradientOpenAIClient(object):
|
|
108
104
|
"""OpenAI client implementation"""
|
|
105
|
+
|
|
109
106
|
client: og.Client
|
|
110
107
|
chat: OGChat
|
|
111
108
|
|
|
112
109
|
def __init__(self, private_key: str):
|
|
113
110
|
self.client = og.Client(
|
|
114
|
-
private_key=private_key,
|
|
115
|
-
rpc_url=DEFAULT_RPC_URL,
|
|
116
|
-
contract_address=DEFAULT_INFERENCE_CONTRACT_ADDRESS,
|
|
117
|
-
email=None,
|
|
118
|
-
password=None
|
|
111
|
+
private_key=private_key, rpc_url=DEFAULT_RPC_URL, contract_address=DEFAULT_INFERENCE_CONTRACT_ADDRESS, email=None, password=None
|
|
119
112
|
)
|
|
120
|
-
self.chat = OGChat(self.client)
|
|
113
|
+
self.chat = OGChat(self.client)
|
opengradient/mltools/__init__.py
CHANGED
|
@@ -1,28 +1,32 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
-
from typing import Callable, Dict,
|
|
2
|
+
from typing import Any, Callable, Dict, Type
|
|
3
3
|
|
|
4
|
-
from pydantic import BaseModel
|
|
5
4
|
from langchain_core.tools import BaseTool, StructuredTool
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
6
7
|
import opengradient as og
|
|
7
8
|
|
|
9
|
+
|
|
8
10
|
class ToolType(str, Enum):
|
|
9
11
|
"""Indicates the framework the tool is compatible with."""
|
|
10
12
|
|
|
11
13
|
LANGCHAIN = "langchain"
|
|
12
14
|
SWARM = "swarm"
|
|
13
|
-
|
|
15
|
+
|
|
14
16
|
def __str__(self) -> str:
|
|
15
17
|
return self.value
|
|
16
18
|
|
|
19
|
+
|
|
17
20
|
def create_og_model_tool(
|
|
18
21
|
tool_type: ToolType,
|
|
19
|
-
model_cid: str,
|
|
22
|
+
model_cid: str,
|
|
20
23
|
tool_name: str,
|
|
21
24
|
input_getter: Callable,
|
|
22
25
|
output_formatter: Callable[..., str],
|
|
23
26
|
input_schema: Type[BaseModel] = None,
|
|
24
|
-
tool_description: str = "Executes the given ML model",
|
|
25
|
-
inference_mode: og.InferenceMode= og.InferenceMode.VANILLA
|
|
27
|
+
tool_description: str = "Executes the given ML model",
|
|
28
|
+
inference_mode: og.InferenceMode = og.InferenceMode.VANILLA,
|
|
29
|
+
) -> BaseTool:
|
|
26
30
|
"""
|
|
27
31
|
Creates a tool that wraps an OpenGradient model for inference.
|
|
28
32
|
|
|
@@ -32,15 +36,15 @@ def create_og_model_tool(
|
|
|
32
36
|
runs inference using the specified OpenGradient model.
|
|
33
37
|
|
|
34
38
|
Args:
|
|
35
|
-
tool_type (ToolType): Specifies the framework to create the tool for. Use
|
|
39
|
+
tool_type (ToolType): Specifies the framework to create the tool for. Use
|
|
36
40
|
ToolType.LANGCHAIN for LangChain integration or ToolType.SWARM for Swarm
|
|
37
|
-
integration.
|
|
41
|
+
integration.
|
|
38
42
|
model_cid (str): The CID of the OpenGradient model to be executed.
|
|
39
43
|
tool_name (str): The name to assign to the created tool. This will be used to identify
|
|
40
44
|
and invoke the tool within the agent.
|
|
41
45
|
input_getter (Callable): A function that returns the input data required by the model.
|
|
42
46
|
The function should return data in a format compatible with the model's expectations.
|
|
43
|
-
output_formatter (Callable[..., str]): A function that takes the model output and
|
|
47
|
+
output_formatter (Callable[..., str]): A function that takes the model output and
|
|
44
48
|
formats it into a string. This is required to ensure the output is compatible
|
|
45
49
|
with the tool framework.
|
|
46
50
|
input_schema (Type[BaseModel], optional): A Pydantic BaseModel class defining the
|
|
@@ -78,29 +82,18 @@ def create_og_model_tool(
|
|
|
78
82
|
... tool_description="Classifies text into categories"
|
|
79
83
|
... )
|
|
80
84
|
"""
|
|
85
|
+
|
|
81
86
|
# define runnable
|
|
82
87
|
def model_executor(**llm_input):
|
|
83
88
|
# Combine LLM input with input provided by code
|
|
84
|
-
combined_input = {
|
|
85
|
-
**llm_input,
|
|
86
|
-
**input_getter()
|
|
87
|
-
}
|
|
89
|
+
combined_input = {**llm_input, **input_getter()}
|
|
88
90
|
|
|
89
|
-
_, output = og.infer(
|
|
90
|
-
model_cid=model_cid,
|
|
91
|
-
inference_mode=inference_mode,
|
|
92
|
-
model_input=combined_input
|
|
93
|
-
)
|
|
91
|
+
_, output = og.infer(model_cid=model_cid, inference_mode=inference_mode, model_input=combined_input)
|
|
94
92
|
|
|
95
93
|
return output_formatter(output)
|
|
96
94
|
|
|
97
95
|
if tool_type == ToolType.LANGCHAIN:
|
|
98
|
-
return StructuredTool.from_function(
|
|
99
|
-
func=model_executor,
|
|
100
|
-
name=tool_name,
|
|
101
|
-
description=tool_description,
|
|
102
|
-
args_schema=input_schema
|
|
103
|
-
)
|
|
96
|
+
return StructuredTool.from_function(func=model_executor, name=tool_name, description=tool_description, args_schema=input_schema)
|
|
104
97
|
elif tool_type == ToolType.SWARM:
|
|
105
98
|
model_executor.__name__ = tool_name
|
|
106
99
|
model_executor.__doc__ = tool_description
|
|
@@ -111,13 +104,14 @@ def create_og_model_tool(
|
|
|
111
104
|
else:
|
|
112
105
|
raise ValueError(f"Invalid tooltype: {tool_type}")
|
|
113
106
|
|
|
107
|
+
|
|
114
108
|
def _convert_pydantic_to_annotations(model: Type[BaseModel]) -> Dict[str, Any]:
|
|
115
109
|
"""
|
|
116
110
|
Convert a Pydantic model to function annotations format used by Swarm.
|
|
117
|
-
|
|
111
|
+
|
|
118
112
|
Args:
|
|
119
113
|
model: A Pydantic BaseModel class
|
|
120
|
-
|
|
114
|
+
|
|
121
115
|
Returns:
|
|
122
116
|
Dict mapping field names to (type, description) tuples
|
|
123
117
|
"""
|