opengradient 0.3.23__py3-none-any.whl → 0.3.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opengradient/__init__.py +125 -98
- opengradient/account.py +6 -4
- opengradient/cli.py +151 -154
- opengradient/client.py +300 -362
- opengradient/defaults.py +7 -7
- opengradient/exceptions.py +25 -0
- opengradient/llm/__init__.py +7 -10
- opengradient/llm/og_langchain.py +34 -51
- opengradient/llm/og_openai.py +54 -61
- opengradient/mltools/__init__.py +2 -7
- opengradient/mltools/model_tool.py +20 -26
- opengradient/proto/infer_pb2.py +24 -29
- opengradient/proto/infer_pb2_grpc.py +95 -86
- opengradient/types.py +39 -35
- opengradient/utils.py +30 -31
- {opengradient-0.3.23.dist-info → opengradient-0.3.25.dist-info}/METADATA +1 -1
- opengradient-0.3.25.dist-info/RECORD +26 -0
- opengradient-0.3.23.dist-info/RECORD +0 -26
- {opengradient-0.3.23.dist-info → opengradient-0.3.25.dist-info}/LICENSE +0 -0
- {opengradient-0.3.23.dist-info → opengradient-0.3.25.dist-info}/WHEEL +0 -0
- {opengradient-0.3.23.dist-info → opengradient-0.3.25.dist-info}/entry_points.txt +0 -0
- {opengradient-0.3.23.dist-info → opengradient-0.3.25.dist-info}/top_level.txt +0 -0
opengradient/client.py
CHANGED
|
@@ -1,76 +1,48 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
import json
|
|
3
2
|
import logging
|
|
4
3
|
import os
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import Dict, List, Optional, Tuple, Union
|
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
7
8
|
|
|
8
9
|
import firebase
|
|
10
|
+
import grpc
|
|
9
11
|
import numpy as np
|
|
10
12
|
import requests
|
|
13
|
+
from eth_account.account import LocalAccount
|
|
11
14
|
from web3 import Web3
|
|
12
15
|
from web3.exceptions import ContractLogicError
|
|
13
16
|
from web3.logs import DISCARD
|
|
14
17
|
|
|
15
|
-
from
|
|
16
|
-
from
|
|
17
|
-
from
|
|
18
|
-
|
|
19
|
-
InferenceMode,
|
|
20
|
-
LlmInferenceMode,
|
|
21
|
-
LLM,
|
|
22
|
-
TEE_LLM,
|
|
23
|
-
ModelOutput,
|
|
24
|
-
SchedulerParams
|
|
25
|
-
)
|
|
26
|
-
|
|
27
|
-
import grpc
|
|
28
|
-
import time
|
|
29
|
-
import uuid
|
|
30
|
-
from google.protobuf import timestamp_pb2
|
|
31
|
-
|
|
32
|
-
from opengradient.proto import infer_pb2
|
|
33
|
-
from opengradient.proto import infer_pb2_grpc
|
|
18
|
+
from . import utils
|
|
19
|
+
from .exceptions import OpenGradientError
|
|
20
|
+
from .proto import infer_pb2, infer_pb2_grpc
|
|
21
|
+
from .types import LLM, TEE_LLM, HistoricalInputQuery, InferenceMode, LlmInferenceMode, ModelOutput, SchedulerParams
|
|
34
22
|
from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT
|
|
35
23
|
|
|
36
|
-
|
|
24
|
+
_FIREBASE_CONFIG = {
|
|
25
|
+
"apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
|
|
26
|
+
"authDomain": "vanna-portal-418018.firebaseapp.com",
|
|
27
|
+
"projectId": "vanna-portal-418018",
|
|
28
|
+
"storageBucket": "vanna-portal-418018.appspot.com",
|
|
29
|
+
"appId": "1:487761246229:web:259af6423a504d2316361c",
|
|
30
|
+
"databaseURL": "",
|
|
31
|
+
}
|
|
37
32
|
|
|
38
|
-
def run_with_retry(txn_function, max_retries=5):
|
|
39
|
-
"""
|
|
40
|
-
Execute a blockchain transaction with retry logic.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
txn_function: Function that executes the transaction
|
|
44
|
-
max_retries (int): Maximum number of retry attempts
|
|
45
|
-
"""
|
|
46
|
-
last_error = None
|
|
47
|
-
for attempt in range(max_retries):
|
|
48
|
-
try:
|
|
49
|
-
return txn_function()
|
|
50
|
-
except Exception as e:
|
|
51
|
-
last_error = e
|
|
52
|
-
if attempt < max_retries - 1:
|
|
53
|
-
if "nonce too low" in str(e) or "nonce too high" in str(e):
|
|
54
|
-
time.sleep(1) # Wait before retry
|
|
55
|
-
continue
|
|
56
|
-
# If it's not a nonce error, raise immediately
|
|
57
|
-
raise
|
|
58
|
-
# If we've exhausted all retries, raise the last error
|
|
59
|
-
raise OpenGradientError(f"Transaction failed after {max_retries} attempts: {str(last_error)}")
|
|
60
33
|
|
|
61
34
|
class Client:
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
}
|
|
70
|
-
|
|
35
|
+
_inference_hub_contract_address: str
|
|
36
|
+
_blockchain: Web3
|
|
37
|
+
_wallet_account: LocalAccount
|
|
38
|
+
|
|
39
|
+
_hub_user: Dict
|
|
40
|
+
_inference_abi: Dict
|
|
41
|
+
|
|
71
42
|
def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: str, password: str):
|
|
72
43
|
"""
|
|
73
44
|
Initialize the Client with private key, RPC URL, and contract address.
|
|
45
|
+
|
|
74
46
|
Args:
|
|
75
47
|
private_key (str): The private key for the wallet.
|
|
76
48
|
rpc_url (str): The RPC URL for the Ethereum node.
|
|
@@ -78,60 +50,27 @@ class Client:
|
|
|
78
50
|
email (str, optional): Email for authentication. Defaults to "test@test.com".
|
|
79
51
|
password (str, optional): Password for authentication. Defaults to "Test-123".
|
|
80
52
|
"""
|
|
81
|
-
self.
|
|
82
|
-
self.
|
|
83
|
-
self.
|
|
84
|
-
self.rpc_url = rpc_url
|
|
85
|
-
self.contract_address = contract_address
|
|
86
|
-
self._w3 = Web3(Web3.HTTPProvider(self.rpc_url))
|
|
87
|
-
self.wallet_account = self._w3.eth.account.from_key(private_key)
|
|
88
|
-
self.wallet_address = self._w3.to_checksum_address(self.wallet_account.address)
|
|
89
|
-
|
|
90
|
-
self.firebase_app = firebase.initialize_app(self.FIREBASE_CONFIG)
|
|
91
|
-
self.auth = self.firebase_app.auth()
|
|
92
|
-
self.user = None
|
|
93
|
-
|
|
94
|
-
abi_path = Path(__file__).parent / 'abi' / 'inference.abi'
|
|
53
|
+
self._inference_hub_contract_address = contract_address
|
|
54
|
+
self._blockchain = Web3(Web3.HTTPProvider(rpc_url))
|
|
55
|
+
self._wallet_account = self._blockchain.eth.account.from_key(private_key)
|
|
95
56
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
except FileNotFoundError:
|
|
100
|
-
raise
|
|
101
|
-
except json.JSONDecodeError:
|
|
102
|
-
raise
|
|
103
|
-
except Exception as e:
|
|
104
|
-
raise
|
|
105
|
-
|
|
106
|
-
self.abi = inference_abi
|
|
57
|
+
abi_path = Path(__file__).parent / "abi" / "inference.abi"
|
|
58
|
+
with open(abi_path, "r") as abi_file:
|
|
59
|
+
self._inference_abi = json.load(abi_file)
|
|
107
60
|
|
|
108
61
|
if email is not None:
|
|
109
|
-
self.
|
|
62
|
+
self._hub_user = self._login_to_hub(email, password)
|
|
63
|
+
else:
|
|
64
|
+
self._hub_user = None
|
|
110
65
|
|
|
111
|
-
def
|
|
66
|
+
def _login_to_hub(self, email, password):
|
|
112
67
|
try:
|
|
113
|
-
|
|
114
|
-
return
|
|
68
|
+
firebase_app = firebase.initialize_app(_FIREBASE_CONFIG)
|
|
69
|
+
return firebase_app.auth().sign_in_with_email_and_password(email, password)
|
|
115
70
|
except Exception as e:
|
|
116
71
|
logging.error(f"Authentication failed: {str(e)}")
|
|
117
72
|
raise
|
|
118
73
|
|
|
119
|
-
def _initialize_web3(self):
|
|
120
|
-
"""
|
|
121
|
-
Initialize the Web3 instance if it is not already initialized.
|
|
122
|
-
"""
|
|
123
|
-
if self._w3 is None:
|
|
124
|
-
self._w3 = Web3(Web3.HTTPProvider(self.rpc_url))
|
|
125
|
-
|
|
126
|
-
def refresh_token(self) -> None:
|
|
127
|
-
"""
|
|
128
|
-
Refresh the authentication token for the current user.
|
|
129
|
-
"""
|
|
130
|
-
if self.user:
|
|
131
|
-
self.user = self.auth.refresh(self.user['refreshToken'])
|
|
132
|
-
else:
|
|
133
|
-
logging.error("No user is currently signed in")
|
|
134
|
-
|
|
135
74
|
def create_model(self, model_name: str, model_desc: str, version: str = "1.00") -> dict:
|
|
136
75
|
"""
|
|
137
76
|
Create a new model with the given model_name and model_desc, and a specified version.
|
|
@@ -147,18 +86,12 @@ class Client:
|
|
|
147
86
|
Raises:
|
|
148
87
|
CreateModelError: If the model creation fails.
|
|
149
88
|
"""
|
|
150
|
-
if not self.
|
|
89
|
+
if not self._hub_user:
|
|
151
90
|
raise ValueError("User not authenticated")
|
|
152
91
|
|
|
153
92
|
url = "https://api.opengradient.ai/api/v0/models/"
|
|
154
|
-
headers = {
|
|
155
|
-
|
|
156
|
-
'Content-Type': 'application/json'
|
|
157
|
-
}
|
|
158
|
-
payload = {
|
|
159
|
-
'name': model_name,
|
|
160
|
-
'description': model_desc
|
|
161
|
-
}
|
|
93
|
+
headers = {"Authorization": f'Bearer {self._hub_user["idToken"]}', "Content-Type": "application/json"}
|
|
94
|
+
payload = {"name": model_name, "description": model_desc}
|
|
162
95
|
|
|
163
96
|
try:
|
|
164
97
|
logging.debug(f"Create Model URL: {url}")
|
|
@@ -169,7 +102,7 @@ class Client:
|
|
|
169
102
|
response.raise_for_status()
|
|
170
103
|
|
|
171
104
|
json_response = response.json()
|
|
172
|
-
model_name = json_response.get(
|
|
105
|
+
model_name = json_response.get("name")
|
|
173
106
|
if not model_name:
|
|
174
107
|
raise Exception(f"Model creation response missing 'name'. Full response: {json_response}")
|
|
175
108
|
logging.info(f"Model creation successful. Model name: {model_name}")
|
|
@@ -186,7 +119,7 @@ class Client:
|
|
|
186
119
|
|
|
187
120
|
except requests.RequestException as e:
|
|
188
121
|
logging.error(f"Model creation failed: {str(e)}")
|
|
189
|
-
if hasattr(e,
|
|
122
|
+
if hasattr(e, "response") and e.response is not None:
|
|
190
123
|
logging.error(f"Response status code: {e.response.status_code}")
|
|
191
124
|
logging.error(f"Response headers: {e.response.headers}")
|
|
192
125
|
logging.error(f"Response content: {e.response.text}")
|
|
@@ -210,18 +143,12 @@ class Client:
|
|
|
210
143
|
Raises:
|
|
211
144
|
Exception: If the version creation fails.
|
|
212
145
|
"""
|
|
213
|
-
if not self.
|
|
146
|
+
if not self._hub_user:
|
|
214
147
|
raise ValueError("User not authenticated")
|
|
215
148
|
|
|
216
149
|
url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions"
|
|
217
|
-
headers = {
|
|
218
|
-
|
|
219
|
-
'Content-Type': 'application/json'
|
|
220
|
-
}
|
|
221
|
-
payload = {
|
|
222
|
-
"notes": notes,
|
|
223
|
-
"is_major": is_major
|
|
224
|
-
}
|
|
150
|
+
headers = {"Authorization": f'Bearer {self._hub_user["idToken"]}', "Content-Type": "application/json"}
|
|
151
|
+
payload = {"notes": notes, "is_major": is_major}
|
|
225
152
|
|
|
226
153
|
try:
|
|
227
154
|
logging.debug(f"Create Version URL: {url}")
|
|
@@ -239,7 +166,7 @@ class Client:
|
|
|
239
166
|
logging.info("Server returned an empty list. Assuming version was created successfully.")
|
|
240
167
|
return {"versionString": "Unknown", "note": "Created based on empty response"}
|
|
241
168
|
elif isinstance(json_response, dict):
|
|
242
|
-
version_string = json_response.get(
|
|
169
|
+
version_string = json_response.get("versionString")
|
|
243
170
|
if not version_string:
|
|
244
171
|
logging.warning(f"'versionString' not found in response. Response: {json_response}")
|
|
245
172
|
return {"versionString": "Unknown", "note": "Version ID not provided in response"}
|
|
@@ -251,7 +178,7 @@ class Client:
|
|
|
251
178
|
|
|
252
179
|
except requests.RequestException as e:
|
|
253
180
|
logging.error(f"Version creation failed: {str(e)}")
|
|
254
|
-
if hasattr(e,
|
|
181
|
+
if hasattr(e, "response") and e.response is not None:
|
|
255
182
|
logging.error(f"Response status code: {e.response.status_code}")
|
|
256
183
|
logging.error(f"Response headers: {e.response.headers}")
|
|
257
184
|
logging.error(f"Response content: {e.response.text}")
|
|
@@ -277,16 +204,14 @@ class Client:
|
|
|
277
204
|
"""
|
|
278
205
|
from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor
|
|
279
206
|
|
|
280
|
-
if not self.
|
|
207
|
+
if not self._hub_user:
|
|
281
208
|
raise ValueError("User not authenticated")
|
|
282
209
|
|
|
283
210
|
if not os.path.exists(model_path):
|
|
284
211
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
285
212
|
|
|
286
213
|
url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions/{version}/files"
|
|
287
|
-
headers = {
|
|
288
|
-
'Authorization': f'Bearer {self.user["idToken"]}'
|
|
289
|
-
}
|
|
214
|
+
headers = {"Authorization": f'Bearer {self._hub_user["idToken"]}'}
|
|
290
215
|
|
|
291
216
|
logging.info(f"Starting upload for file: {model_path}")
|
|
292
217
|
logging.info(f"File size: {os.path.getsize(model_path)} bytes")
|
|
@@ -295,27 +220,27 @@ class Client:
|
|
|
295
220
|
|
|
296
221
|
def create_callback(encoder):
|
|
297
222
|
encoder_len = encoder.len
|
|
223
|
+
|
|
298
224
|
def callback(monitor):
|
|
299
225
|
progress = (monitor.bytes_read / encoder_len) * 100
|
|
300
226
|
logging.info(f"Upload progress: {progress:.2f}%")
|
|
227
|
+
|
|
301
228
|
return callback
|
|
302
229
|
|
|
303
230
|
try:
|
|
304
|
-
with open(model_path,
|
|
305
|
-
encoder = MultipartEncoder(
|
|
306
|
-
fields={'file': (os.path.basename(model_path), file, 'application/octet-stream')}
|
|
307
|
-
)
|
|
231
|
+
with open(model_path, "rb") as file:
|
|
232
|
+
encoder = MultipartEncoder(fields={"file": (os.path.basename(model_path), file, "application/octet-stream")})
|
|
308
233
|
monitor = MultipartEncoderMonitor(encoder, create_callback(encoder))
|
|
309
|
-
headers[
|
|
234
|
+
headers["Content-Type"] = monitor.content_type
|
|
310
235
|
|
|
311
236
|
logging.info("Sending POST request...")
|
|
312
237
|
response = requests.post(url, data=monitor, headers=headers, timeout=3600) # 1 hour timeout
|
|
313
|
-
|
|
238
|
+
|
|
314
239
|
logging.info(f"Response received. Status code: {response.status_code}")
|
|
315
240
|
logging.info(f"Full response content: {response.text}") # Log the full response content
|
|
316
241
|
|
|
317
242
|
if response.status_code == 201:
|
|
318
|
-
if response.content and response.content != b
|
|
243
|
+
if response.content and response.content != b"null":
|
|
319
244
|
json_response = response.json()
|
|
320
245
|
logging.info(f"JSON response: {json_response}") # Log the parsed JSON response
|
|
321
246
|
logging.info(f"Upload successful. CID: {json_response.get('ipfsCid', 'N/A')}")
|
|
@@ -328,7 +253,7 @@ class Client:
|
|
|
328
253
|
logging.error(error_message)
|
|
329
254
|
raise OpenGradientError(error_message, status_code=500)
|
|
330
255
|
else:
|
|
331
|
-
error_message = response.json().get(
|
|
256
|
+
error_message = response.json().get("detail", "Unknown error occurred")
|
|
332
257
|
logging.error(f"Upload failed with status code {response.status_code}: {error_message}")
|
|
333
258
|
raise OpenGradientError(f"Upload failed: {error_message}", status_code=response.status_code)
|
|
334
259
|
|
|
@@ -336,22 +261,23 @@ class Client:
|
|
|
336
261
|
|
|
337
262
|
except requests.RequestException as e:
|
|
338
263
|
logging.error(f"Request exception during upload: {str(e)}")
|
|
339
|
-
if hasattr(e,
|
|
264
|
+
if hasattr(e, "response") and e.response is not None:
|
|
340
265
|
logging.error(f"Response status code: {e.response.status_code}")
|
|
341
266
|
logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
|
|
342
|
-
raise OpenGradientError(
|
|
343
|
-
|
|
267
|
+
raise OpenGradientError(
|
|
268
|
+
f"Upload failed due to request exception: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
|
|
269
|
+
)
|
|
344
270
|
except Exception as e:
|
|
345
271
|
logging.error(f"Unexpected error during upload: {str(e)}", exc_info=True)
|
|
346
272
|
raise OpenGradientError(f"Unexpected error during upload: {str(e)}")
|
|
347
|
-
|
|
273
|
+
|
|
348
274
|
def infer(
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
275
|
+
self,
|
|
276
|
+
model_cid: str,
|
|
277
|
+
inference_mode: InferenceMode,
|
|
278
|
+
model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
|
|
279
|
+
max_retries: Optional[int] = None,
|
|
280
|
+
) -> Tuple[str, Dict[str, np.ndarray]]:
|
|
355
281
|
"""
|
|
356
282
|
Perform inference on a model.
|
|
357
283
|
|
|
@@ -367,54 +293,54 @@ class Client:
|
|
|
367
293
|
Raises:
|
|
368
294
|
OpenGradientError: If the inference fails.
|
|
369
295
|
"""
|
|
296
|
+
|
|
370
297
|
def execute_transaction():
|
|
371
|
-
self.
|
|
372
|
-
|
|
373
|
-
|
|
298
|
+
contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
|
|
299
|
+
|
|
374
300
|
inference_mode_uint8 = int(inference_mode)
|
|
375
301
|
converted_model_input = utils.convert_to_model_input(model_input)
|
|
376
|
-
|
|
377
|
-
run_function = contract.functions.run(
|
|
378
|
-
model_cid,
|
|
379
|
-
inference_mode_uint8,
|
|
380
|
-
converted_model_input
|
|
381
|
-
)
|
|
382
302
|
|
|
383
|
-
|
|
384
|
-
|
|
303
|
+
run_function = contract.functions.run(model_cid, inference_mode_uint8, converted_model_input)
|
|
304
|
+
|
|
305
|
+
nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
|
|
306
|
+
estimated_gas = run_function.estimate_gas({"from": self._wallet_account.address})
|
|
385
307
|
gas_limit = int(estimated_gas * 3)
|
|
386
308
|
|
|
387
|
-
transaction = run_function.build_transaction(
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
309
|
+
transaction = run_function.build_transaction(
|
|
310
|
+
{
|
|
311
|
+
"from": self._wallet_account.address,
|
|
312
|
+
"nonce": nonce,
|
|
313
|
+
"gas": gas_limit,
|
|
314
|
+
"gasPrice": self._blockchain.eth.gas_price,
|
|
315
|
+
}
|
|
316
|
+
)
|
|
393
317
|
|
|
394
|
-
signed_tx = self.
|
|
395
|
-
tx_hash = self.
|
|
396
|
-
tx_receipt = self.
|
|
318
|
+
signed_tx = self._wallet_account.sign_transaction(transaction)
|
|
319
|
+
tx_hash = self._blockchain.eth.send_raw_transaction(signed_tx.raw_transaction)
|
|
320
|
+
tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash)
|
|
397
321
|
|
|
398
|
-
if tx_receipt[
|
|
322
|
+
if tx_receipt["status"] == 0:
|
|
399
323
|
raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
|
|
400
324
|
|
|
401
325
|
parsed_logs = contract.events.InferenceResult().process_receipt(tx_receipt, errors=DISCARD)
|
|
402
326
|
if len(parsed_logs) < 1:
|
|
403
327
|
raise OpenGradientError("InferenceResult event not found in transaction logs")
|
|
404
328
|
|
|
405
|
-
model_output = utils.convert_to_model_output(parsed_logs[0][
|
|
329
|
+
model_output = utils.convert_to_model_output(parsed_logs[0]["args"])
|
|
406
330
|
return tx_hash.hex(), model_output
|
|
407
331
|
|
|
408
332
|
return run_with_retry(execute_transaction, max_retries or 5)
|
|
409
333
|
|
|
410
|
-
def llm_completion(
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
334
|
+
def llm_completion(
|
|
335
|
+
self,
|
|
336
|
+
model_cid: LLM,
|
|
337
|
+
inference_mode: InferenceMode,
|
|
338
|
+
prompt: str,
|
|
339
|
+
max_tokens: int = 100,
|
|
340
|
+
stop_sequence: Optional[List[str]] = None,
|
|
341
|
+
temperature: float = 0.0,
|
|
342
|
+
max_retries: Optional[int] = None,
|
|
343
|
+
) -> Tuple[str, str]:
|
|
418
344
|
"""
|
|
419
345
|
Perform inference on an LLM model using completions.
|
|
420
346
|
|
|
@@ -432,16 +358,16 @@ class Client:
|
|
|
432
358
|
Raises:
|
|
433
359
|
OpenGradientError: If the inference fails.
|
|
434
360
|
"""
|
|
361
|
+
|
|
435
362
|
def execute_transaction():
|
|
436
363
|
# Check inference mode and supported model
|
|
437
364
|
if inference_mode != LlmInferenceMode.VANILLA and inference_mode != LlmInferenceMode.TEE:
|
|
438
365
|
raise OpenGradientError("Invalid inference mode %s: Inference mode must be VANILLA or TEE" % inference_mode)
|
|
439
|
-
|
|
366
|
+
|
|
440
367
|
if inference_mode == LlmInferenceMode.TEE and model_cid not in TEE_LLM:
|
|
441
368
|
raise OpenGradientError("That model CID is not supported yet supported for TEE inference")
|
|
442
369
|
|
|
443
|
-
self.
|
|
444
|
-
contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
|
|
370
|
+
contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
|
|
445
371
|
|
|
446
372
|
# Prepare LLM input
|
|
447
373
|
llm_request = {
|
|
@@ -450,56 +376,60 @@ class Client:
|
|
|
450
376
|
"prompt": prompt,
|
|
451
377
|
"max_tokens": max_tokens,
|
|
452
378
|
"stop_sequence": stop_sequence or [],
|
|
453
|
-
"temperature": int(temperature * 100) # Scale to 0-100 range
|
|
379
|
+
"temperature": int(temperature * 100), # Scale to 0-100 range
|
|
454
380
|
}
|
|
455
381
|
logging.debug(f"Prepared LLM request: {llm_request}")
|
|
456
382
|
|
|
457
383
|
run_function = contract.functions.runLLMCompletion(llm_request)
|
|
458
384
|
|
|
459
|
-
nonce = self.
|
|
460
|
-
estimated_gas = run_function.estimate_gas({
|
|
385
|
+
nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
|
|
386
|
+
estimated_gas = run_function.estimate_gas({"from": self._wallet_account.address})
|
|
461
387
|
gas_limit = int(estimated_gas * 1.2)
|
|
462
388
|
|
|
463
|
-
transaction = run_function.build_transaction(
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
389
|
+
transaction = run_function.build_transaction(
|
|
390
|
+
{
|
|
391
|
+
"from": self._wallet_account.address,
|
|
392
|
+
"nonce": nonce,
|
|
393
|
+
"gas": gas_limit,
|
|
394
|
+
"gasPrice": self._blockchain.eth.gas_price,
|
|
395
|
+
}
|
|
396
|
+
)
|
|
469
397
|
|
|
470
|
-
signed_tx = self.
|
|
471
|
-
tx_hash = self.
|
|
472
|
-
tx_receipt = self.
|
|
398
|
+
signed_tx = self._wallet_account.sign_transaction(transaction)
|
|
399
|
+
tx_hash = self._blockchain.eth.send_raw_transaction(signed_tx.raw_transaction)
|
|
400
|
+
tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash)
|
|
473
401
|
|
|
474
|
-
if tx_receipt[
|
|
402
|
+
if tx_receipt["status"] == 0:
|
|
475
403
|
raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
|
|
476
404
|
|
|
477
405
|
parsed_logs = contract.events.LLMCompletionResult().process_receipt(tx_receipt, errors=DISCARD)
|
|
478
406
|
if len(parsed_logs) < 1:
|
|
479
407
|
raise OpenGradientError("LLM completion result event not found in transaction logs")
|
|
480
408
|
|
|
481
|
-
llm_answer = parsed_logs[0][
|
|
409
|
+
llm_answer = parsed_logs[0]["args"]["response"]["answer"]
|
|
482
410
|
return tx_hash.hex(), llm_answer
|
|
483
411
|
|
|
484
412
|
return run_with_retry(execute_transaction, max_retries or 5)
|
|
485
413
|
|
|
486
|
-
def llm_chat(
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
414
|
+
def llm_chat(
|
|
415
|
+
self,
|
|
416
|
+
model_cid: LLM,
|
|
417
|
+
inference_mode: InferenceMode,
|
|
418
|
+
messages: List[Dict],
|
|
419
|
+
max_tokens: int = 100,
|
|
420
|
+
stop_sequence: Optional[List[str]] = None,
|
|
421
|
+
temperature: float = 0.0,
|
|
422
|
+
tools: Optional[List[Dict]] = [],
|
|
423
|
+
tool_choice: Optional[str] = None,
|
|
424
|
+
max_retries: Optional[int] = None,
|
|
425
|
+
) -> Tuple[str, str]:
|
|
496
426
|
"""
|
|
497
427
|
Perform inference on an LLM model using chat.
|
|
498
428
|
|
|
499
429
|
Args:
|
|
500
430
|
model_cid (LLM): The unique content identifier for the model.
|
|
501
431
|
inference_mode (InferenceMode): The inference mode.
|
|
502
|
-
messages (dict): The messages that will be passed into the chat.
|
|
432
|
+
messages (dict): The messages that will be passed into the chat.
|
|
503
433
|
This should be in OpenAI API format (https://platform.openai.com/docs/api-reference/chat/create)
|
|
504
434
|
Example:
|
|
505
435
|
[
|
|
@@ -541,7 +471,7 @@ class Client:
|
|
|
541
471
|
}
|
|
542
472
|
}
|
|
543
473
|
]
|
|
544
|
-
tool_choice (str, optional): Sets a specific tool to choose. Default value is "auto".
|
|
474
|
+
tool_choice (str, optional): Sets a specific tool to choose. Default value is "auto".
|
|
545
475
|
|
|
546
476
|
Returns:
|
|
547
477
|
Tuple[str, str, dict]: The transaction hash, finish reason, and a dictionary struct of LLM chat messages.
|
|
@@ -549,37 +479,37 @@ class Client:
|
|
|
549
479
|
Raises:
|
|
550
480
|
OpenGradientError: If the inference fails.
|
|
551
481
|
"""
|
|
482
|
+
|
|
552
483
|
def execute_transaction():
|
|
553
484
|
# Check inference mode and supported model
|
|
554
485
|
if inference_mode != LlmInferenceMode.VANILLA and inference_mode != LlmInferenceMode.TEE:
|
|
555
486
|
raise OpenGradientError("Invalid inference mode %s: Inference mode must be VANILLA or TEE" % inference_mode)
|
|
556
|
-
|
|
487
|
+
|
|
557
488
|
if inference_mode == LlmInferenceMode.TEE and model_cid not in TEE_LLM:
|
|
558
489
|
raise OpenGradientError("That model CID is not supported yet supported for TEE inference")
|
|
559
|
-
|
|
560
|
-
self.
|
|
561
|
-
contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
|
|
490
|
+
|
|
491
|
+
contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
|
|
562
492
|
|
|
563
493
|
# For incoming chat messages, tool_calls can be empty. Add an empty array so that it will fit the ABI.
|
|
564
494
|
for message in messages:
|
|
565
|
-
if
|
|
566
|
-
message[
|
|
567
|
-
if
|
|
568
|
-
message[
|
|
569
|
-
if
|
|
570
|
-
message[
|
|
495
|
+
if "tool_calls" not in message:
|
|
496
|
+
message["tool_calls"] = []
|
|
497
|
+
if "tool_call_id" not in message:
|
|
498
|
+
message["tool_call_id"] = ""
|
|
499
|
+
if "name" not in message:
|
|
500
|
+
message["name"] = ""
|
|
571
501
|
|
|
572
502
|
# Create simplified tool structure for smart contract
|
|
573
503
|
converted_tools = []
|
|
574
504
|
if tools is not None:
|
|
575
505
|
for tool in tools:
|
|
576
|
-
function = tool[
|
|
506
|
+
function = tool["function"]
|
|
577
507
|
converted_tool = {}
|
|
578
|
-
converted_tool[
|
|
579
|
-
converted_tool[
|
|
580
|
-
if (parameters := function.get(
|
|
508
|
+
converted_tool["name"] = function["name"]
|
|
509
|
+
converted_tool["description"] = function["description"]
|
|
510
|
+
if (parameters := function.get("parameters")) is not None:
|
|
581
511
|
try:
|
|
582
|
-
converted_tool[
|
|
512
|
+
converted_tool["parameters"] = json.dumps(parameters)
|
|
583
513
|
except Exception as e:
|
|
584
514
|
raise OpenGradientError("Chat LLM failed to convert parameters into JSON: %s", e)
|
|
585
515
|
converted_tools.append(converted_tool)
|
|
@@ -593,40 +523,42 @@ class Client:
|
|
|
593
523
|
"stop_sequence": stop_sequence or [],
|
|
594
524
|
"temperature": int(temperature * 100), # Scale to 0-100 range
|
|
595
525
|
"tools": converted_tools or [],
|
|
596
|
-
"tool_choice": tool_choice if tool_choice else ("" if tools is None else "auto")
|
|
526
|
+
"tool_choice": tool_choice if tool_choice else ("" if tools is None else "auto"),
|
|
597
527
|
}
|
|
598
528
|
logging.debug(f"Prepared LLM request: {llm_request}")
|
|
599
529
|
|
|
600
530
|
run_function = contract.functions.runLLMChat(llm_request)
|
|
601
531
|
|
|
602
|
-
nonce = self.
|
|
603
|
-
estimated_gas = run_function.estimate_gas({
|
|
532
|
+
nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
|
|
533
|
+
estimated_gas = run_function.estimate_gas({"from": self._wallet_account.address})
|
|
604
534
|
gas_limit = int(estimated_gas * 1.2)
|
|
605
535
|
|
|
606
|
-
transaction = run_function.build_transaction(
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
536
|
+
transaction = run_function.build_transaction(
|
|
537
|
+
{
|
|
538
|
+
"from": self._wallet_account.address,
|
|
539
|
+
"nonce": nonce,
|
|
540
|
+
"gas": gas_limit,
|
|
541
|
+
"gasPrice": self._blockchain.eth.gas_price,
|
|
542
|
+
}
|
|
543
|
+
)
|
|
612
544
|
|
|
613
|
-
signed_tx = self.
|
|
614
|
-
tx_hash = self.
|
|
615
|
-
tx_receipt = self.
|
|
545
|
+
signed_tx = self._wallet_account.sign_transaction(transaction)
|
|
546
|
+
tx_hash = self._blockchain.eth.send_raw_transaction(signed_tx.raw_transaction)
|
|
547
|
+
tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash)
|
|
616
548
|
|
|
617
|
-
if tx_receipt[
|
|
549
|
+
if tx_receipt["status"] == 0:
|
|
618
550
|
raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
|
|
619
551
|
|
|
620
552
|
parsed_logs = contract.events.LLMChatResult().process_receipt(tx_receipt, errors=DISCARD)
|
|
621
553
|
if len(parsed_logs) < 1:
|
|
622
554
|
raise OpenGradientError("LLM chat result event not found in transaction logs")
|
|
623
555
|
|
|
624
|
-
llm_result = parsed_logs[0][
|
|
625
|
-
message = dict(llm_result[
|
|
626
|
-
if (tool_calls := message.get(
|
|
627
|
-
message[
|
|
556
|
+
llm_result = parsed_logs[0]["args"]["response"]
|
|
557
|
+
message = dict(llm_result["message"])
|
|
558
|
+
if (tool_calls := message.get("tool_calls")) is not None:
|
|
559
|
+
message["tool_calls"] = [dict(tool_call) for tool_call in tool_calls]
|
|
628
560
|
|
|
629
|
-
return tx_hash.hex(), llm_result[
|
|
561
|
+
return tx_hash.hex(), llm_result["finish_reason"], message
|
|
630
562
|
|
|
631
563
|
return run_with_retry(execute_transaction, max_retries or 5)
|
|
632
564
|
|
|
@@ -644,13 +576,11 @@ class Client:
|
|
|
644
576
|
Raises:
|
|
645
577
|
OpenGradientError: If the file listing fails.
|
|
646
578
|
"""
|
|
647
|
-
if not self.
|
|
579
|
+
if not self._hub_user:
|
|
648
580
|
raise ValueError("User not authenticated")
|
|
649
581
|
|
|
650
582
|
url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions/{version}/files"
|
|
651
|
-
headers = {
|
|
652
|
-
'Authorization': f'Bearer {self.user["idToken"]}'
|
|
653
|
-
}
|
|
583
|
+
headers = {"Authorization": f'Bearer {self._hub_user["idToken"]}'}
|
|
654
584
|
|
|
655
585
|
logging.debug(f"List Files URL: {url}")
|
|
656
586
|
logging.debug(f"Headers: {headers}")
|
|
@@ -661,31 +591,32 @@ class Client:
|
|
|
661
591
|
|
|
662
592
|
json_response = response.json()
|
|
663
593
|
logging.info(f"File listing successful. Number of files: {len(json_response)}")
|
|
664
|
-
|
|
594
|
+
|
|
665
595
|
return json_response
|
|
666
596
|
|
|
667
597
|
except requests.RequestException as e:
|
|
668
598
|
logging.error(f"File listing failed: {str(e)}")
|
|
669
|
-
if hasattr(e,
|
|
599
|
+
if hasattr(e, "response") and e.response is not None:
|
|
670
600
|
logging.error(f"Response status code: {e.response.status_code}")
|
|
671
601
|
logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
|
|
672
|
-
raise OpenGradientError(
|
|
673
|
-
|
|
602
|
+
raise OpenGradientError(
|
|
603
|
+
f"File listing failed: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
|
|
604
|
+
)
|
|
674
605
|
except Exception as e:
|
|
675
606
|
logging.error(f"Unexpected error during file listing: {str(e)}", exc_info=True)
|
|
676
607
|
raise OpenGradientError(f"Unexpected error during file listing: {str(e)}")
|
|
677
608
|
|
|
678
609
|
def generate_image(
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
610
|
+
self,
|
|
611
|
+
model_cid: str,
|
|
612
|
+
prompt: str,
|
|
613
|
+
host: str = DEFAULT_IMAGE_GEN_HOST,
|
|
614
|
+
port: int = DEFAULT_IMAGE_GEN_PORT,
|
|
615
|
+
width: int = 1024,
|
|
616
|
+
height: int = 1024,
|
|
617
|
+
timeout: int = 300, # 5 minute timeout
|
|
618
|
+
max_retries: int = 3,
|
|
619
|
+
) -> bytes:
|
|
689
620
|
"""
|
|
690
621
|
Generate an image using a diffusion model through gRPC.
|
|
691
622
|
|
|
@@ -706,9 +637,10 @@ class Client:
|
|
|
706
637
|
OpenGradientError: If the image generation fails
|
|
707
638
|
TimeoutError: If the generation exceeds the timeout period
|
|
708
639
|
"""
|
|
640
|
+
|
|
709
641
|
def exponential_backoff(attempt: int, max_delay: float = 30.0) -> None:
|
|
710
642
|
"""Calculate and sleep for exponential backoff duration"""
|
|
711
|
-
delay = min(0.1 * (2
|
|
643
|
+
delay = min(0.1 * (2**attempt), max_delay)
|
|
712
644
|
time.sleep(delay)
|
|
713
645
|
|
|
714
646
|
channel = None
|
|
@@ -719,28 +651,20 @@ class Client:
|
|
|
719
651
|
while retry_count < max_retries:
|
|
720
652
|
try:
|
|
721
653
|
# Initialize gRPC channel and stub
|
|
722
|
-
channel = grpc.insecure_channel(f
|
|
654
|
+
channel = grpc.insecure_channel(f"{host}:{port}")
|
|
723
655
|
stub = infer_pb2_grpc.InferenceServiceStub(channel)
|
|
724
656
|
|
|
725
657
|
# Create image generation request
|
|
726
|
-
image_request = infer_pb2.ImageGenerationRequest(
|
|
727
|
-
model=model_cid,
|
|
728
|
-
prompt=prompt,
|
|
729
|
-
height=height,
|
|
730
|
-
width=width
|
|
731
|
-
)
|
|
658
|
+
image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
|
|
732
659
|
|
|
733
660
|
# Create inference request with random transaction ID
|
|
734
661
|
tx_id = str(uuid.uuid4())
|
|
735
|
-
request = infer_pb2.InferenceRequest(
|
|
736
|
-
tx=tx_id,
|
|
737
|
-
image_generation=image_request
|
|
738
|
-
)
|
|
662
|
+
request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
|
|
739
663
|
|
|
740
664
|
# Send request with timeout
|
|
741
665
|
response_id = stub.RunInferenceAsync(
|
|
742
666
|
request,
|
|
743
|
-
timeout=min(30, timeout) # Initial request timeout
|
|
667
|
+
timeout=min(30, timeout), # Initial request timeout
|
|
744
668
|
)
|
|
745
669
|
|
|
746
670
|
# Poll for completion
|
|
@@ -754,7 +678,7 @@ class Client:
|
|
|
754
678
|
try:
|
|
755
679
|
status = stub.GetInferenceStatus(
|
|
756
680
|
status_request,
|
|
757
|
-
timeout=min(5, timeout) # Status check timeout
|
|
681
|
+
timeout=min(5, timeout), # Status check timeout
|
|
758
682
|
).status
|
|
759
683
|
except grpc.RpcError as e:
|
|
760
684
|
logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
|
|
@@ -775,7 +699,7 @@ class Client:
|
|
|
775
699
|
# Get result
|
|
776
700
|
result = stub.GetInferenceResult(
|
|
777
701
|
response_id,
|
|
778
|
-
timeout=min(30, timeout) # Result fetch timeout
|
|
702
|
+
timeout=min(30, timeout), # Result fetch timeout
|
|
779
703
|
)
|
|
780
704
|
return result.image_generation_result.image_data
|
|
781
705
|
|
|
@@ -783,7 +707,7 @@ class Client:
|
|
|
783
707
|
retry_count += 1
|
|
784
708
|
if retry_count >= max_retries:
|
|
785
709
|
raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
|
|
786
|
-
|
|
710
|
+
|
|
787
711
|
logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
|
|
788
712
|
exponential_backoff(retry_count)
|
|
789
713
|
|
|
@@ -798,61 +722,62 @@ class Client:
|
|
|
798
722
|
raise OpenGradientError(f"Image generation failed: {str(e)}")
|
|
799
723
|
finally:
|
|
800
724
|
if channel:
|
|
801
|
-
channel.close()
|
|
725
|
+
channel.close()
|
|
802
726
|
|
|
803
727
|
def _get_model_executor_abi(self) -> List[Dict]:
|
|
804
728
|
"""
|
|
805
729
|
Returns the ABI for the ModelExecutorHistorical contract.
|
|
806
730
|
"""
|
|
807
|
-
abi_path = Path(__file__).parent /
|
|
808
|
-
with open(abi_path,
|
|
731
|
+
abi_path = Path(__file__).parent / "abi" / "ModelExecutorHistorical.abi"
|
|
732
|
+
with open(abi_path, "r") as f:
|
|
809
733
|
return json.load(f)
|
|
810
734
|
|
|
811
|
-
|
|
812
735
|
def new_workflow(
|
|
813
736
|
self,
|
|
814
737
|
model_cid: str,
|
|
815
738
|
input_query: Union[Dict[str, Any], HistoricalInputQuery],
|
|
816
739
|
input_tensor_name: str,
|
|
817
|
-
scheduler_params: Optional[SchedulerParams] = None
|
|
740
|
+
scheduler_params: Optional[SchedulerParams] = None,
|
|
818
741
|
) -> str:
|
|
819
742
|
"""
|
|
820
743
|
Deploy a new workflow contract with the specified parameters.
|
|
821
744
|
"""
|
|
822
745
|
if isinstance(input_query, dict):
|
|
823
746
|
input_query = HistoricalInputQuery.from_dict(input_query)
|
|
824
|
-
|
|
747
|
+
|
|
825
748
|
# Get contract ABI and bytecode
|
|
826
749
|
abi = self._get_model_executor_abi()
|
|
827
|
-
bin_path = Path(__file__).parent /
|
|
828
|
-
|
|
829
|
-
with open(bin_path,
|
|
750
|
+
bin_path = Path(__file__).parent / "contracts" / "templates" / "ModelExecutorHistorical.bin"
|
|
751
|
+
|
|
752
|
+
with open(bin_path, "r") as f:
|
|
830
753
|
bytecode = f.read().strip()
|
|
831
|
-
|
|
754
|
+
|
|
832
755
|
print("📦 Deploying workflow contract...")
|
|
833
|
-
|
|
756
|
+
|
|
834
757
|
# Create contract instance
|
|
835
|
-
contract = self.
|
|
836
|
-
|
|
758
|
+
contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode)
|
|
759
|
+
|
|
837
760
|
# Deploy contract with constructor arguments
|
|
838
761
|
transaction = contract.constructor(
|
|
839
762
|
model_cid,
|
|
840
763
|
input_query.to_abi_format(),
|
|
841
764
|
"0x00000000000000000000000000000000000000F5", # Historical contract address
|
|
842
|
-
input_tensor_name
|
|
843
|
-
).build_transaction(
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
765
|
+
input_tensor_name,
|
|
766
|
+
).build_transaction(
|
|
767
|
+
{
|
|
768
|
+
"from": self._wallet_account.address,
|
|
769
|
+
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
770
|
+
"gas": 15000000,
|
|
771
|
+
"gasPrice": self._blockchain.eth.gas_price,
|
|
772
|
+
"chainId": self._blockchain.eth.chain_id,
|
|
773
|
+
}
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
signed_txn = self._wallet_account.sign_transaction(transaction)
|
|
777
|
+
tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
|
|
778
|
+
tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash)
|
|
854
779
|
contract_address = tx_receipt.contractAddress
|
|
855
|
-
|
|
780
|
+
|
|
856
781
|
print(f"✅ Workflow contract deployed at: {contract_address}")
|
|
857
782
|
|
|
858
783
|
# Register with scheduler if params provided
|
|
@@ -861,46 +786,45 @@ class Client:
|
|
|
861
786
|
print(f" • Frequency: Every {scheduler_params.frequency} seconds")
|
|
862
787
|
print(f" • Duration: {scheduler_params.duration_hours} hours")
|
|
863
788
|
print(f" • End Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(scheduler_params.end_time))}")
|
|
864
|
-
|
|
865
|
-
scheduler_abi = [
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
789
|
+
|
|
790
|
+
scheduler_abi = [
|
|
791
|
+
{
|
|
792
|
+
"inputs": [
|
|
793
|
+
{"internalType": "address", "name": "contractAddress", "type": "address"},
|
|
794
|
+
{"internalType": "uint256", "name": "endTime", "type": "uint256"},
|
|
795
|
+
{"internalType": "uint256", "name": "frequency", "type": "uint256"},
|
|
796
|
+
],
|
|
797
|
+
"name": "registerTask",
|
|
798
|
+
"outputs": [],
|
|
799
|
+
"stateMutability": "nonpayable",
|
|
800
|
+
"type": "function",
|
|
801
|
+
}
|
|
802
|
+
]
|
|
876
803
|
|
|
877
804
|
scheduler_address = "0xE81a54399CFDf551bB917d0427464fE54537D245"
|
|
878
|
-
scheduler_contract = self.
|
|
879
|
-
address=scheduler_address,
|
|
880
|
-
abi=scheduler_abi
|
|
881
|
-
)
|
|
805
|
+
scheduler_contract = self._blockchain.eth.contract(address=scheduler_address, abi=scheduler_abi)
|
|
882
806
|
|
|
883
807
|
try:
|
|
884
808
|
# Register the workflow with the scheduler
|
|
885
809
|
scheduler_tx = scheduler_contract.functions.registerTask(
|
|
886
|
-
contract_address,
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
signed_scheduler_tx = self.
|
|
898
|
-
scheduler_tx_hash = self.
|
|
899
|
-
self.
|
|
900
|
-
|
|
810
|
+
contract_address, scheduler_params.end_time, scheduler_params.frequency
|
|
811
|
+
).build_transaction(
|
|
812
|
+
{
|
|
813
|
+
"from": self._wallet_account.address,
|
|
814
|
+
"gas": 300000,
|
|
815
|
+
"gasPrice": self._blockchain.eth.gas_price,
|
|
816
|
+
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
817
|
+
"chainId": self._blockchain.eth.chain_id,
|
|
818
|
+
}
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx)
|
|
822
|
+
scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction)
|
|
823
|
+
self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash)
|
|
824
|
+
|
|
901
825
|
print("✅ Automated execution schedule set successfully!")
|
|
902
826
|
print(f" Transaction hash: {scheduler_tx_hash.hex()}")
|
|
903
|
-
|
|
827
|
+
|
|
904
828
|
except Exception as e:
|
|
905
829
|
print("❌ Failed to set up automated execution schedule")
|
|
906
830
|
print(f" Error: {str(e)}")
|
|
@@ -911,26 +835,20 @@ class Client:
|
|
|
911
835
|
def read_workflow_result(self, contract_address: str) -> Any:
|
|
912
836
|
"""
|
|
913
837
|
Reads the latest inference result from a deployed workflow contract.
|
|
914
|
-
|
|
838
|
+
|
|
915
839
|
Args:
|
|
916
840
|
contract_address (str): Address of the deployed workflow contract
|
|
917
|
-
|
|
841
|
+
|
|
918
842
|
Returns:
|
|
919
843
|
Any: The inference result from the contract
|
|
920
|
-
|
|
844
|
+
|
|
921
845
|
Raises:
|
|
922
846
|
ContractLogicError: If the transaction fails
|
|
923
847
|
Web3Error: If there are issues with the web3 connection or contract interaction
|
|
924
848
|
"""
|
|
925
|
-
if not self._w3:
|
|
926
|
-
self._initialize_web3()
|
|
927
|
-
|
|
928
849
|
# Get the contract interface
|
|
929
|
-
contract = self.
|
|
930
|
-
|
|
931
|
-
abi=self._get_model_executor_abi()
|
|
932
|
-
)
|
|
933
|
-
|
|
850
|
+
contract = self._blockchain.eth.contract(address=Web3.to_checksum_address(contract_address), abi=self._get_model_executor_abi())
|
|
851
|
+
|
|
934
852
|
# Get the result
|
|
935
853
|
result = contract.functions.getInferenceResult().call()
|
|
936
854
|
return result
|
|
@@ -938,45 +856,65 @@ class Client:
|
|
|
938
856
|
def run_workflow(self, contract_address: str) -> ModelOutput:
|
|
939
857
|
"""
|
|
940
858
|
Triggers the run() function on a deployed workflow contract and returns the result.
|
|
941
|
-
|
|
859
|
+
|
|
942
860
|
Args:
|
|
943
861
|
contract_address (str): Address of the deployed workflow contract
|
|
944
|
-
|
|
862
|
+
|
|
945
863
|
Returns:
|
|
946
864
|
ModelOutput: The inference result from the contract
|
|
947
|
-
|
|
865
|
+
|
|
948
866
|
Raises:
|
|
949
867
|
ContractLogicError: If the transaction fails
|
|
950
868
|
Web3Error: If there are issues with the web3 connection or contract interaction
|
|
951
869
|
"""
|
|
952
|
-
if not self._w3:
|
|
953
|
-
self._initialize_web3()
|
|
954
|
-
|
|
955
870
|
# Get the contract interface
|
|
956
|
-
contract = self.
|
|
957
|
-
|
|
958
|
-
abi=self._get_model_executor_abi()
|
|
959
|
-
)
|
|
960
|
-
|
|
871
|
+
contract = self._blockchain.eth.contract(address=Web3.to_checksum_address(contract_address), abi=self._get_model_executor_abi())
|
|
872
|
+
|
|
961
873
|
# Call run() function
|
|
962
|
-
nonce = self.
|
|
963
|
-
|
|
874
|
+
nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
|
|
875
|
+
|
|
964
876
|
run_function = contract.functions.run()
|
|
965
|
-
transaction = run_function.build_transaction(
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
877
|
+
transaction = run_function.build_transaction(
|
|
878
|
+
{
|
|
879
|
+
"from": self._wallet_account.address,
|
|
880
|
+
"nonce": nonce,
|
|
881
|
+
"gas": 30000000,
|
|
882
|
+
"gasPrice": self._blockchain.eth.gas_price,
|
|
883
|
+
"chainId": self._blockchain.eth.chain_id,
|
|
884
|
+
}
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
signed_txn = self._wallet_account.sign_transaction(transaction)
|
|
888
|
+
tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
|
|
889
|
+
tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash)
|
|
890
|
+
|
|
977
891
|
if tx_receipt.status == 0:
|
|
978
892
|
raise ContractLogicError(f"Run transaction failed. Receipt: {tx_receipt}")
|
|
979
893
|
|
|
980
894
|
# Get the inference result from the contract
|
|
981
895
|
result = contract.functions.getInferenceResult().call()
|
|
982
896
|
return result
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def run_with_retry(txn_function, max_retries=5):
|
|
900
|
+
"""
|
|
901
|
+
Execute a blockchain transaction with retry logic.
|
|
902
|
+
|
|
903
|
+
Args:
|
|
904
|
+
txn_function: Function that executes the transaction
|
|
905
|
+
max_retries (int): Maximum number of retry attempts
|
|
906
|
+
"""
|
|
907
|
+
last_error = None
|
|
908
|
+
for attempt in range(max_retries):
|
|
909
|
+
try:
|
|
910
|
+
return txn_function()
|
|
911
|
+
except Exception as e:
|
|
912
|
+
last_error = e
|
|
913
|
+
if attempt < max_retries - 1:
|
|
914
|
+
if "nonce too low" in str(e) or "nonce too high" in str(e):
|
|
915
|
+
time.sleep(1) # Wait before retry
|
|
916
|
+
continue
|
|
917
|
+
# If it's not a nonce error, raise immediately
|
|
918
|
+
raise
|
|
919
|
+
# If we've exhausted all retries, raise the last error
|
|
920
|
+
raise OpenGradientError(f"Transaction failed after {max_retries} attempts: {str(last_error)}")
|