opengradient 0.3.24__py3-none-any.whl → 0.3.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
opengradient/client.py CHANGED
@@ -1,76 +1,48 @@
1
- import asyncio
2
1
  import json
3
2
  import logging
4
3
  import os
4
+ import time
5
+ import uuid
5
6
  from pathlib import Path
6
- from typing import Dict, List, Optional, Tuple, Union, Any
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
8
 
8
9
  import firebase
10
+ import grpc
9
11
  import numpy as np
10
12
  import requests
13
+ from eth_account.account import LocalAccount
11
14
  from web3 import Web3
12
15
  from web3.exceptions import ContractLogicError
13
16
  from web3.logs import DISCARD
14
17
 
15
- from opengradient import utils
16
- from opengradient.exceptions import OpenGradientError
17
- from opengradient.types import (
18
- HistoricalInputQuery,
19
- InferenceMode,
20
- LlmInferenceMode,
21
- LLM,
22
- TEE_LLM,
23
- ModelOutput,
24
- SchedulerParams
25
- )
26
-
27
- import grpc
28
- import time
29
- import uuid
30
- from google.protobuf import timestamp_pb2
31
-
32
- from opengradient.proto import infer_pb2
33
- from opengradient.proto import infer_pb2_grpc
18
+ from . import utils
19
+ from .exceptions import OpenGradientError
20
+ from .proto import infer_pb2, infer_pb2_grpc
21
+ from .types import LLM, TEE_LLM, HistoricalInputQuery, InferenceMode, LlmInferenceMode, ModelOutput, SchedulerParams
34
22
  from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT
35
23
 
36
- from functools import wraps
24
+ _FIREBASE_CONFIG = {
25
+ "apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
26
+ "authDomain": "vanna-portal-418018.firebaseapp.com",
27
+ "projectId": "vanna-portal-418018",
28
+ "storageBucket": "vanna-portal-418018.appspot.com",
29
+ "appId": "1:487761246229:web:259af6423a504d2316361c",
30
+ "databaseURL": "",
31
+ }
37
32
 
38
- def run_with_retry(txn_function, max_retries=5):
39
- """
40
- Execute a blockchain transaction with retry logic.
41
-
42
- Args:
43
- txn_function: Function that executes the transaction
44
- max_retries (int): Maximum number of retry attempts
45
- """
46
- last_error = None
47
- for attempt in range(max_retries):
48
- try:
49
- return txn_function()
50
- except Exception as e:
51
- last_error = e
52
- if attempt < max_retries - 1:
53
- if "nonce too low" in str(e) or "nonce too high" in str(e):
54
- time.sleep(1) # Wait before retry
55
- continue
56
- # If it's not a nonce error, raise immediately
57
- raise
58
- # If we've exhausted all retries, raise the last error
59
- raise OpenGradientError(f"Transaction failed after {max_retries} attempts: {str(last_error)}")
60
33
 
61
34
  class Client:
62
- FIREBASE_CONFIG = {
63
- "apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
64
- "authDomain": "vanna-portal-418018.firebaseapp.com",
65
- "projectId": "vanna-portal-418018",
66
- "storageBucket": "vanna-portal-418018.appspot.com",
67
- "appId": "1:487761246229:web:259af6423a504d2316361c",
68
- "databaseURL": ""
69
- }
70
-
35
+ _inference_hub_contract_address: str
36
+ _blockchain: Web3
37
+ _wallet_account: LocalAccount
38
+
39
+ _hub_user: Dict
40
+ _inference_abi: Dict
41
+
71
42
  def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: str, password: str):
72
43
  """
73
44
  Initialize the Client with private key, RPC URL, and contract address.
45
+
74
46
  Args:
75
47
  private_key (str): The private key for the wallet.
76
48
  rpc_url (str): The RPC URL for the Ethereum node.
@@ -78,60 +50,27 @@ class Client:
78
50
  email (str, optional): Email for authentication. Defaults to "test@test.com".
79
51
  password (str, optional): Password for authentication. Defaults to "Test-123".
80
52
  """
81
- self.email = email
82
- self.password = password
83
- self.private_key = private_key
84
- self.rpc_url = rpc_url
85
- self.contract_address = contract_address
86
- self._w3 = Web3(Web3.HTTPProvider(self.rpc_url))
87
- self.wallet_account = self._w3.eth.account.from_key(private_key)
88
- self.wallet_address = self._w3.to_checksum_address(self.wallet_account.address)
89
-
90
- self.firebase_app = firebase.initialize_app(self.FIREBASE_CONFIG)
91
- self.auth = self.firebase_app.auth()
92
- self.user = None
93
-
94
- abi_path = Path(__file__).parent / 'abi' / 'inference.abi'
53
+ self._inference_hub_contract_address = contract_address
54
+ self._blockchain = Web3(Web3.HTTPProvider(rpc_url))
55
+ self._wallet_account = self._blockchain.eth.account.from_key(private_key)
95
56
 
96
- try:
97
- with open(abi_path, 'r') as abi_file:
98
- inference_abi = json.load(abi_file)
99
- except FileNotFoundError:
100
- raise
101
- except json.JSONDecodeError:
102
- raise
103
- except Exception as e:
104
- raise
105
-
106
- self.abi = inference_abi
57
+ abi_path = Path(__file__).parent / "abi" / "inference.abi"
58
+ with open(abi_path, "r") as abi_file:
59
+ self._inference_abi = json.load(abi_file)
107
60
 
108
61
  if email is not None:
109
- self.login(email, password)
62
+ self._hub_user = self._login_to_hub(email, password)
63
+ else:
64
+ self._hub_user = None
110
65
 
111
- def login(self, email, password):
66
+ def _login_to_hub(self, email, password):
112
67
  try:
