opengradient 0.4.7__tar.gz → 0.4.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opengradient-0.4.7/src/opengradient.egg-info → opengradient-0.4.9}/PKG-INFO +1 -1
- {opengradient-0.4.7 → opengradient-0.4.9}/pyproject.toml +1 -1
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/__init__.py +12 -7
- opengradient-0.4.9/src/opengradient/alphasense/run_model_tool.py +152 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/client.py +39 -59
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/types.py +50 -4
- opengradient-0.4.9/src/opengradient/workflow_models/__init__.py +28 -0
- opengradient-0.4.9/src/opengradient/workflow_models/constants.py +13 -0
- opengradient-0.4.9/src/opengradient/workflow_models/types.py +16 -0
- opengradient-0.4.9/src/opengradient/workflow_models/utils.py +39 -0
- opengradient-0.4.9/src/opengradient/workflow_models/workflow_models.py +97 -0
- {opengradient-0.4.7 → opengradient-0.4.9/src/opengradient.egg-info}/PKG-INFO +1 -1
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient.egg-info/SOURCES.txt +6 -1
- opengradient-0.4.7/src/opengradient/alphasense/run_model_tool.py +0 -114
- {opengradient-0.4.7 → opengradient-0.4.9}/LICENSE +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/README.md +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/setup.cfg +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/abi/PriceHistoryInference.abi +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/abi/WorkflowScheduler.abi +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/abi/inference.abi +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/account.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/alphasense/__init__.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/alphasense/read_workflow_tool.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/alphasense/types.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/bin/PriceHistoryInference.bin +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/cli.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/defaults.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/exceptions.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/llm/__init__.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/llm/og_langchain.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/llm/og_openai.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/proto/__init__.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/proto/infer.proto +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/proto/infer_pb2.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/proto/infer_pb2_grpc.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient/utils.py +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient.egg-info/dependency_links.txt +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient.egg-info/entry_points.txt +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient.egg-info/requires.txt +0 -0
- {opengradient-0.4.7 → opengradient-0.4.9}/src/opengradient.egg-info/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "opengradient"
|
|
7
|
-
version = "0.4.
|
|
7
|
+
version = "0.4.9"
|
|
8
8
|
description = "Python SDK for OpenGradient decentralized model management & inference services"
|
|
9
9
|
authors = [{name = "OpenGradient", email = "oliver@opengradient.ai"}]
|
|
10
10
|
license = {file = "LICENSE"}
|
|
@@ -14,9 +14,12 @@ from .types import (
|
|
|
14
14
|
CandleType,
|
|
15
15
|
CandleOrder,
|
|
16
16
|
InferenceMode,
|
|
17
|
+
InferenceResult,
|
|
17
18
|
LlmInferenceMode,
|
|
18
19
|
TextGenerationOutput,
|
|
19
20
|
ModelOutput,
|
|
21
|
+
ModelRepository,
|
|
22
|
+
FileUploadResult,
|
|
20
23
|
)
|
|
21
24
|
|
|
22
25
|
from . import llm, alphasense
|
|
@@ -61,7 +64,7 @@ def init(email: str, password: str, private_key: str, rpc_url=DEFAULT_RPC_URL, c
|
|
|
61
64
|
return _client
|
|
62
65
|
|
|
63
66
|
|
|
64
|
-
def upload(model_path, model_name, version):
|
|
67
|
+
def upload(model_path, model_name, version) -> FileUploadResult:
|
|
65
68
|
"""Upload a model file to OpenGradient.
|
|
66
69
|
|
|
67
70
|
Args:
|
|
@@ -70,7 +73,7 @@ def upload(model_path, model_name, version):
|
|
|
70
73
|
version: Version string for this model upload
|
|
71
74
|
|
|
72
75
|
Returns:
|
|
73
|
-
|
|
76
|
+
FileUploadResult: Upload response containing file metadata
|
|
74
77
|
|
|
75
78
|
Raises:
|
|
76
79
|
RuntimeError: If SDK is not initialized
|
|
@@ -80,7 +83,7 @@ def upload(model_path, model_name, version):
|
|
|
80
83
|
return _client.upload(model_path, model_name, version)
|
|
81
84
|
|
|
82
85
|
|
|
83
|
-
def create_model(model_name: str, model_desc: str, model_path: Optional[str] = None):
|
|
86
|
+
def create_model(model_name: str, model_desc: str, model_path: Optional[str] = None) -> ModelRepository:
|
|
84
87
|
"""Create a new model repository.
|
|
85
88
|
|
|
86
89
|
Args:
|
|
@@ -89,7 +92,7 @@ def create_model(model_name: str, model_desc: str, model_path: Optional[str] = N
|
|
|
89
92
|
model_path: Optional path to model file to upload immediately
|
|
90
93
|
|
|
91
94
|
Returns:
|
|
92
|
-
|
|
95
|
+
ModelRepository: Creation response with model metadata and optional upload results
|
|
93
96
|
|
|
94
97
|
Raises:
|
|
95
98
|
RuntimeError: If SDK is not initialized
|
|
@@ -126,7 +129,7 @@ def create_version(model_name, notes=None, is_major=False):
|
|
|
126
129
|
return _client.create_version(model_name, notes, is_major)
|
|
127
130
|
|
|
128
131
|
|
|
129
|
-
def infer(model_cid, inference_mode, model_input, max_retries: Optional[int] = None):
|
|
132
|
+
def infer(model_cid, inference_mode, model_input, max_retries: Optional[int] = None) -> InferenceResult:
|
|
130
133
|
"""Run inference on a model.
|
|
131
134
|
|
|
132
135
|
Args:
|
|
@@ -136,7 +139,9 @@ def infer(model_cid, inference_mode, model_input, max_retries: Optional[int] = N
|
|
|
136
139
|
max_retries: Maximum number of retries for failed transactions
|
|
137
140
|
|
|
138
141
|
Returns:
|
|
139
|
-
InferenceResult:
|
|
142
|
+
InferenceResult (InferenceResult): A dataclass object containing the transaction hash and model output.
|
|
143
|
+
* transaction_hash (str): Blockchain hash for the transaction
|
|
144
|
+
* model_output (Dict[str, np.ndarray]): Output of the ONNX model
|
|
140
145
|
|
|
141
146
|
Raises:
|
|
142
147
|
RuntimeError: If SDK is not initialized
|
|
@@ -319,7 +324,7 @@ def run_workflow(contract_address: str) -> ModelOutput:
|
|
|
319
324
|
return _client.run_workflow(contract_address)
|
|
320
325
|
|
|
321
326
|
|
|
322
|
-
def read_workflow_history(contract_address: str, num_results: int) -> List[
|
|
327
|
+
def read_workflow_history(contract_address: str, num_results: int) -> List[ModelOutput]:
|
|
323
328
|
"""
|
|
324
329
|
Gets historical inference results from a workflow contract.
|
|
325
330
|
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Any, Callable, List, Dict, Type, Optional, Union
|
|
3
|
+
|
|
4
|
+
from langchain_core.tools import BaseTool, StructuredTool
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
import opengradient as og
|
|
8
|
+
from .types import ToolType
|
|
9
|
+
from opengradient import InferenceResult
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def create_run_model_tool(
|
|
14
|
+
tool_type: ToolType,
|
|
15
|
+
model_cid: str,
|
|
16
|
+
tool_name: str,
|
|
17
|
+
model_input_provider: Callable[..., Dict[str, Union[str, int, float, List, np.ndarray]]],
|
|
18
|
+
model_output_formatter: Callable[[InferenceResult], str],
|
|
19
|
+
tool_input_schema: Optional[Type[BaseModel]] = None,
|
|
20
|
+
tool_description: str = "Executes the given ML model",
|
|
21
|
+
inference_mode: og.InferenceMode = og.InferenceMode.VANILLA,
|
|
22
|
+
) -> BaseTool | Callable:
|
|
23
|
+
"""
|
|
24
|
+
Creates a tool that wraps an OpenGradient model for inference.
|
|
25
|
+
|
|
26
|
+
This function generates a tool that can be integrated into either a LangChain pipeline
|
|
27
|
+
or a Swarm system, allowing the model to be executed as part of a chain of operations.
|
|
28
|
+
The tool uses the provided input_getter function to obtain the necessary input data and
|
|
29
|
+
runs inference using the specified OpenGradient model.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
tool_type (ToolType): Specifies the framework to create the tool for. Use
|
|
33
|
+
ToolType.LANGCHAIN for LangChain integration or ToolType.SWARM for Swarm
|
|
34
|
+
integration.
|
|
35
|
+
model_cid (str): The CID of the OpenGradient model to be executed.
|
|
36
|
+
tool_name (str): The name to assign to the created tool. This will be used to identify
|
|
37
|
+
and invoke the tool within the agent.
|
|
38
|
+
model_input_provider (Callable): A function that takes in the tool_input_schema with arguments
|
|
39
|
+
filled by the agent and returns input data required by the model.
|
|
40
|
+
|
|
41
|
+
The function should return data in a format compatible with the model's expectations.
|
|
42
|
+
model_output_formatter (Callable[..., str]): A function that takes the output of
|
|
43
|
+
the OpenGradient infer method (with type InferenceResult) and formats it into a string.
|
|
44
|
+
|
|
45
|
+
This is required to ensure the output is compatible with the tool framework.
|
|
46
|
+
|
|
47
|
+
Default returns the InferenceResult object.
|
|
48
|
+
|
|
49
|
+
InferenceResult has attributes:
|
|
50
|
+
* transaction_hash (str): Blockchain hash for the transaction
|
|
51
|
+
* model_output (Dict[str, np.ndarray]): Output of the ONNX model
|
|
52
|
+
tool_input_schema (Type[BaseModel], optional): A Pydantic BaseModel class defining the
|
|
53
|
+
input schema.
|
|
54
|
+
|
|
55
|
+
For LangChain tools the schema will be used directly. The defined schema will be used as
|
|
56
|
+
input keyword arguments for the `model_input_provider` function. If no arguments are required
|
|
57
|
+
for the `model_input_provider` function then this schema can be unspecified.
|
|
58
|
+
|
|
59
|
+
For Swarm tools the schema will be converted to appropriate annotations.
|
|
60
|
+
|
|
61
|
+
Default is None -- an empty schema will be provided for LangChain.
|
|
62
|
+
tool_description (str, optional): A description of what the tool does. Defaults to
|
|
63
|
+
"Executes the given ML model".
|
|
64
|
+
inference_mode (og.InferenceMode, optional): The inference mode to use when running
|
|
65
|
+
the model. Defaults to VANILLA.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
BaseTool: For ToolType.LANGCHAIN, returns a LangChain StructuredTool.
|
|
69
|
+
Callable: For ToolType.SWARM, returns a decorated function with appropriate metadata.
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
ValueError: If an invalid tool_type is provided.
|
|
73
|
+
|
|
74
|
+
Examples:
|
|
75
|
+
>>> from pydantic import BaseModel, Field
|
|
76
|
+
>>> from enum import Enum
|
|
77
|
+
>>> from opengradient.alphasense import create_run_model_tool
|
|
78
|
+
>>> class Token(str, Enum):
|
|
79
|
+
... ETH = "ethereum"
|
|
80
|
+
... BTC = "bitcoin"
|
|
81
|
+
...
|
|
82
|
+
>>> class InputSchema(BaseModel):
|
|
83
|
+
... token: Token = Field(default=Token.ETH, description="Token name specified by user.")
|
|
84
|
+
...
|
|
85
|
+
>>> eth_model_input = {"price_series": [2010.1, 2012.3, 2020.1, 2019.2]} # Example data
|
|
86
|
+
>>> btc_model_input = {"price_series": [100001.1, 100013.2, 100149.2, 99998.1]} # Example data
|
|
87
|
+
>>> def model_input_provider(**llm_input):
|
|
88
|
+
... token = llm_input.get("token")
|
|
89
|
+
... if token == Token.BTC:
|
|
90
|
+
... return btc_model_input
|
|
91
|
+
... elif token == Token.ETH:
|
|
92
|
+
... return eth_model_input
|
|
93
|
+
... else:
|
|
94
|
+
... raise ValueError("Unexpected token found")
|
|
95
|
+
...
|
|
96
|
+
>>> def output_formatter(inference_result):
|
|
97
|
+
... return format(float(inference_result.model_output["std"].item()), ".3%")
|
|
98
|
+
...
|
|
99
|
+
>>> run_model_tool = create_run_model_tool(
|
|
100
|
+
... tool_type=ToolType.LANGCHAIN,
|
|
101
|
+
... model_cid="QmZdSfHWGJyzBiB2K98egzu3MypPcv4R1ASypUxwZ1MFUG",
|
|
102
|
+
... tool_name="Return_volatility_tool",
|
|
103
|
+
... model_input_provider=model_input_provider,
|
|
104
|
+
... model_output_formatter=output_formatter,
|
|
105
|
+
... tool_input_schema=InputSchema,
|
|
106
|
+
... tool_description="This tool takes a token and measures the return volatility (standard deviation of returns).",
|
|
107
|
+
... inference_mode=og.InferenceMode.VANILLA,
|
|
108
|
+
... )
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def model_executor(**llm_input):
|
|
112
|
+
# Pass LLM input arguments (formatted based on tool_input_schema) as parameters into model_input_provider
|
|
113
|
+
model_input = model_input_provider(**llm_input)
|
|
114
|
+
|
|
115
|
+
inference_result = og.infer(model_cid=model_cid, inference_mode=inference_mode, model_input=model_input)
|
|
116
|
+
|
|
117
|
+
return model_output_formatter(inference_result)
|
|
118
|
+
|
|
119
|
+
if tool_type == ToolType.LANGCHAIN:
|
|
120
|
+
if not tool_input_schema:
|
|
121
|
+
tool_input_schema = type("EmptyInputSchema", (BaseModel,), {})
|
|
122
|
+
|
|
123
|
+
return StructuredTool.from_function(
|
|
124
|
+
func=model_executor, name=tool_name, description=tool_description, args_schema=tool_input_schema
|
|
125
|
+
)
|
|
126
|
+
elif tool_type == ToolType.SWARM:
|
|
127
|
+
model_executor.__name__ = tool_name
|
|
128
|
+
model_executor.__doc__ = tool_description
|
|
129
|
+
# Convert Pydantic model to Swarm annotations if provided
|
|
130
|
+
if tool_input_schema:
|
|
131
|
+
model_executor.__annotations__ = _convert_pydantic_to_annotations(tool_input_schema)
|
|
132
|
+
return model_executor
|
|
133
|
+
else:
|
|
134
|
+
raise ValueError(f"Invalid tooltype: {tool_type}")
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _convert_pydantic_to_annotations(model: Type[BaseModel]) -> Dict[str, Any]:
|
|
138
|
+
"""
|
|
139
|
+
Convert a Pydantic model to function annotations format used by Swarm.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
model: A Pydantic BaseModel class
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Dict mapping field names to (type, description) tuples
|
|
146
|
+
"""
|
|
147
|
+
annotations = {}
|
|
148
|
+
for field_name, field in model.model_fields.items():
|
|
149
|
+
field_type = field.annotation
|
|
150
|
+
description = field.description or ""
|
|
151
|
+
annotations[field_name] = (field_type, description)
|
|
152
|
+
return annotations
|
|
@@ -13,7 +13,6 @@ from web3 import Web3
|
|
|
13
13
|
from web3.exceptions import ContractLogicError
|
|
14
14
|
from web3.logs import DISCARD
|
|
15
15
|
|
|
16
|
-
from . import utils
|
|
17
16
|
from .exceptions import OpenGradientError
|
|
18
17
|
from .proto import infer_pb2, infer_pb2_grpc
|
|
19
18
|
from .types import (
|
|
@@ -26,8 +25,11 @@ from .types import (
|
|
|
26
25
|
TextGenerationOutput,
|
|
27
26
|
SchedulerParams,
|
|
28
27
|
InferenceResult,
|
|
28
|
+
ModelRepository,
|
|
29
|
+
FileUploadResult,
|
|
29
30
|
)
|
|
30
31
|
from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT, DEFAULT_SCHEDULER_ADDRESS
|
|
32
|
+
from .utils import convert_array_to_model_output, convert_to_model_input, convert_to_model_output
|
|
31
33
|
|
|
32
34
|
_FIREBASE_CONFIG = {
|
|
33
35
|
"apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
|
|
@@ -53,7 +55,7 @@ class Client:
|
|
|
53
55
|
_blockchain: Web3
|
|
54
56
|
_wallet_account: LocalAccount
|
|
55
57
|
|
|
56
|
-
_hub_user: Dict
|
|
58
|
+
_hub_user: Optional[Dict]
|
|
57
59
|
_inference_abi: Dict
|
|
58
60
|
|
|
59
61
|
def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: Optional[str], password: Optional[str]):
|
|
@@ -88,7 +90,7 @@ class Client:
|
|
|
88
90
|
logging.error(f"Authentication failed: {str(e)}")
|
|
89
91
|
raise
|
|
90
92
|
|
|
91
|
-
def create_model(self, model_name: str, model_desc: str, version: str = "1.00") ->
|
|
93
|
+
def create_model(self, model_name: str, model_desc: str, version: str = "1.00") -> ModelRepository:
|
|
92
94
|
"""
|
|
93
95
|
Create a new model with the given model_name and model_desc, and a specified version.
|
|
94
96
|
|
|
@@ -111,39 +113,21 @@ class Client:
|
|
|
111
113
|
payload = {"name": model_name, "description": model_desc}
|
|
112
114
|
|
|
113
115
|
try:
|
|
114
|
-
logging.debug(f"Create Model URL: {url}")
|
|
115
|
-
logging.debug(f"Headers: {headers}")
|
|
116
|
-
logging.debug(f"Payload: {payload}")
|
|
117
|
-
|
|
118
116
|
response = requests.post(url, json=payload, headers=headers)
|
|
119
117
|
response.raise_for_status()
|
|
118
|
+
except requests.HTTPError as e:
|
|
119
|
+
error_details = f"HTTP {e.response.status_code}: {e.response.text}"
|
|
120
|
+
raise OpenGradientError(f"Model creation failed: {error_details}") from e
|
|
120
121
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
logging.info(f"Model creation successful. Model name: {model_name}")
|
|
122
|
+
json_response = response.json()
|
|
123
|
+
model_name = json_response.get("name")
|
|
124
|
+
if not model_name:
|
|
125
|
+
raise Exception(f"Model creation response missing 'name'. Full response: {json_response}")
|
|
126
126
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
version_response = self.create_version(model_name, version)
|
|
130
|
-
logging.info(f"Version creation successful. Version string: {version_response['versionString']}")
|
|
131
|
-
except Exception as ve:
|
|
132
|
-
logging.error(f"Version creation failed, but model was created. Error: {str(ve)}")
|
|
133
|
-
return {"name": model_name, "versionString": None, "version_error": str(ve)}
|
|
127
|
+
# Create the specified version for the newly created model
|
|
128
|
+
version_response = self.create_version(model_name, version)
|
|
134
129
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
except requests.RequestException as e:
|
|
138
|
-
logging.error(f"Model creation failed: {str(e)}")
|
|
139
|
-
if hasattr(e, "response") and e.response is not None:
|
|
140
|
-
logging.error(f"Response status code: {e.response.status_code}")
|
|
141
|
-
logging.error(f"Response headers: {e.response.headers}")
|
|
142
|
-
logging.error(f"Response content: {e.response.text}")
|
|
143
|
-
raise Exception(f"Model creation failed: {str(e)}")
|
|
144
|
-
except Exception as e:
|
|
145
|
-
logging.error(f"Unexpected error during model creation: {str(e)}")
|
|
146
|
-
raise
|
|
130
|
+
return ModelRepository(model_name, version_response["versionString"])
|
|
147
131
|
|
|
148
132
|
def create_version(self, model_name: str, notes: str = "", is_major: bool = False) -> dict:
|
|
149
133
|
"""
|
|
@@ -204,7 +188,7 @@ class Client:
|
|
|
204
188
|
logging.error(f"Unexpected error during version creation: {str(e)}")
|
|
205
189
|
raise
|
|
206
190
|
|
|
207
|
-
def upload(self, model_path: str, model_name: str, version: str) ->
|
|
191
|
+
def upload(self, model_path: str, model_name: str, version: str) -> FileUploadResult:
|
|
208
192
|
"""
|
|
209
193
|
Upload a model file to the server.
|
|
210
194
|
|
|
@@ -259,12 +243,9 @@ class Client:
|
|
|
259
243
|
if response.status_code == 201:
|
|
260
244
|
if response.content and response.content != b"null":
|
|
261
245
|
json_response = response.json()
|
|
262
|
-
|
|
263
|
-
logging.info(f"Upload successful. CID: {json_response.get('ipfsCid', 'N/A')}")
|
|
264
|
-
result = {"model_cid": json_response.get("ipfsCid"), "size": json_response.get("size")}
|
|
246
|
+
return FileUploadResult(json_response.get("ipfsCid"), json_response.get("size"))
|
|
265
247
|
else:
|
|
266
|
-
|
|
267
|
-
result = {"model_cid": None, "size": None}
|
|
248
|
+
raise RuntimeError("Empty or null response content received. Assuming upload was successful.")
|
|
268
249
|
elif response.status_code == 500:
|
|
269
250
|
error_message = "Internal server error occurred. Please try again later or contact support."
|
|
270
251
|
logging.error(error_message)
|
|
@@ -274,8 +255,6 @@ class Client:
|
|
|
274
255
|
logging.error(f"Upload failed with status code {response.status_code}: {error_message}")
|
|
275
256
|
raise OpenGradientError(f"Upload failed: {error_message}", status_code=response.status_code)
|
|
276
257
|
|
|
277
|
-
return result
|
|
278
|
-
|
|
279
258
|
except requests.RequestException as e:
|
|
280
259
|
logging.error(f"Request exception during upload: {str(e)}")
|
|
281
260
|
if hasattr(e, "response") and e.response is not None:
|
|
@@ -303,7 +282,9 @@ class Client:
|
|
|
303
282
|
max_retries (int, optional): Maximum number of retry attempts. Defaults to 5.
|
|
304
283
|
|
|
305
284
|
Returns:
|
|
306
|
-
InferenceResult:
|
|
285
|
+
InferenceResult (InferenceResult): A dataclass object containing the transaction hash and model output.
|
|
286
|
+
transaction_hash (str): Blockchain hash for the transaction
|
|
287
|
+
model_output (Dict[str, np.ndarray]): Output of the ONNX model
|
|
307
288
|
|
|
308
289
|
Raises:
|
|
309
290
|
OpenGradientError: If the inference fails.
|
|
@@ -313,7 +294,7 @@ class Client:
|
|
|
313
294
|
contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
|
|
314
295
|
|
|
315
296
|
inference_mode_uint8 = inference_mode.value
|
|
316
|
-
converted_model_input =
|
|
297
|
+
converted_model_input = convert_to_model_input(model_input)
|
|
317
298
|
|
|
318
299
|
run_function = contract.functions.run(model_cid, inference_mode_uint8, converted_model_input)
|
|
319
300
|
|
|
@@ -342,7 +323,7 @@ class Client:
|
|
|
342
323
|
raise OpenGradientError("InferenceResult event not found in transaction logs")
|
|
343
324
|
|
|
344
325
|
# TODO: This should return a ModelOutput class object
|
|
345
|
-
model_output =
|
|
326
|
+
model_output = convert_to_model_output(parsed_logs[0]["args"])
|
|
346
327
|
|
|
347
328
|
return InferenceResult(tx_hash.hex(), model_output)
|
|
348
329
|
|
|
@@ -751,7 +732,7 @@ class Client:
|
|
|
751
732
|
# if channel:
|
|
752
733
|
# channel.close()
|
|
753
734
|
|
|
754
|
-
def _get_abi(self, abi_name) ->
|
|
735
|
+
def _get_abi(self, abi_name) -> str:
|
|
755
736
|
"""
|
|
756
737
|
Returns the ABI for the requested contract.
|
|
757
738
|
"""
|
|
@@ -759,7 +740,7 @@ class Client:
|
|
|
759
740
|
with open(abi_path, "r") as f:
|
|
760
741
|
return json.load(f)
|
|
761
742
|
|
|
762
|
-
def _get_bin(self, bin_name) ->
|
|
743
|
+
def _get_bin(self, bin_name) -> str:
|
|
763
744
|
"""
|
|
764
745
|
Returns the bin for the requested contract.
|
|
765
746
|
"""
|
|
@@ -781,17 +762,20 @@ class Client:
|
|
|
781
762
|
"""
|
|
782
763
|
Deploy a new workflow contract with the specified parameters.
|
|
783
764
|
|
|
784
|
-
This function deploys a new workflow contract
|
|
785
|
-
|
|
786
|
-
the
|
|
765
|
+
This function deploys a new workflow contract on OpenGradient that connects
|
|
766
|
+
an AI model with its required input data. When executed, the workflow will fetch
|
|
767
|
+
the specified model, evaluate the input query to get data, and perform inference.
|
|
768
|
+
|
|
769
|
+
The workflow can be set to execute manually or automatically via a scheduler.
|
|
787
770
|
|
|
788
771
|
Args:
|
|
789
|
-
model_cid (str):
|
|
790
|
-
input_query (HistoricalInputQuery):
|
|
772
|
+
model_cid (str): CID of the model to be executed from the Model Hub
|
|
773
|
+
input_query (HistoricalInputQuery): Input definition for the model inference,
|
|
774
|
+
will be evaluated at runtime for each inference
|
|
791
775
|
input_tensor_name (str): Name of the input tensor expected by the model
|
|
792
776
|
scheduler_params (Optional[SchedulerParams]): Scheduler configuration for automated execution:
|
|
793
777
|
- frequency: Execution frequency in seconds
|
|
794
|
-
- duration_hours: How long
|
|
778
|
+
- duration_hours: How long the schedule should live for
|
|
795
779
|
|
|
796
780
|
Returns:
|
|
797
781
|
str: Deployed contract address. If scheduler_params was provided, the workflow
|
|
@@ -910,7 +894,7 @@ class Client:
|
|
|
910
894
|
# Get the result
|
|
911
895
|
result = contract.functions.getInferenceResult().call()
|
|
912
896
|
|
|
913
|
-
return
|
|
897
|
+
return convert_array_to_model_output(result)
|
|
914
898
|
|
|
915
899
|
def run_workflow(self, contract_address: str) -> ModelOutput:
|
|
916
900
|
"""
|
|
@@ -955,9 +939,9 @@ class Client:
|
|
|
955
939
|
# Get the inference result from the contract
|
|
956
940
|
result = contract.functions.getInferenceResult().call()
|
|
957
941
|
|
|
958
|
-
return
|
|
942
|
+
return convert_array_to_model_output(result)
|
|
959
943
|
|
|
960
|
-
def read_workflow_history(self, contract_address: str, num_results: int) -> List[
|
|
944
|
+
def read_workflow_history(self, contract_address: str, num_results: int) -> List[ModelOutput]:
|
|
961
945
|
"""
|
|
962
946
|
Gets historical inference results from a workflow contract.
|
|
963
947
|
|
|
@@ -969,18 +953,14 @@ class Client:
|
|
|
969
953
|
num_results (int): Number of historical results to retrieve
|
|
970
954
|
|
|
971
955
|
Returns:
|
|
972
|
-
List[
|
|
973
|
-
- prediction values
|
|
974
|
-
- timestamps
|
|
975
|
-
- any additional metadata stored with the result
|
|
976
|
-
|
|
956
|
+
List[ModelOutput]: List of historical inference results
|
|
977
957
|
"""
|
|
978
958
|
contract = self._blockchain.eth.contract(
|
|
979
959
|
address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
|
|
980
960
|
)
|
|
981
961
|
|
|
982
962
|
results = contract.functions.getLastInferenceResults(num_results).call()
|
|
983
|
-
return [
|
|
963
|
+
return [convert_array_to_model_output(result) for result in results]
|
|
984
964
|
|
|
985
965
|
|
|
986
966
|
def run_with_retry(txn_function, max_retries=DEFAULT_MAX_RETRY, retry_delay=DEFAULT_RETRY_DELAY_SEC):
|
|
@@ -47,29 +47,63 @@ class Number:
|
|
|
47
47
|
|
|
48
48
|
@dataclass
|
|
49
49
|
class NumberTensor:
|
|
50
|
+
"""
|
|
51
|
+
A container for numeric tensor data used as input for ONNX models.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
|
|
55
|
+
name: Identifier for this tensor in the model.
|
|
56
|
+
|
|
57
|
+
values: List of integer tuples representing the tensor data.
|
|
58
|
+
"""
|
|
59
|
+
|
|
50
60
|
name: str
|
|
51
61
|
values: List[Tuple[int, int]]
|
|
52
62
|
|
|
53
63
|
|
|
54
64
|
@dataclass
|
|
55
65
|
class StringTensor:
|
|
66
|
+
"""
|
|
67
|
+
A container for string tensor data used as input for ONNX models.
|
|
68
|
+
|
|
69
|
+
Attributes:
|
|
70
|
+
|
|
71
|
+
name: Identifier for this tensor in the model.
|
|
72
|
+
|
|
73
|
+
values: List of strings representing the tensor data.
|
|
74
|
+
"""
|
|
75
|
+
|
|
56
76
|
name: str
|
|
57
77
|
values: List[str]
|
|
58
78
|
|
|
59
79
|
|
|
60
80
|
@dataclass
|
|
61
81
|
class ModelInput:
|
|
82
|
+
"""
|
|
83
|
+
A collection of tensor inputs required for ONNX model inference.
|
|
84
|
+
|
|
85
|
+
Attributes:
|
|
86
|
+
|
|
87
|
+
numbers: Collection of numeric tensors for the model.
|
|
88
|
+
|
|
89
|
+
strings: Collection of string tensors for the model.
|
|
90
|
+
"""
|
|
91
|
+
|
|
62
92
|
numbers: List[NumberTensor]
|
|
63
93
|
strings: List[StringTensor]
|
|
64
94
|
|
|
65
95
|
|
|
66
96
|
class InferenceMode(Enum):
|
|
97
|
+
"""Enum for the different inference modes available for inference (VANILLA, ZKML, TEE)"""
|
|
98
|
+
|
|
67
99
|
VANILLA = 0
|
|
68
100
|
ZKML = 1
|
|
69
101
|
TEE = 2
|
|
70
102
|
|
|
71
103
|
|
|
72
104
|
class LlmInferenceMode(Enum):
|
|
105
|
+
"""Enum for differetn inference modes available for LLM inferences (VANILLA, TEE)"""
|
|
106
|
+
|
|
73
107
|
VANILLA = 0
|
|
74
108
|
TEE = 1
|
|
75
109
|
|
|
@@ -89,14 +123,14 @@ class ModelOutput:
|
|
|
89
123
|
@dataclass
|
|
90
124
|
class InferenceResult:
|
|
91
125
|
"""
|
|
92
|
-
Output for ML inference requests
|
|
126
|
+
Output for ML inference requests.
|
|
127
|
+
This class has two fields
|
|
128
|
+
transaction_hash (str): Blockchain hash for the transaction
|
|
129
|
+
model_output (Dict[str, np.ndarray]): Output of the ONNX model
|
|
93
130
|
"""
|
|
94
131
|
|
|
95
132
|
transaction_hash: str
|
|
96
|
-
"""Blockchain hash for the transaction."""
|
|
97
|
-
|
|
98
133
|
model_output: Dict[str, np.ndarray]
|
|
99
|
-
"""Output of ONNX model"""
|
|
100
134
|
|
|
101
135
|
|
|
102
136
|
@dataclass
|
|
@@ -184,3 +218,15 @@ class SchedulerParams:
|
|
|
184
218
|
if data is None:
|
|
185
219
|
return None
|
|
186
220
|
return SchedulerParams(frequency=data.get("frequency", 600), duration_hours=data.get("duration_hours", 2))
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
@dataclass
|
|
224
|
+
class ModelRepository:
|
|
225
|
+
name: str
|
|
226
|
+
initialVersion: str
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@dataclass
|
|
230
|
+
class FileUploadResult:
|
|
231
|
+
modelCid: str
|
|
232
|
+
size: int
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenGradient Hardcoded Models
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .workflow_models import *
|
|
6
|
+
from .types import WorkflowModelOutput
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"read_eth_usdt_one_hour_volatility_forecast",
|
|
10
|
+
"read_btc_1_hour_price_forecast",
|
|
11
|
+
"read_eth_1_hour_price_forecast",
|
|
12
|
+
"read_sol_1_hour_price_forecast",
|
|
13
|
+
"read_sui_1_hour_price_forecast",
|
|
14
|
+
"read_sui_usdt_30_min_price_forecast",
|
|
15
|
+
"read_sui_usdt_6_hour_price_forecast",
|
|
16
|
+
"WorkflowModelOutput",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
__pdoc__ = {
|
|
20
|
+
"read_eth_usdt_one_hour_volatility_forecast": False,
|
|
21
|
+
"read_btc_1_hour_price_forecast": False,
|
|
22
|
+
"read_eth_1_hour_price_forecast": False,
|
|
23
|
+
"read_sol_1_hour_price_forecast": False,
|
|
24
|
+
"read_sui_1_hour_price_forecast": False,
|
|
25
|
+
"read_sui_usdt_30_min_price_forecast": False,
|
|
26
|
+
"read_sui_usdt_6_hour_price_forecast": False,
|
|
27
|
+
"WorkflowModelOutput": False,
|
|
28
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Constants used by the models module"""
|
|
2
|
+
|
|
3
|
+
# URL for OpenGradient block explorer
|
|
4
|
+
BLOCK_EXPLORER_URL = "https://explorer.opengradient.ai/"
|
|
5
|
+
|
|
6
|
+
# Workflow Contract Addresses
|
|
7
|
+
ETH_USDT_1_HOUR_VOLATILITY_ADDRESS = "0xD5629A5b95dde11e4B5772B5Ad8a13B933e33845"
|
|
8
|
+
BTC_1_HOUR_PRICE_FORECAST_ADDRESS = "0xb4146E095c2CD2a7aA497EdfB513F1cB868Dcc3D"
|
|
9
|
+
ETH_1_HOUR_PRICE_FORECAST_ADDRESS = "0x58826c6dc9A608238d9d57a65bDd50EcaE27FE99"
|
|
10
|
+
SOL_1_HOUR_PRICE_FORECAST_ADDRESS = "0xaE12EC7314e91A612CF4Fa4DC07dC02cfCc4dF95"
|
|
11
|
+
SUI_1_HOUR_PRICE_FORECAST_ADDRESS = "0x90FE6434C46838c96E9ACb8d394C0fdfe38C6182"
|
|
12
|
+
SUI_30_MINUTE_PRICE_FORECAST_ADDRESS = "0xD85BA71f5701dc4C5BDf9780189Db49C6F3708D2"
|
|
13
|
+
SUI_6_HOUR_PRICE_FORECAST_ADDRESS = "0x3C2E4DbD653Bd30F1333d456480c1b7aB122e946"
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Type definitions for models module."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class WorkflowModelOutput:
|
|
8
|
+
"""
|
|
9
|
+
Output definition for reading from a workflow model.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
result: str
|
|
13
|
+
"""Result of the workflow formatted as a string."""
|
|
14
|
+
|
|
15
|
+
block_explorer_link: str = field(default="")
|
|
16
|
+
"""(Optional) Block explorer link for the smart contract address of the workflow."""
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Utility functions for the models module."""
|
|
2
|
+
|
|
3
|
+
from .constants import BLOCK_EXPLORER_URL
|
|
4
|
+
from typing import Callable, Any
|
|
5
|
+
from .types import WorkflowModelOutput
|
|
6
|
+
import opengradient as og
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def create_block_explorer_link_smart_contract(transaction_hash: str) -> str:
|
|
10
|
+
"""Create block explorer link for smart contract."""
|
|
11
|
+
block_explorer_url = BLOCK_EXPLORER_URL + "address/" + transaction_hash
|
|
12
|
+
return block_explorer_url
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def create_block_explorer_link_transaction(transaction_hash: str) -> str:
|
|
16
|
+
"""Create block explorer link for transaction."""
|
|
17
|
+
block_explorer_url = BLOCK_EXPLORER_URL + "tx/" + transaction_hash
|
|
18
|
+
return block_explorer_url
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def read_workflow_wrapper(contract_address: str, format_function: Callable[..., str]) -> WorkflowModelOutput:
|
|
22
|
+
"""
|
|
23
|
+
Wrapper function for reading from models through workflows.
|
|
24
|
+
Args:
|
|
25
|
+
contract_address (str): Smart contract address of the workflow
|
|
26
|
+
format_function (Callable): Function for formatting the result returned by read_workflow
|
|
27
|
+
"""
|
|
28
|
+
try:
|
|
29
|
+
result = og.read_workflow_result(contract_address)
|
|
30
|
+
|
|
31
|
+
formatted_result = format_function(result)
|
|
32
|
+
block_explorer_link = create_block_explorer_link_smart_contract(contract_address)
|
|
33
|
+
|
|
34
|
+
return WorkflowModelOutput(
|
|
35
|
+
result=formatted_result,
|
|
36
|
+
block_explorer_link=block_explorer_link,
|
|
37
|
+
)
|
|
38
|
+
except Exception as e:
|
|
39
|
+
raise RuntimeError(f"Error reading from workflow with address {contract_address}: {e!s}")
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Repository of OpenGradient quantitative workflow models."""
|
|
2
|
+
|
|
3
|
+
import opengradient as og
|
|
4
|
+
from .constants import (
|
|
5
|
+
ETH_USDT_1_HOUR_VOLATILITY_ADDRESS,
|
|
6
|
+
BTC_1_HOUR_PRICE_FORECAST_ADDRESS,
|
|
7
|
+
ETH_1_HOUR_PRICE_FORECAST_ADDRESS,
|
|
8
|
+
SOL_1_HOUR_PRICE_FORECAST_ADDRESS,
|
|
9
|
+
SUI_1_HOUR_PRICE_FORECAST_ADDRESS,
|
|
10
|
+
SUI_30_MINUTE_PRICE_FORECAST_ADDRESS,
|
|
11
|
+
SUI_6_HOUR_PRICE_FORECAST_ADDRESS,
|
|
12
|
+
)
|
|
13
|
+
from .utils import read_workflow_wrapper
|
|
14
|
+
from .types import WorkflowModelOutput
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def read_eth_usdt_one_hour_volatility_forecast() -> WorkflowModelOutput:
|
|
18
|
+
"""
|
|
19
|
+
Read from the ETH/USDT one hour volatility forecast model workflow on the OpenGradient network.
|
|
20
|
+
|
|
21
|
+
More information on this model can be found at https://hub.opengradient.ai/models/OpenGradient/og-1hr-volatility-ethusdt.
|
|
22
|
+
"""
|
|
23
|
+
return read_workflow_wrapper(
|
|
24
|
+
contract_address=ETH_USDT_1_HOUR_VOLATILITY_ADDRESS, format_function=lambda x: format(float(x.numbers["Y"].item()), ".10%")
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def read_btc_1_hour_price_forecast() -> WorkflowModelOutput:
|
|
29
|
+
"""
|
|
30
|
+
Read from the BTC one hour return forecast workflow on the OpenGradient network.
|
|
31
|
+
|
|
32
|
+
More information on this model can be found at https://hub.opengradient.ai/models/OpenGradient/og-btc-1hr-forecast.
|
|
33
|
+
"""
|
|
34
|
+
return read_workflow_wrapper(
|
|
35
|
+
contract_address=BTC_1_HOUR_PRICE_FORECAST_ADDRESS,
|
|
36
|
+
format_function=lambda x: format(float(x.numbers["regression_output"].item()), ".10%"),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def read_eth_1_hour_price_forecast() -> WorkflowModelOutput:
|
|
41
|
+
"""
|
|
42
|
+
Read from the ETH one hour return forecast workflow on the OpenGradient network.
|
|
43
|
+
|
|
44
|
+
More information on this model can be found at https://hub.opengradient.ai/models/OpenGradient/og-eth-1hr-forecast.
|
|
45
|
+
"""
|
|
46
|
+
return read_workflow_wrapper(
|
|
47
|
+
contract_address=ETH_1_HOUR_PRICE_FORECAST_ADDRESS,
|
|
48
|
+
format_function=lambda x: format(float(x.numbers["regression_output"].item()), ".10%"),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def read_sol_1_hour_price_forecast() -> WorkflowModelOutput:
|
|
53
|
+
"""
|
|
54
|
+
Read from the SOL one hour return forecast workflow on the OpenGradient network.
|
|
55
|
+
|
|
56
|
+
More information on this model can be found at https://hub.opengradient.ai/models/OpenGradient/og-sol-1hr-forecast.
|
|
57
|
+
"""
|
|
58
|
+
return read_workflow_wrapper(
|
|
59
|
+
contract_address=SOL_1_HOUR_PRICE_FORECAST_ADDRESS,
|
|
60
|
+
format_function=lambda x: format(float(x.numbers["regression_output"].item()), ".10%"),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def read_sui_1_hour_price_forecast() -> WorkflowModelOutput:
|
|
65
|
+
"""
|
|
66
|
+
Read from the SUI one hour return forecast workflow on the OpenGradient network.
|
|
67
|
+
|
|
68
|
+
More information on this model can be found at https://hub.opengradient.ai/models/OpenGradient/og-sui-1hr-forecast.
|
|
69
|
+
"""
|
|
70
|
+
return read_workflow_wrapper(
|
|
71
|
+
contract_address=SUI_1_HOUR_PRICE_FORECAST_ADDRESS,
|
|
72
|
+
format_function=lambda x: format(float(x.numbers["regression_output"].item()), ".10%"),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def read_sui_usdt_30_min_price_forecast() -> WorkflowModelOutput:
|
|
77
|
+
"""
|
|
78
|
+
Read from the SUI/USDT pair 30 min return forecast workflow on the OpenGradient network.
|
|
79
|
+
|
|
80
|
+
More information on this model can be found at https://hub.opengradient.ai/models/OpenGradient/og-30min-return-suiusdt.
|
|
81
|
+
"""
|
|
82
|
+
return read_workflow_wrapper(
|
|
83
|
+
contract_address=SUI_30_MINUTE_PRICE_FORECAST_ADDRESS,
|
|
84
|
+
format_function=lambda x: format(float(x.numbers["destandardized_prediction"].item()), ".10%"),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def read_sui_usdt_6_hour_price_forecast() -> WorkflowModelOutput:
|
|
89
|
+
"""
|
|
90
|
+
Read from the SUI/USDT pair 6 hour return forecast workflow on the OpenGradient network.
|
|
91
|
+
|
|
92
|
+
More information on this model can be found at https://hub.opengradient.ai/models/OpenGradient/og-6h-return-suiusdt.
|
|
93
|
+
"""
|
|
94
|
+
return read_workflow_wrapper(
|
|
95
|
+
contract_address=SUI_6_HOUR_PRICE_FORECAST_ADDRESS,
|
|
96
|
+
format_function=lambda x: format(float(x.numbers["destandardized_prediction"].item()), ".10%"),
|
|
97
|
+
)
|
|
@@ -29,4 +29,9 @@ src/opengradient/llm/og_openai.py
|
|
|
29
29
|
src/opengradient/proto/__init__.py
|
|
30
30
|
src/opengradient/proto/infer.proto
|
|
31
31
|
src/opengradient/proto/infer_pb2.py
|
|
32
|
-
src/opengradient/proto/infer_pb2_grpc.py
|
|
32
|
+
src/opengradient/proto/infer_pb2_grpc.py
|
|
33
|
+
src/opengradient/workflow_models/__init__.py
|
|
34
|
+
src/opengradient/workflow_models/constants.py
|
|
35
|
+
src/opengradient/workflow_models/types.py
|
|
36
|
+
src/opengradient/workflow_models/utils.py
|
|
37
|
+
src/opengradient/workflow_models/workflow_models.py
|
|
@@ -1,114 +0,0 @@
|
|
|
1
|
-
from enum import Enum
|
|
2
|
-
from typing import Any, Callable, Dict, Type, Optional
|
|
3
|
-
|
|
4
|
-
from langchain_core.tools import BaseTool, StructuredTool
|
|
5
|
-
from pydantic import BaseModel
|
|
6
|
-
|
|
7
|
-
import opengradient as og
|
|
8
|
-
from .types import ToolType
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def create_run_model_tool(
|
|
12
|
-
tool_type: ToolType,
|
|
13
|
-
model_cid: str,
|
|
14
|
-
tool_name: str,
|
|
15
|
-
input_getter: Callable,
|
|
16
|
-
output_formatter: Callable[..., str] = lambda x: x,
|
|
17
|
-
input_schema: Optional[Type[BaseModel]] = None,
|
|
18
|
-
tool_description: str = "Executes the given ML model",
|
|
19
|
-
inference_mode: og.InferenceMode = og.InferenceMode.VANILLA,
|
|
20
|
-
) -> BaseTool | Callable:
|
|
21
|
-
"""
|
|
22
|
-
Creates a tool that wraps an OpenGradient model for inference.
|
|
23
|
-
|
|
24
|
-
This function generates a tool that can be integrated into either a LangChain pipeline
|
|
25
|
-
or a Swarm system, allowing the model to be executed as part of a chain of operations.
|
|
26
|
-
The tool uses the provided input_getter function to obtain the necessary input data and
|
|
27
|
-
runs inference using the specified OpenGradient model.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
tool_type (ToolType): Specifies the framework to create the tool for. Use
|
|
31
|
-
ToolType.LANGCHAIN for LangChain integration or ToolType.SWARM for Swarm
|
|
32
|
-
integration.
|
|
33
|
-
model_cid (str): The CID of the OpenGradient model to be executed.
|
|
34
|
-
tool_name (str): The name to assign to the created tool. This will be used to identify
|
|
35
|
-
and invoke the tool within the agent.
|
|
36
|
-
input_getter (Callable): A function that returns the input data required by the model.
|
|
37
|
-
The function should return data in a format compatible with the model's expectations.
|
|
38
|
-
output_formatter (Callable[..., str], optional): A function that takes the model output and
|
|
39
|
-
formats it into a string. This is required to ensure the output is compatible
|
|
40
|
-
with the tool framework. Default returns string as is.
|
|
41
|
-
input_schema (Type[BaseModel], optional): A Pydantic BaseModel class defining the
|
|
42
|
-
input schema. This will be used directly for LangChain tools and converted
|
|
43
|
-
to appropriate annotations for Swarm tools. Default is None.
|
|
44
|
-
tool_description (str, optional): A description of what the tool does. Defaults to
|
|
45
|
-
"Executes the given ML model".
|
|
46
|
-
inference_mode (og.InferenceMode, optional): The inference mode to use when running
|
|
47
|
-
the model. Defaults to VANILLA.
|
|
48
|
-
|
|
49
|
-
Returns:
|
|
50
|
-
BaseTool: For ToolType.LANGCHAIN, returns a LangChain StructuredTool.
|
|
51
|
-
Callable: For ToolType.SWARM, returns a decorated function with appropriate metadata.
|
|
52
|
-
|
|
53
|
-
Raises:
|
|
54
|
-
ValueError: If an invalid tool_type is provided.
|
|
55
|
-
|
|
56
|
-
Examples:
|
|
57
|
-
>>> from pydantic import BaseModel, Field
|
|
58
|
-
>>> class ClassifierInput(BaseModel):
|
|
59
|
-
... query: str = Field(description="User query to analyze")
|
|
60
|
-
... parameters: dict = Field(description="Additional parameters")
|
|
61
|
-
>>> def get_input():
|
|
62
|
-
... return {"text": "Sample input text"}
|
|
63
|
-
>>> def format_output(output):
|
|
64
|
-
... return str(output.get("class", "Unknown"))
|
|
65
|
-
>>> # Create a LangChain tool
|
|
66
|
-
>>> langchain_tool = create_og_model_tool(
|
|
67
|
-
... tool_type=ToolType.LANGCHAIN,
|
|
68
|
-
... model_cid="Qm...",
|
|
69
|
-
... tool_name="text_classifier",
|
|
70
|
-
... input_getter=get_input,
|
|
71
|
-
... output_formatter=format_output,
|
|
72
|
-
... input_schema=ClassifierInput
|
|
73
|
-
... tool_description="Classifies text into categories"
|
|
74
|
-
... )
|
|
75
|
-
"""
|
|
76
|
-
|
|
77
|
-
# define runnable
|
|
78
|
-
def model_executor(**llm_input):
|
|
79
|
-
# Combine LLM input with input provided by code
|
|
80
|
-
combined_input = {**llm_input, **input_getter()}
|
|
81
|
-
|
|
82
|
-
_, output = og.infer(model_cid=model_cid, inference_mode=inference_mode, model_input=combined_input)
|
|
83
|
-
|
|
84
|
-
return output_formatter(output)
|
|
85
|
-
|
|
86
|
-
if tool_type == ToolType.LANGCHAIN:
|
|
87
|
-
return StructuredTool.from_function(func=model_executor, name=tool_name, description=tool_description, args_schema=input_schema)
|
|
88
|
-
elif tool_type == ToolType.SWARM:
|
|
89
|
-
model_executor.__name__ = tool_name
|
|
90
|
-
model_executor.__doc__ = tool_description
|
|
91
|
-
# Convert Pydantic model to Swarm annotations if provided
|
|
92
|
-
if input_schema:
|
|
93
|
-
model_executor.__annotations__ = _convert_pydantic_to_annotations(input_schema)
|
|
94
|
-
return model_executor
|
|
95
|
-
else:
|
|
96
|
-
raise ValueError(f"Invalid tooltype: {tool_type}")
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def _convert_pydantic_to_annotations(model: Type[BaseModel]) -> Dict[str, Any]:
|
|
100
|
-
"""
|
|
101
|
-
Convert a Pydantic model to function annotations format used by Swarm.
|
|
102
|
-
|
|
103
|
-
Args:
|
|
104
|
-
model: A Pydantic BaseModel class
|
|
105
|
-
|
|
106
|
-
Returns:
|
|
107
|
-
Dict mapping field names to (type, description) tuples
|
|
108
|
-
"""
|
|
109
|
-
annotations = {}
|
|
110
|
-
for field_name, field in model.model_fields.items():
|
|
111
|
-
field_type = field.annotation
|
|
112
|
-
description = field.description or ""
|
|
113
|
-
annotations[field_name] = (field_type, description)
|
|
114
|
-
return annotations
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|