113
- self.user = self.auth.sign_in_with_email_and_password(email, password)
114
- return self.user
68
+ firebase_app = firebase.initialize_app(_FIREBASE_CONFIG)
69
+ return firebase_app.auth().sign_in_with_email_and_password(email, password)
115
70
  except Exception as e:
116
71
  logging.error(f"Authentication failed: {str(e)}")
117
72
  raise
118
73
 
119
- def _initialize_web3(self):
120
- """
121
- Initialize the Web3 instance if it is not already initialized.
122
- """
123
- if self._w3 is None:
124
- self._w3 = Web3(Web3.HTTPProvider(self.rpc_url))
125
-
126
- def refresh_token(self) -> None:
127
- """
128
- Refresh the authentication token for the current user.
129
- """
130
- if self.user:
131
- self.user = self.auth.refresh(self.user['refreshToken'])
132
- else:
133
- logging.error("No user is currently signed in")
134
-
135
74
  def create_model(self, model_name: str, model_desc: str, version: str = "1.00") -> dict:
136
75
  """
137
76
  Create a new model with the given model_name and model_desc, and a specified version.
@@ -147,18 +86,12 @@ class Client:
147
86
  Raises:
148
87
  CreateModelError: If the model creation fails.
149
88
  """
150
- if not self.user:
89
+ if not self._hub_user:
151
90
  raise ValueError("User not authenticated")
152
91
 
153
92
  url = "https://api.opengradient.ai/api/v0/models/"
154
- headers = {
155
- 'Authorization': f'Bearer {self.user["idToken"]}',
156
- 'Content-Type': 'application/json'
157
- }
158
- payload = {
159
- 'name': model_name,
160
- 'description': model_desc
161
- }
93
+ headers = {"Authorization": f'Bearer {self._hub_user["idToken"]}', "Content-Type": "application/json"}
94
+ payload = {"name": model_name, "description": model_desc}
162
95
 
163
96
  try:
164
97
  logging.debug(f"Create Model URL: {url}")
@@ -169,7 +102,7 @@ class Client:
169
102
  response.raise_for_status()
170
103
 
171
104
  json_response = response.json()
172
- model_name = json_response.get('name')
105
+ model_name = json_response.get("name")
173
106
  if not model_name:
174
107
  raise Exception(f"Model creation response missing 'name'. Full response: {json_response}")
175
108
  logging.info(f"Model creation successful. Model name: {model_name}")
@@ -186,7 +119,7 @@ class Client:
186
119
 
187
120
  except requests.RequestException as e:
188
121
  logging.error(f"Model creation failed: {str(e)}")
189
- if hasattr(e, 'response') and e.response is not None:
122
+ if hasattr(e, "response") and e.response is not None:
190
123
  logging.error(f"Response status code: {e.response.status_code}")
191
124
  logging.error(f"Response headers: {e.response.headers}")
192
125
  logging.error(f"Response content: {e.response.text}")
@@ -210,18 +143,12 @@ class Client:
210
143
  Raises:
211
144
  Exception: If the version creation fails.
212
145
  """
213
- if not self.user:
146
+ if not self._hub_user:
214
147
  raise ValueError("User not authenticated")
215
148
 
216
149
  url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions"
217
- headers = {
218
- 'Authorization': f'Bearer {self.user["idToken"]}',
219
- 'Content-Type': 'application/json'
220
- }
221
- payload = {
222
- "notes": notes,
223
- "is_major": is_major
224
- }
150
+ headers = {"Authorization": f'Bearer {self._hub_user["idToken"]}', "Content-Type": "application/json"}
151
+ payload = {"notes": notes, "is_major": is_major}
225
152
 
226
153
  try:
227
154
  logging.debug(f"Create Version URL: {url}")
@@ -239,7 +166,7 @@ class Client:
239
166
  logging.info("Server returned an empty list. Assuming version was created successfully.")
240
167
  return {"versionString": "Unknown", "note": "Created based on empty response"}
241
168
  elif isinstance(json_response, dict):
242
- version_string = json_response.get('versionString')
169
+ version_string = json_response.get("versionString")
243
170
  if not version_string:
244
171
  logging.warning(f"'versionString' not found in response. Response: {json_response}")
245
172
  return {"versionString": "Unknown", "note": "Version ID not provided in response"}
@@ -251,7 +178,7 @@ class Client:
251
178
 
252
179
  except requests.RequestException as e:
253
180
  logging.error(f"Version creation failed: {str(e)}")
254
- if hasattr(e, 'response') and e.response is not None:
181
+ if hasattr(e, "response") and e.response is not None:
255
182
  logging.error(f"Response status code: {e.response.status_code}")
256
183
  logging.error(f"Response headers: {e.response.headers}")
257
184
  logging.error(f"Response content: {e.response.text}")
@@ -277,16 +204,14 @@ class Client:
277
204
  """
278
205
  from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor
279
206
 
280
- if not self.user:
207
+ if not self._hub_user:
281
208
  raise ValueError("User not authenticated")
282
209
 
283
210
  if not os.path.exists(model_path):
284
211
  raise FileNotFoundError(f"Model file not found: {model_path}")
285
212
 
286
213
  url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions/{version}/files"
287
- headers = {
288
- 'Authorization': f'Bearer {self.user["idToken"]}'
289
- }
214
+ headers = {"Authorization": f'Bearer {self._hub_user["idToken"]}'}
290
215
 
291
216
  logging.info(f"Starting upload for file: {model_path}")
292
217
  logging.info(f"File size: {os.path.getsize(model_path)} bytes")
@@ -295,27 +220,27 @@ class Client:
295
220
 
296
221
  def create_callback(encoder):
297
222
  encoder_len = encoder.len
223
+
298
224
  def callback(monitor):
299
225
  progress = (monitor.bytes_read / encoder_len) * 100
300
226
  logging.info(f"Upload progress: {progress:.2f}%")
227
+
301
228
  return callback
302
229
 
303
230
  try:
304
- with open(model_path, 'rb') as file:
305
- encoder = MultipartEncoder(
306
- fields={'file': (os.path.basename(model_path), file, 'application/octet-stream')}
307
- )
231
+ with open(model_path, "rb") as file:
232
+ encoder = MultipartEncoder(fields={"file": (os.path.basename(model_path), file, "application/octet-stream")})
308
233
  monitor = MultipartEncoderMonitor(encoder, create_callback(encoder))
309
- headers['Content-Type'] = monitor.content_type
234
+ headers["Content-Type"] = monitor.content_type
310
235
 
311
236
  logging.info("Sending POST request...")
312
237
  response = requests.post(url, data=monitor, headers=headers, timeout=3600) # 1 hour timeout
313
-
238
+
314
239
  logging.info(f"Response received. Status code: {response.status_code}")
315
240
  logging.info(f"Full response content: {response.text}") # Log the full response content
316
241
 
317
242
  if response.status_code == 201:
318
- if response.content and response.content != b'null':
243
+ if response.content and response.content != b"null":
319
244
  json_response = response.json()
320
245
  logging.info(f"JSON response: {json_response}") # Log the parsed JSON response
321
246
  logging.info(f"Upload successful. CID: {json_response.get('ipfsCid', 'N/A')}")
@@ -328,7 +253,7 @@ class Client:
328
253
  logging.error(error_message)
329
254
  raise OpenGradientError(error_message, status_code=500)
330
255
  else:
331
- error_message = response.json().get('detail', 'Unknown error occurred')
256
+ error_message = response.json().get("detail", "Unknown error occurred")
332
257
  logging.error(f"Upload failed with status code {response.status_code}: {error_message}")
333
258
  raise OpenGradientError(f"Upload failed: {error_message}", status_code=response.status_code)
334
259
 
@@ -336,22 +261,23 @@ class Client:
336
261
 
337
262
  except requests.RequestException as e:
338
263
  logging.error(f"Request exception during upload: {str(e)}")
339
- if hasattr(e, 'response') and e.response is not None:
264
+ if hasattr(e, "response") and e.response is not None:
340
265
  logging.error(f"Response status code: {e.response.status_code}")
341
266
  logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
342
- raise OpenGradientError(f"Upload failed due to request exception: {str(e)}",
343
- status_code=e.response.status_code if hasattr(e, 'response') else None)
267
+ raise OpenGradientError(
268
+ f"Upload failed due to request exception: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
269
+ )
344
270
  except Exception as e:
345
271
  logging.error(f"Unexpected error during upload: {str(e)}", exc_info=True)
346
272
  raise OpenGradientError(f"Unexpected error during upload: {str(e)}")
347
-
273
+
348
274
  def infer(
349
- self,
350
- model_cid: str,
351
- inference_mode: InferenceMode,
352
- model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
353
- max_retries: Optional[int] = None
354
- ) -> Tuple[str, Dict[str, np.ndarray]]:
275
+ self,
276
+ model_cid: str,
277
+ inference_mode: InferenceMode,
278
+ model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
279
+ max_retries: Optional[int] = None,
280
+ ) -> Tuple[str, Dict[str, np.ndarray]]:
355
281
  """
356
282
  Perform inference on a model.
357
283
 
@@ -367,54 +293,54 @@ class Client:
367
293
  Raises:
368
294
  OpenGradientError: If the inference fails.
369
295
  """
296
+
370
297
  def execute_transaction():
371
- self._initialize_web3()
372
- contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
373
-
298
+ contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
299
+
374
300
  inference_mode_uint8 = int(inference_mode)
375
301
  converted_model_input = utils.convert_to_model_input(model_input)
376
-
377
- run_function = contract.functions.run(
378
- model_cid,
379
- inference_mode_uint8,
380
- converted_model_input
381
- )
382
302
 
383
- nonce = self._w3.eth.get_transaction_count(self.wallet_address, 'pending')
384
- estimated_gas = run_function.estimate_gas({'from': self.wallet_address})
303
+ run_function = contract.functions.run(model_cid, inference_mode_uint8, converted_model_input)
304
+
305
+ nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
306
+ estimated_gas = run_function.estimate_gas({"from": self._wallet_account.address})
385
307
  gas_limit = int(estimated_gas * 3)
386
308
 
387
- transaction = run_function.build_transaction({
388
- 'from': self.wallet_address,
389
- 'nonce': nonce,
390
- 'gas': gas_limit,
391
- 'gasPrice': self._w3.eth.gas_price,
392
- })
309
+ transaction = run_function.build_transaction(
310
+ {
311
+ "from": self._wallet_account.address,
312
+ "nonce": nonce,
313
+ "gas": gas_limit,
314
+ "gasPrice": self._blockchain.eth.gas_price,
315
+ }
316
+ )
393
317
 
394
- signed_tx = self._w3.eth.account.sign_transaction(transaction, self.private_key)
395
- tx_hash = self._w3.eth.send_raw_transaction(signed_tx.raw_transaction)
396
- tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
318
+ signed_tx = self._wallet_account.sign_transaction(transaction)
319
+ tx_hash = self._blockchain.eth.send_raw_transaction(signed_tx.raw_transaction)
320
+ tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash)
397
321
 
398
- if tx_receipt['status'] == 0:
322
+ if tx_receipt["status"] == 0:
399
323
  raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
400
324
 
401
325
  parsed_logs = contract.events.InferenceResult().process_receipt(tx_receipt, errors=DISCARD)
402
326
  if len(parsed_logs) < 1:
403
327
  raise OpenGradientError("InferenceResult event not found in transaction logs")
404
328
 
405
- model_output = utils.convert_to_model_output(parsed_logs[0]['args'])
329
+ model_output = utils.convert_to_model_output(parsed_logs[0]["args"])
406
330
  return tx_hash.hex(), model_output
407
331
 
408
332
  return run_with_retry(execute_transaction, max_retries or 5)
409
333
 
410
- def llm_completion(self,
411
- model_cid: LLM,
412
- inference_mode: InferenceMode,
413
- prompt: str,
414
- max_tokens: int = 100,
415
- stop_sequence: Optional[List[str]] = None,
416
- temperature: float = 0.0,
417
- max_retries: Optional[int] = None) -> Tuple[str, str]:
334
+ def llm_completion(
335
+ self,
336
+ model_cid: LLM,
337
+ inference_mode: InferenceMode,
338
+ prompt: str,
339
+ max_tokens: int = 100,
340
+ stop_sequence: Optional[List[str]] = None,
341
+ temperature: float = 0.0,
342
+ max_retries: Optional[int] = None,
343
+ ) -> Tuple[str, str]:
418
344
  """
419
345
  Perform inference on an LLM model using completions.
420
346
 
@@ -432,16 +358,16 @@ class Client:
432
358
  Raises:
433
359
  OpenGradientError: If the inference fails.
434
360
  """
361
+
435
362
  def execute_transaction():
436
363
  # Check inference mode and supported model
437
364
  if inference_mode != LlmInferenceMode.VANILLA and inference_mode != LlmInferenceMode.TEE:
438
365
  raise OpenGradientError("Invalid inference mode %s: Inference mode must be VANILLA or TEE" % inference_mode)
439
-
366
+
440
367
  if inference_mode == LlmInferenceMode.TEE and model_cid not in TEE_LLM:
441
368
  raise OpenGradientError("That model CID is not supported yet supported for TEE inference")
442
369
 
443
- self._initialize_web3()
444
- contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
370
+ contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
445
371
 
446
372
  # Prepare LLM input
447
373
  llm_request = {
@@ -450,56 +376,60 @@ class Client:
450
376
  "prompt": prompt,
451
377
  "max_tokens": max_tokens,
452
378
  "stop_sequence": stop_sequence or [],
453
- "temperature": int(temperature * 100) # Scale to 0-100 range
379
+ "temperature": int(temperature * 100), # Scale to 0-100 range
454
380
  }
455
381
  logging.debug(f"Prepared LLM request: {llm_request}")
456
382
 
457
383
  run_function = contract.functions.runLLMCompletion(llm_request)
458
384
 
459
- nonce = self._w3.eth.get_transaction_count(self.wallet_address, 'pending')
460
- estimated_gas = run_function.estimate_gas({'from': self.wallet_address})
385
+ nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
386
+ estimated_gas = run_function.estimate_gas({"from": self._wallet_account.address})
461
387
  gas_limit = int(estimated_gas * 1.2)
462
388
 
463
- transaction = run_function.build_transaction({
464
- 'from': self.wallet_address,
465
- 'nonce': nonce,
466
- 'gas': gas_limit,
467
- 'gasPrice': self._w3.eth.gas_price,
468
- })
389
+ transaction = run_function.build_transaction(
390
+ {
391
+ "from": self._wallet_account.address,
392
+ "nonce": nonce,
393
+ "gas": gas_limit,
394
+ "gasPrice": self._blockchain.eth.gas_price,
395
+ }
396
+ )
469
397
 
470
- signed_tx = self._w3.eth.account.sign_transaction(transaction, self.private_key)
471
- tx_hash = self._w3.eth.send_raw_transaction(signed_tx.raw_transaction)
472
- tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
398
+ signed_tx = self._wallet_account.sign_transaction(transaction)
399
+ tx_hash = self._blockchain.eth.send_raw_transaction(signed_tx.raw_transaction)
400
+ tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash)
473
401
 
474
- if tx_receipt['status'] == 0:
402
+ if tx_receipt["status"] == 0:
475
403
  raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
476
404
 
477
405
  parsed_logs = contract.events.LLMCompletionResult().process_receipt(tx_receipt, errors=DISCARD)
478
406
  if len(parsed_logs) < 1:
479
407
  raise OpenGradientError("LLM completion result event not found in transaction logs")
480
408
 
481
- llm_answer = parsed_logs[0]['args']['response']['answer']
409
+ llm_answer = parsed_logs[0]["args"]["response"]["answer"]
482
410
  return tx_hash.hex(), llm_answer
483
411
 
484
412
  return run_with_retry(execute_transaction, max_retries or 5)
485
413
 
486
- def llm_chat(self,
487
- model_cid: str,
488
- inference_mode: InferenceMode,
489
- messages: List[Dict],
490
- max_tokens: int = 100,
491
- stop_sequence: Optional[List[str]] = None,
492
- temperature: float = 0.0,
493
- tools: Optional[List[Dict]] = [],
494
- tool_choice: Optional[str] = None,
495
- max_retries: Optional[int] = None) -> Tuple[str, str]:
414
+ def llm_chat(
415
+ self,
416
+ model_cid: LLM,
417
+ inference_mode: InferenceMode,
418
+ messages: List[Dict],
419
+ max_tokens: int = 100,
420
+ stop_sequence: Optional[List[str]] = None,
421
+ temperature: float = 0.0,
422
+ tools: Optional[List[Dict]] = [],
423
+ tool_choice: Optional[str] = None,
424
+ max_retries: Optional[int] = None,
425
+ ) -> Tuple[str, str]:
496
426
  """
497
427
  Perform inference on an LLM model using chat.
498
428
 
499
429
  Args:
500
430
  model_cid (LLM): The unique content identifier for the model.
501
431
  inference_mode (InferenceMode): The inference mode.
502
- messages (dict): The messages that will be passed into the chat.
432
+ messages (dict): The messages that will be passed into the chat.
503
433
  This should be in OpenAI API format (https://platform.openai.com/docs/api-reference/chat/create)
504
434
  Example:
505
435
  [
@@ -541,7 +471,7 @@ class Client:
541
471
  }
542
472
  }
543
473
  ]
544
- tool_choice (str, optional): Sets a specific tool to choose. Default value is "auto".
474
+ tool_choice (str, optional): Sets a specific tool to choose. Default value is "auto".
545
475
 
546
476
  Returns:
547
477
  Tuple[str, str, dict]: The transaction hash, finish reason, and a dictionary struct of LLM chat messages.
@@ -549,37 +479,37 @@ class Client:
549
479
  Raises:
550
480
  OpenGradientError: If the inference fails.
551
481
  """
482
+
552
483
  def execute_transaction():
553
484
  # Check inference mode and supported model
554
485
  if inference_mode != LlmInferenceMode.VANILLA and inference_mode != LlmInferenceMode.TEE:
555
486
  raise OpenGradientError("Invalid inference mode %s: Inference mode must be VANILLA or TEE" % inference_mode)
556
-
487
+
557
488
  if inference_mode == LlmInferenceMode.TEE and model_cid not in TEE_LLM:
558
489
  raise OpenGradientError("That model CID is not supported yet supported for TEE inference")
559
-
560
- self._initialize_web3()
561
- contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
490
+
491
+ contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
562
492
 
563
493
  # For incoming chat messages, tool_calls can be empty. Add an empty array so that it will fit the ABI.
564
494
  for message in messages:
565
- if 'tool_calls' not in message:
566
- message['tool_calls'] = []
567
- if 'tool_call_id' not in message:
568
- message['tool_call_id'] = ""
569
- if 'name' not in message:
570
- message['name'] = ""
495
+ if "tool_calls" not in message:
496
+ message["tool_calls"] = []
497
+ if "tool_call_id" not in message:
498
+ message["tool_call_id"] = ""
499
+ if "name" not in message:
500
+ message["name"] = ""
571
501
 
572
502
  # Create simplified tool structure for smart contract
573
503
  converted_tools = []
574
504
  if tools is not None:
575
505
  for tool in tools:
576
- function = tool['function']
506
+ function = tool["function"]
577
507
  converted_tool = {}
578
- converted_tool['name'] = function['name']
579
- converted_tool['description'] = function['description']
580
- if (parameters := function.get('parameters')) is not None:
508
+ converted_tool["name"] = function["name"]
509
+ converted_tool["description"] = function["description"]
510
+ if (parameters := function.get("parameters")) is not None:
581
511
  try:
582
- converted_tool['parameters'] = json.dumps(parameters)
512
+ converted_tool["parameters"] = json.dumps(parameters)
583
513
  except Exception as e:
584
514
  raise OpenGradientError("Chat LLM failed to convert parameters into JSON: %s", e)
585
515
  converted_tools.append(converted_tool)
@@ -593,40 +523,42 @@ class Client:
593
523
  "stop_sequence": stop_sequence or [],
594
524
  "temperature": int(temperature * 100), # Scale to 0-100 range
595
525
  "tools": converted_tools or [],
596
- "tool_choice": tool_choice if tool_choice else ("" if tools is None else "auto")
526
+ "tool_choice": tool_choice if tool_choice else ("" if tools is None else "auto"),
597
527
  }
598
528
  logging.debug(f"Prepared LLM request: {llm_request}")
599
529
 
600
530
  run_function = contract.functions.runLLMChat(llm_request)
601
531
 
602
- nonce = self._w3.eth.get_transaction_count(self.wallet_address, 'pending')
603
- estimated_gas = run_function.estimate_gas({'from': self.wallet_address})
532
+ nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
533
+ estimated_gas = run_function.estimate_gas({"from": self._wallet_account.address})
604
534
  gas_limit = int(estimated_gas * 1.2)
605
535
 
606
- transaction = run_function.build_transaction({
607
- 'from': self.wallet_address,
608
- 'nonce': nonce,
609
- 'gas': gas_limit,
610
- 'gasPrice': self._w3.eth.gas_price,
611
- })
536
+ transaction = run_function.build_transaction(
537
+ {
538
+ "from": self._wallet_account.address,
539
+ "nonce": nonce,
540
+ "gas": gas_limit,
541
+ "gasPrice": self._blockchain.eth.gas_price,
542
+ }
543
+ )
612
544
 
613
- signed_tx = self._w3.eth.account.sign_transaction(transaction, self.private_key)
614
- tx_hash = self._w3.eth.send_raw_transaction(signed_tx.raw_transaction)
615
- tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
545
+ signed_tx = self._wallet_account.sign_transaction(transaction)
546
+ tx_hash = self._blockchain.eth.send_raw_transaction(signed_tx.raw_transaction)
547
+ tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash)
616
548
 
617
- if tx_receipt['status'] == 0:
549
+ if tx_receipt["status"] == 0:
618
550
  raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
619
551
 
620
552
  parsed_logs = contract.events.LLMChatResult().process_receipt(tx_receipt, errors=DISCARD)
621
553
  if len(parsed_logs) < 1:
622
554
  raise OpenGradientError("LLM chat result event not found in transaction logs")
623
555
 
624
- llm_result = parsed_logs[0]['args']['response']
625
- message = dict(llm_result['message'])
626
- if (tool_calls := message.get('tool_calls')) is not None:
627
- message['tool_calls'] = [dict(tool_call) for tool_call in tool_calls]
556
+ llm_result = parsed_logs[0]["args"]["response"]
557
+ message = dict(llm_result["message"])
558
+ if (tool_calls := message.get("tool_calls")) is not None:
559
+ message["tool_calls"] = [dict(tool_call) for tool_call in tool_calls]
628
560
 
629
- return tx_hash.hex(), llm_result['finish_reason'], message
561
+ return tx_hash.hex(), llm_result["finish_reason"], message
630
562
 
631
563
  return run_with_retry(execute_transaction, max_retries or 5)
632
564
 
@@ -644,13 +576,11 @@ class Client:
644
576
  Raises:
645
577
  OpenGradientError: If the file listing fails.
646
578
  """
647
- if not self.user:
579
+ if not self._hub_user:
648
580
  raise ValueError("User not authenticated")
649
581
 
650
582
  url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions/{version}/files"
651
- headers = {
652
- 'Authorization': f'Bearer {self.user["idToken"]}'
653
- }
583
+ headers = {"Authorization": f'Bearer {self._hub_user["idToken"]}'}
654
584
 
655
585
  logging.debug(f"List Files URL: {url}")
656
586
  logging.debug(f"Headers: {headers}")
@@ -661,31 +591,32 @@ class Client:
661
591
 
662
592
  json_response = response.json()
663
593
  logging.info(f"File listing successful. Number of files: {len(json_response)}")
664
-
594
+
665
595
  return json_response
666
596
 
667
597
  except requests.RequestException as e:
668
598
  logging.error(f"File listing failed: {str(e)}")
669
- if hasattr(e, 'response') and e.response is not None:
599
+ if hasattr(e, "response") and e.response is not None:
670
600
  logging.error(f"Response status code: {e.response.status_code}")
671
601
  logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
672
- raise OpenGradientError(f"File listing failed: {str(e)}",
673
- status_code=e.response.status_code if hasattr(e, 'response') else None)
602
+ raise OpenGradientError(
603
+ f"File listing failed: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
604
+ )
674
605
  except Exception as e:
675
606
  logging.error(f"Unexpected error during file listing: {str(e)}", exc_info=True)
676
607
  raise OpenGradientError(f"Unexpected error during file listing: {str(e)}")
677
608
 
678
609
  def generate_image(
679
- self,
680
- model_cid: str,
681
- prompt: str,
682
- host: str = DEFAULT_IMAGE_GEN_HOST,
683
- port: int = DEFAULT_IMAGE_GEN_PORT,
684
- width: int = 1024,
685
- height: int = 1024,
686
- timeout: int = 300, # 5 minute timeout
687
- max_retries: int = 3
688
- ) -> bytes:
610
+ self,
611
+ model_cid: str,
612
+ prompt: str,
613
+ host: str = DEFAULT_IMAGE_GEN_HOST,
614
+ port: int = DEFAULT_IMAGE_GEN_PORT,
615
+ width: int = 1024,
616
+ height: int = 1024,
617
+ timeout: int = 300, # 5 minute timeout
618
+ max_retries: int = 3,
619
+ ) -> bytes:
689
620
  """
690
621
  Generate an image using a diffusion model through gRPC.
691
622
 
@@ -706,9 +637,10 @@ class Client:
706
637
  OpenGradientError: If the image generation fails
707
638
  TimeoutError: If the generation exceeds the timeout period
708
639
  """
640
+
709
641
  def exponential_backoff(attempt: int, max_delay: float = 30.0) -> None:
710
642
  """Calculate and sleep for exponential backoff duration"""
711
- delay = min(0.1 * (2 ** attempt), max_delay)
643
+ delay = min(0.1 * (2**attempt), max_delay)
712
644
  time.sleep(delay)
713
645
 
714
646
  channel = None
@@ -719,28 +651,20 @@ class Client:
719
651
  while retry_count < max_retries:
720
652
  try:
721
653
  # Initialize gRPC channel and stub
722
- channel = grpc.insecure_channel(f'{host}:{port}')
654
+ channel = grpc.insecure_channel(f"{host}:{port}")
723
655
  stub = infer_pb2_grpc.InferenceServiceStub(channel)
724
656
 
725
657
  # Create image generation request
726
- image_request = infer_pb2.ImageGenerationRequest(
727
- model=model_cid,
728
- prompt=prompt,
729
- height=height,
730
- width=width
731
- )
658
+ image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
732
659
 
733
660
  # Create inference request with random transaction ID
734
661
  tx_id = str(uuid.uuid4())
735
- request = infer_pb2.InferenceRequest(
736
- tx=tx_id,
737
- image_generation=image_request
738
- )
662
+ request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
739
663
 
740
664
  # Send request with timeout
741
665
  response_id = stub.RunInferenceAsync(
742
666
  request,
743
- timeout=min(30, timeout) # Initial request timeout
667
+ timeout=min(30, timeout), # Initial request timeout
744
668
  )
745
669
 
746
670
  # Poll for completion
@@ -754,7 +678,7 @@ class Client:
754
678
  try:
755
679
  status = stub.GetInferenceStatus(
756
680
  status_request,
757
- timeout=min(5, timeout) # Status check timeout
681
+ timeout=min(5, timeout), # Status check timeout
758
682
  ).status
759
683
  except grpc.RpcError as e:
760
684
  logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
@@ -775,7 +699,7 @@ class Client:
775
699
  # Get result
776
700
  result = stub.GetInferenceResult(
777
701
  response_id,
778
- timeout=min(30, timeout) # Result fetch timeout
702
+ timeout=min(30, timeout), # Result fetch timeout
779
703
  )
780
704
  return result.image_generation_result.image_data
781
705
 
@@ -783,7 +707,7 @@ class Client:
783
707
  retry_count += 1
784
708
  if retry_count >= max_retries:
785
709
  raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
786
-
710
+
787
711
  logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
788
712
  exponential_backoff(retry_count)
789
713
 
@@ -798,61 +722,62 @@ class Client:
798
722
  raise OpenGradientError(f"Image generation failed: {str(e)}")
799
723
  finally:
800
724
  if channel:
801
- channel.close()
725
+ channel.close()
802
726
 
803
727
  def _get_model_executor_abi(self) -> List[Dict]:
804
728
  """
805
729
  Returns the ABI for the ModelExecutorHistorical contract.
806
730
  """
807
- abi_path = Path(__file__).parent / 'abi' / 'ModelExecutorHistorical.abi'
808
- with open(abi_path, 'r') as f:
731
+ abi_path = Path(__file__).parent / "abi" / "ModelExecutorHistorical.abi"
732
+ with open(abi_path, "r") as f:
809
733
  return json.load(f)
810
734
 
811
-
812
735
  def new_workflow(
813
736
  self,
814
737
  model_cid: str,
815
738
  input_query: Union[Dict[str, Any], HistoricalInputQuery],
816
739
  input_tensor_name: str,
817
- scheduler_params: Optional[SchedulerParams] = None
740
+ scheduler_params: Optional[SchedulerParams] = None,
818
741
  ) -> str:
819
742
  """
820
743
  Deploy a new workflow contract with the specified parameters.
821
744
  """
822
745
  if isinstance(input_query, dict):
823
746
  input_query = HistoricalInputQuery.from_dict(input_query)
824
-
747
+
825
748
  # Get contract ABI and bytecode
826
749
  abi = self._get_model_executor_abi()
827
- bin_path = Path(__file__).parent / 'contracts' / 'templates' / 'ModelExecutorHistorical.bin'
828
-
829
- with open(bin_path, 'r') as f:
750
+ bin_path = Path(__file__).parent / "contracts" / "templates" / "ModelExecutorHistorical.bin"
751
+
752
+ with open(bin_path, "r") as f:
830
753
  bytecode = f.read().strip()
831
-
754
+
832
755
  print("📦 Deploying workflow contract...")
833
-
756
+
834
757
  # Create contract instance
835
- contract = self._w3.eth.contract(abi=abi, bytecode=bytecode)
836
-
758
+ contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode)
759
+
837
760
  # Deploy contract with constructor arguments
838
761
  transaction = contract.constructor(
839
762
  model_cid,
840
763
  input_query.to_abi_format(),
841
764
  "0x00000000000000000000000000000000000000F5", # Historical contract address
842
- input_tensor_name
843
- ).build_transaction({
844
- 'from': self.wallet_address,
845
- 'nonce': self._w3.eth.get_transaction_count(self.wallet_address, 'pending'),
846
- 'gas': 15000000,
847
- 'gasPrice': self._w3.eth.gas_price,
848
- 'chainId': self._w3.eth.chain_id
849
- })
850
-
851
- signed_txn = self._w3.eth.account.sign_transaction(transaction, self.private_key)
852
- tx_hash = self._w3.eth.send_raw_transaction(signed_txn.raw_transaction)
853
- tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
765
+ input_tensor_name,
766
+ ).build_transaction(
767
+ {
768
+ "from": self._wallet_account.address,
769
+ "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
770
+ "gas": 15000000,
771
+ "gasPrice": self._blockchain.eth.gas_price,
772
+ "chainId": self._blockchain.eth.chain_id,
773
+ }
774
+ )
775
+
776
+ signed_txn = self._wallet_account.sign_transaction(transaction)
777
+ tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
778
+ tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash)
854
779
  contract_address = tx_receipt.contractAddress
855
-
780
+
856
781
  print(f"✅ Workflow contract deployed at: {contract_address}")
857
782
 
858
783
  # Register with scheduler if params provided
@@ -861,46 +786,45 @@ class Client:
861
786
  print(f" • Frequency: Every {scheduler_params.frequency} seconds")
862
787
  print(f" • Duration: {scheduler_params.duration_hours} hours")
863
788
  print(f" • End Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(scheduler_params.end_time))}")
864
-
865
- scheduler_abi = [{
866
- "inputs": [
867
- {"internalType": "address", "name": "contractAddress", "type": "address"},
868
- {"internalType": "uint256", "name": "endTime", "type": "uint256"},
869
- {"internalType": "uint256", "name": "frequency", "type": "uint256"}
870
- ],
871
- "name": "registerTask",
872
- "outputs": [],
873
- "stateMutability": "nonpayable",
874
- "type": "function"
875
- }]
789
+
790
+ scheduler_abi = [
791
+ {
792
+ "inputs": [
793
+ {"internalType": "address", "name": "contractAddress", "type": "address"},
794
+ {"internalType": "uint256", "name": "endTime", "type": "uint256"},
795
+ {"internalType": "uint256", "name": "frequency", "type": "uint256"},
796
+ ],
797
+ "name": "registerTask",
798
+ "outputs": [],
799
+ "stateMutability": "nonpayable",
800
+ "type": "function",
801
+ }
802
+ ]
876
803
 
877
804
  scheduler_address = "0xE81a54399CFDf551bB917d0427464fE54537D245"
878
- scheduler_contract = self._w3.eth.contract(
879
- address=scheduler_address,
880
- abi=scheduler_abi
881
- )
805
+ scheduler_contract = self._blockchain.eth.contract(address=scheduler_address, abi=scheduler_abi)
882
806
 
883
807
  try:
884
808
  # Register the workflow with the scheduler
885
809
  scheduler_tx = scheduler_contract.functions.registerTask(
886
- contract_address,
887
- scheduler_params.end_time,
888
- scheduler_params.frequency
889
- ).build_transaction({
890
- 'from': self.wallet_address,
891
- 'gas': 300000,
892
- 'gasPrice': self._w3.eth.gas_price,
893
- 'nonce': self._w3.eth.get_transaction_count(self.wallet_address, 'pending'),
894
- 'chainId': self._w3.eth.chain_id
895
- })
896
-
897
- signed_scheduler_tx = self._w3.eth.account.sign_transaction(scheduler_tx, self.private_key)
898
- scheduler_tx_hash = self._w3.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction)
899
- self._w3.eth.wait_for_transaction_receipt(scheduler_tx_hash)
900
-
810
+ contract_address, scheduler_params.end_time, scheduler_params.frequency
811
+ ).build_transaction(
812
+ {
813
+ "from": self._wallet_account.address,
814
+ "gas": 300000,
815
+ "gasPrice": self._blockchain.eth.gas_price,
816
+ "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
817
+ "chainId": self._blockchain.eth.chain_id,
818
+ }
819
+ )
820
+
821
+ signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx)
822
+ scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction)
823
+ self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash)
824
+
901
825
  print("✅ Automated execution schedule set successfully!")
902
826
  print(f" Transaction hash: {scheduler_tx_hash.hex()}")
903
-
827
+
904
828
  except Exception as e:
905
829
  print("❌ Failed to set up automated execution schedule")
906
830
  print(f" Error: {str(e)}")
@@ -911,26 +835,20 @@ class Client:
911
835
  def read_workflow_result(self, contract_address: str) -> Any:
912
836
  """
913
837
  Reads the latest inference result from a deployed workflow contract.
914
-
838
+
915
839
  Args:
916
840
  contract_address (str): Address of the deployed workflow contract
917
-
841
+
918
842
  Returns:
919
843
  Any: The inference result from the contract
920
-
844
+
921
845
  Raises:
922
846
  ContractLogicError: If the transaction fails
923
847
  Web3Error: If there are issues with the web3 connection or contract interaction
924
848
  """
925
- if not self._w3:
926
- self._initialize_web3()
927
-
928
849
  # Get the contract interface
929
- contract = self._w3.eth.contract(
930
- address=Web3.to_checksum_address(contract_address),
931
- abi=self._get_model_executor_abi()
932
- )
933
-
850
+ contract = self._blockchain.eth.contract(address=Web3.to_checksum_address(contract_address), abi=self._get_model_executor_abi())
851
+
934
852
  # Get the result
935
853
  result = contract.functions.getInferenceResult().call()
936
854
  return result
@@ -938,45 +856,65 @@ class Client:
938
856
  def run_workflow(self, contract_address: str) -> ModelOutput:
939
857
  """
940
858
  Triggers the run() function on a deployed workflow contract and returns the result.
941
-
859
+
942
860
  Args:
943
861
  contract_address (str): Address of the deployed workflow contract
944
-
862
+
945
863
  Returns:
946
864
  ModelOutput: The inference result from the contract
947
-
865
+
948
866
  Raises:
949
867
  ContractLogicError: If the transaction fails
950
868
  Web3Error: If there are issues with the web3 connection or contract interaction
951
869
  """
952
- if not self._w3:
953
- self._initialize_web3()
954
-
955
870
  # Get the contract interface
956
- contract = self._w3.eth.contract(
957
- address=Web3.to_checksum_address(contract_address),
958
- abi=self._get_model_executor_abi()
959
- )
960
-
871
+ contract = self._blockchain.eth.contract(address=Web3.to_checksum_address(contract_address), abi=self._get_model_executor_abi())
872
+
961
873
  # Call run() function
962
- nonce = self._w3.eth.get_transaction_count(self.wallet_address, 'pending')
963
-
874
+ nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
875
+
964
876
  run_function = contract.functions.run()
965
- transaction = run_function.build_transaction({
966
- 'from': self.wallet_address,
967
- 'nonce': nonce,
968
- 'gas': 30000000,
969
- 'gasPrice': self._w3.eth.gas_price,
970
- 'chainId': self._w3.eth.chain_id
971
- })
972
-
973
- signed_txn = self._w3.eth.account.sign_transaction(transaction, self.private_key)
974
- tx_hash = self._w3.eth.send_raw_transaction(signed_txn.raw_transaction)
975
- tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
976
-
877
+ transaction = run_function.build_transaction(
878
+ {
879
+ "from": self._wallet_account.address,
880
+ "nonce": nonce,
881
+ "gas": 30000000,
882
+ "gasPrice": self._blockchain.eth.gas_price,
883
+ "chainId": self._blockchain.eth.chain_id,
884
+ }
885
+ )
886
+
887
+ signed_txn = self._wallet_account.sign_transaction(transaction)
888
+ tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
889
+ tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash)
890
+
977
891
  if tx_receipt.status == 0:
978
892
  raise ContractLogicError(f"Run transaction failed. Receipt: {tx_receipt}")
979
893
 
980
894
  # Get the inference result from the contract
981
895
  result = contract.functions.getInferenceResult().call()
982
896
  return result
897
+
898
+
899
+ def run_with_retry(txn_function, max_retries=5):
900
+ """
901
+ Execute a blockchain transaction with retry logic.
902
+
903
+ Args:
904
+ txn_function: Function that executes the transaction
905
+ max_retries (int): Maximum number of retry attempts
906
+ """
907
+ last_error = None
908
+ for attempt in range(max_retries):
909
+ try:
910
+ return txn_function()
911
+ except Exception as e:
912
+ last_error = e
913
+ if attempt < max_retries - 1:
914
+ if "nonce too low" in str(e) or "nonce too high" in str(e):
915
+ time.sleep(1) # Wait before retry
916
+ continue
917
+ # If it's not a nonce error, raise immediately
918
+ raise
919
+ # If we've exhausted all retries, raise the last error
920
+ raise OpenGradientError(f"Transaction failed after {max_retries} attempts: {str(last_error)}")