opengradient 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opengradient/__init__.py +64 -70
- opengradient/abi/PriceHistoryInference.abi +1 -0
- opengradient/abi/WorkflowScheduler.abi +13 -0
- opengradient/alphasense/read_workflow_tool.py +1 -1
- opengradient/alphasense/run_model_tool.py +3 -3
- opengradient/bin/PriceHistoryInference.bin +1 -0
- opengradient/cli.py +8 -4
- opengradient/client.py +303 -259
- opengradient/defaults.py +1 -0
- opengradient/llm/__init__.py +1 -1
- opengradient/llm/og_langchain.py +36 -22
- opengradient/llm/og_openai.py +1 -1
- opengradient/types.py +34 -20
- opengradient/utils.py +2 -0
- opengradient-0.4.8.dist-info/METADATA +159 -0
- opengradient-0.4.8.dist-info/RECORD +29 -0
- {opengradient-0.4.6.dist-info → opengradient-0.4.8.dist-info}/WHEEL +1 -1
- opengradient/abi/ModelExecutorHistorical.abi +0 -1
- opengradient-0.4.6.dist-info/METADATA +0 -189
- opengradient-0.4.6.dist-info/RECORD +0 -27
- {opengradient-0.4.6.dist-info → opengradient-0.4.8.dist-info}/LICENSE +0 -0
- {opengradient-0.4.6.dist-info → opengradient-0.4.8.dist-info}/entry_points.txt +0 -0
- {opengradient-0.4.6.dist-info → opengradient-0.4.8.dist-info}/top_level.txt +0 -0
opengradient/client.py
CHANGED
|
@@ -2,12 +2,10 @@ import json
|
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
import time
|
|
5
|
-
import uuid
|
|
6
5
|
from pathlib import Path
|
|
7
|
-
from typing import Any, Dict, List, Optional,
|
|
6
|
+
from typing import Any, Dict, List, Optional, Union
|
|
8
7
|
|
|
9
8
|
import firebase
|
|
10
|
-
import grpc
|
|
11
9
|
import numpy as np
|
|
12
10
|
import requests
|
|
13
11
|
from eth_account.account import LocalAccount
|
|
@@ -15,11 +13,23 @@ from web3 import Web3
|
|
|
15
13
|
from web3.exceptions import ContractLogicError
|
|
16
14
|
from web3.logs import DISCARD
|
|
17
15
|
|
|
18
|
-
from . import utils
|
|
19
16
|
from .exceptions import OpenGradientError
|
|
20
17
|
from .proto import infer_pb2, infer_pb2_grpc
|
|
21
|
-
from .types import
|
|
22
|
-
|
|
18
|
+
from .types import (
|
|
19
|
+
LLM,
|
|
20
|
+
TEE_LLM,
|
|
21
|
+
HistoricalInputQuery,
|
|
22
|
+
InferenceMode,
|
|
23
|
+
LlmInferenceMode,
|
|
24
|
+
ModelOutput,
|
|
25
|
+
TextGenerationOutput,
|
|
26
|
+
SchedulerParams,
|
|
27
|
+
InferenceResult,
|
|
28
|
+
ModelRepository,
|
|
29
|
+
FileUploadResult,
|
|
30
|
+
)
|
|
31
|
+
from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT, DEFAULT_SCHEDULER_ADDRESS
|
|
32
|
+
from .utils import convert_array_to_model_output, convert_to_model_input, convert_to_model_output
|
|
23
33
|
|
|
24
34
|
_FIREBASE_CONFIG = {
|
|
25
35
|
"apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
|
|
@@ -45,10 +55,10 @@ class Client:
|
|
|
45
55
|
_blockchain: Web3
|
|
46
56
|
_wallet_account: LocalAccount
|
|
47
57
|
|
|
48
|
-
_hub_user: Dict
|
|
58
|
+
_hub_user: Optional[Dict]
|
|
49
59
|
_inference_abi: Dict
|
|
50
60
|
|
|
51
|
-
def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: str, password: str):
|
|
61
|
+
def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: Optional[str], password: Optional[str]):
|
|
52
62
|
"""
|
|
53
63
|
Initialize the Client with private key, RPC URL, and contract address.
|
|
54
64
|
|
|
@@ -80,7 +90,7 @@ class Client:
|
|
|
80
90
|
logging.error(f"Authentication failed: {str(e)}")
|
|
81
91
|
raise
|
|
82
92
|
|
|
83
|
-
def create_model(self, model_name: str, model_desc: str, version: str = "1.00") ->
|
|
93
|
+
def create_model(self, model_name: str, model_desc: str, version: str = "1.00") -> ModelRepository:
|
|
84
94
|
"""
|
|
85
95
|
Create a new model with the given model_name and model_desc, and a specified version.
|
|
86
96
|
|
|
@@ -103,41 +113,24 @@ class Client:
|
|
|
103
113
|
payload = {"name": model_name, "description": model_desc}
|
|
104
114
|
|
|
105
115
|
try:
|
|
106
|
-
logging.debug(f"Create Model URL: {url}")
|
|
107
|
-
logging.debug(f"Headers: {headers}")
|
|
108
|
-
logging.debug(f"Payload: {payload}")
|
|
109
|
-
|
|
110
116
|
response = requests.post(url, json=payload, headers=headers)
|
|
111
117
|
response.raise_for_status()
|
|
118
|
+
except requests.HTTPError as e:
|
|
119
|
+
error_details = f"HTTP {e.response.status_code}: {e.response.text}"
|
|
120
|
+
raise OpenGradientError(f"Model creation failed: {error_details}") from e
|
|
112
121
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
logging.info(f"Model creation successful. Model name: {model_name}")
|
|
122
|
+
json_response = response.json()
|
|
123
|
+
model_name = json_response.get("name")
|
|
124
|
+
if not model_name:
|
|
125
|
+
raise Exception(f"Model creation response missing 'name'. Full response: {json_response}")
|
|
118
126
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
version_response = self.create_version(model_name, version)
|
|
122
|
-
logging.info(f"Version creation successful. Version string: {version_response['versionString']}")
|
|
123
|
-
except Exception as ve:
|
|
124
|
-
logging.error(f"Version creation failed, but model was created. Error: {str(ve)}")
|
|
125
|
-
return {"name": model_name, "versionString": None, "version_error": str(ve)}
|
|
127
|
+
# Create the specified version for the newly created model
|
|
128
|
+
version_response = self.create_version(model_name, version)
|
|
126
129
|
|
|
127
|
-
|
|
130
|
+
return ModelRepository(model_name, version_response["versionString"])
|
|
128
131
|
|
|
129
|
-
except requests.RequestException as e:
|
|
130
|
-
logging.error(f"Model creation failed: {str(e)}")
|
|
131
|
-
if hasattr(e, "response") and e.response is not None:
|
|
132
|
-
logging.error(f"Response status code: {e.response.status_code}")
|
|
133
|
-
logging.error(f"Response headers: {e.response.headers}")
|
|
134
|
-
logging.error(f"Response content: {e.response.text}")
|
|
135
|
-
raise Exception(f"Model creation failed: {str(e)}")
|
|
136
|
-
except Exception as e:
|
|
137
|
-
logging.error(f"Unexpected error during model creation: {str(e)}")
|
|
138
|
-
raise
|
|
139
132
|
|
|
140
|
-
def create_version(self, model_name: str, notes: str =
|
|
133
|
+
def create_version(self, model_name: str, notes: str = "", is_major: bool = False) -> dict:
|
|
141
134
|
"""
|
|
142
135
|
Create a new version for the specified model.
|
|
143
136
|
|
|
@@ -196,7 +189,7 @@ class Client:
|
|
|
196
189
|
logging.error(f"Unexpected error during version creation: {str(e)}")
|
|
197
190
|
raise
|
|
198
191
|
|
|
199
|
-
def upload(self, model_path: str, model_name: str, version: str) ->
|
|
192
|
+
def upload(self, model_path: str, model_name: str, version: str) -> FileUploadResult:
|
|
200
193
|
"""
|
|
201
194
|
Upload a model file to the server.
|
|
202
195
|
|
|
@@ -251,12 +244,9 @@ class Client:
|
|
|
251
244
|
if response.status_code == 201:
|
|
252
245
|
if response.content and response.content != b"null":
|
|
253
246
|
json_response = response.json()
|
|
254
|
-
|
|
255
|
-
logging.info(f"Upload successful. CID: {json_response.get('ipfsCid', 'N/A')}")
|
|
256
|
-
result = {"model_cid": json_response.get("ipfsCid"), "size": json_response.get("size")}
|
|
247
|
+
return FileUploadResult(json_response.get("ipfsCid"), json_response.get("size"))
|
|
257
248
|
else:
|
|
258
|
-
|
|
259
|
-
result = {"model_cid": None, "size": None}
|
|
249
|
+
raise RuntimeError("Empty or null response content received. Assuming upload was successful.")
|
|
260
250
|
elif response.status_code == 500:
|
|
261
251
|
error_message = "Internal server error occurred. Please try again later or contact support."
|
|
262
252
|
logging.error(error_message)
|
|
@@ -266,16 +256,12 @@ class Client:
|
|
|
266
256
|
logging.error(f"Upload failed with status code {response.status_code}: {error_message}")
|
|
267
257
|
raise OpenGradientError(f"Upload failed: {error_message}", status_code=response.status_code)
|
|
268
258
|
|
|
269
|
-
return result
|
|
270
|
-
|
|
271
259
|
except requests.RequestException as e:
|
|
272
260
|
logging.error(f"Request exception during upload: {str(e)}")
|
|
273
261
|
if hasattr(e, "response") and e.response is not None:
|
|
274
262
|
logging.error(f"Response status code: {e.response.status_code}")
|
|
275
263
|
logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
|
|
276
|
-
raise OpenGradientError(
|
|
277
|
-
f"Upload failed due to request exception: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
|
|
278
|
-
)
|
|
264
|
+
raise OpenGradientError(f"Upload failed due to request exception: {str(e)}")
|
|
279
265
|
except Exception as e:
|
|
280
266
|
logging.error(f"Unexpected error during upload: {str(e)}", exc_info=True)
|
|
281
267
|
raise OpenGradientError(f"Unexpected error during upload: {str(e)}")
|
|
@@ -286,7 +272,7 @@ class Client:
|
|
|
286
272
|
inference_mode: InferenceMode,
|
|
287
273
|
model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
|
|
288
274
|
max_retries: Optional[int] = None,
|
|
289
|
-
) ->
|
|
275
|
+
) -> InferenceResult:
|
|
290
276
|
"""
|
|
291
277
|
Perform inference on a model.
|
|
292
278
|
|
|
@@ -297,7 +283,7 @@ class Client:
|
|
|
297
283
|
max_retries (int, optional): Maximum number of retry attempts. Defaults to 5.
|
|
298
284
|
|
|
299
285
|
Returns:
|
|
300
|
-
|
|
286
|
+
InferenceResult: The transaction hash and the model output.
|
|
301
287
|
|
|
302
288
|
Raises:
|
|
303
289
|
OpenGradientError: If the inference fails.
|
|
@@ -306,8 +292,8 @@ class Client:
|
|
|
306
292
|
def execute_transaction():
|
|
307
293
|
contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
|
|
308
294
|
|
|
309
|
-
inference_mode_uint8 =
|
|
310
|
-
converted_model_input =
|
|
295
|
+
inference_mode_uint8 = inference_mode.value
|
|
296
|
+
converted_model_input = convert_to_model_input(model_input)
|
|
311
297
|
|
|
312
298
|
run_function = contract.functions.run(model_cid, inference_mode_uint8, converted_model_input)
|
|
313
299
|
|
|
@@ -336,15 +322,16 @@ class Client:
|
|
|
336
322
|
raise OpenGradientError("InferenceResult event not found in transaction logs")
|
|
337
323
|
|
|
338
324
|
# TODO: This should return a ModelOutput class object
|
|
339
|
-
model_output =
|
|
340
|
-
|
|
325
|
+
model_output = convert_to_model_output(parsed_logs[0]["args"])
|
|
326
|
+
|
|
327
|
+
return InferenceResult(tx_hash.hex(), model_output)
|
|
341
328
|
|
|
342
329
|
return run_with_retry(execute_transaction, max_retries)
|
|
343
330
|
|
|
344
331
|
def llm_completion(
|
|
345
332
|
self,
|
|
346
333
|
model_cid: LLM,
|
|
347
|
-
inference_mode:
|
|
334
|
+
inference_mode: LlmInferenceMode,
|
|
348
335
|
prompt: str,
|
|
349
336
|
max_tokens: int = 100,
|
|
350
337
|
stop_sequence: Optional[List[str]] = None,
|
|
@@ -383,7 +370,7 @@ class Client:
|
|
|
383
370
|
|
|
384
371
|
# Prepare LLM input
|
|
385
372
|
llm_request = {
|
|
386
|
-
"mode": inference_mode,
|
|
373
|
+
"mode": inference_mode.value,
|
|
387
374
|
"modelCID": model_cid,
|
|
388
375
|
"prompt": prompt,
|
|
389
376
|
"max_tokens": max_tokens,
|
|
@@ -420,18 +407,15 @@ class Client:
|
|
|
420
407
|
raise OpenGradientError("LLM completion result event not found in transaction logs")
|
|
421
408
|
|
|
422
409
|
llm_answer = parsed_logs[0]["args"]["response"]["answer"]
|
|
423
|
-
|
|
424
|
-
return TextGenerationOutput(
|
|
425
|
-
transaction_hash=tx_hash.hex(),
|
|
426
|
-
completion_output=llm_answer
|
|
427
|
-
)
|
|
410
|
+
|
|
411
|
+
return TextGenerationOutput(transaction_hash=tx_hash.hex(), completion_output=llm_answer)
|
|
428
412
|
|
|
429
413
|
return run_with_retry(execute_transaction, max_retries)
|
|
430
414
|
|
|
431
415
|
def llm_chat(
|
|
432
416
|
self,
|
|
433
417
|
model_cid: LLM,
|
|
434
|
-
inference_mode:
|
|
418
|
+
inference_mode: LlmInferenceMode,
|
|
435
419
|
messages: List[Dict],
|
|
436
420
|
max_tokens: int = 100,
|
|
437
421
|
stop_sequence: Optional[List[str]] = None,
|
|
@@ -536,7 +520,7 @@ class Client:
|
|
|
536
520
|
|
|
537
521
|
# Prepare LLM input
|
|
538
522
|
llm_request = {
|
|
539
|
-
"mode": inference_mode,
|
|
523
|
+
"mode": inference_mode.value,
|
|
540
524
|
"modelCID": model_cid,
|
|
541
525
|
"messages": messages,
|
|
542
526
|
"max_tokens": max_tokens,
|
|
@@ -624,234 +608,269 @@ class Client:
|
|
|
624
608
|
if hasattr(e, "response") and e.response is not None:
|
|
625
609
|
logging.error(f"Response status code: {e.response.status_code}")
|
|
626
610
|
logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
|
|
627
|
-
raise OpenGradientError(
|
|
628
|
-
f"File listing failed: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
|
|
629
|
-
)
|
|
611
|
+
raise OpenGradientError(f"File listing failed: {str(e)}")
|
|
630
612
|
except Exception as e:
|
|
631
613
|
logging.error(f"Unexpected error during file listing: {str(e)}", exc_info=True)
|
|
632
614
|
raise OpenGradientError(f"Unexpected error during file listing: {str(e)}")
|
|
633
615
|
|
|
634
|
-
def generate_image(
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
) -> bytes:
|
|
616
|
+
# def generate_image(
|
|
617
|
+
# self,
|
|
618
|
+
# model_cid: str,
|
|
619
|
+
# prompt: str,
|
|
620
|
+
# host: str = DEFAULT_IMAGE_GEN_HOST,
|
|
621
|
+
# port: int = DEFAULT_IMAGE_GEN_PORT,
|
|
622
|
+
# width: int = 1024,
|
|
623
|
+
# height: int = 1024,
|
|
624
|
+
# timeout: int = 300, # 5 minute timeout
|
|
625
|
+
# max_retries: int = 3,
|
|
626
|
+
# ) -> bytes:
|
|
627
|
+
# """
|
|
628
|
+
# Generate an image using a diffusion model through gRPC.
|
|
629
|
+
|
|
630
|
+
# Args:
|
|
631
|
+
# model_cid (str): The model identifier (e.g. "stabilityai/stable-diffusion-xl-base-1.0")
|
|
632
|
+
# prompt (str): The text prompt to generate the image from
|
|
633
|
+
# host (str, optional): gRPC host address. Defaults to DEFAULT_IMAGE_GEN_HOST.
|
|
634
|
+
# port (int, optional): gRPC port number. Defaults to DEFAULT_IMAGE_GEN_PORT.
|
|
635
|
+
# width (int, optional): Output image width. Defaults to 1024.
|
|
636
|
+
# height (int, optional): Output image height. Defaults to 1024.
|
|
637
|
+
# timeout (int, optional): Maximum time to wait for generation in seconds. Defaults to 300.
|
|
638
|
+
# max_retries (int, optional): Maximum number of retry attempts. Defaults to 3.
|
|
639
|
+
|
|
640
|
+
# Returns:
|
|
641
|
+
# bytes: The raw image data bytes
|
|
642
|
+
|
|
643
|
+
# Raises:
|
|
644
|
+
# OpenGradientError: If the image generation fails
|
|
645
|
+
# TimeoutError: If the generation exceeds the timeout period
|
|
646
|
+
# """
|
|
647
|
+
|
|
648
|
+
# def exponential_backoff(attempt: int, max_delay: float = 30.0) -> None:
|
|
649
|
+
# """Calculate and sleep for exponential backoff duration"""
|
|
650
|
+
# delay = min(0.1 * (2**attempt), max_delay)
|
|
651
|
+
# time.sleep(delay)
|
|
652
|
+
|
|
653
|
+
# channel = None
|
|
654
|
+
# start_time = time.time()
|
|
655
|
+
# retry_count = 0
|
|
656
|
+
|
|
657
|
+
# try:
|
|
658
|
+
# while retry_count < max_retries:
|
|
659
|
+
# try:
|
|
660
|
+
# # Initialize gRPC channel and stub
|
|
661
|
+
# channel = grpc.insecure_channel(f"{host}:{port}")
|
|
662
|
+
# stub = infer_pb2_grpc.InferenceServiceStub(channel)
|
|
663
|
+
|
|
664
|
+
# # Create image generation request
|
|
665
|
+
# image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
|
|
666
|
+
|
|
667
|
+
# # Create inference request with random transaction ID
|
|
668
|
+
# tx_id = str(uuid.uuid4())
|
|
669
|
+
# request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
|
|
670
|
+
|
|
671
|
+
# # Send request with timeout
|
|
672
|
+
# response_id = stub.RunInferenceAsync(
|
|
673
|
+
# request,
|
|
674
|
+
# timeout=min(30, timeout), # Initial request timeout
|
|
675
|
+
# )
|
|
676
|
+
|
|
677
|
+
# # Poll for completion
|
|
678
|
+
# attempt = 0
|
|
679
|
+
# while True:
|
|
680
|
+
# # Check timeout
|
|
681
|
+
# if time.time() - start_time > timeout:
|
|
682
|
+
# raise TimeoutError(f"Image generation timed out after {timeout} seconds")
|
|
683
|
+
|
|
684
|
+
# status_request = infer_pb2.InferenceTxId(id=response_id.id)
|
|
685
|
+
# try:
|
|
686
|
+
# status = stub.GetInferenceStatus(
|
|
687
|
+
# status_request,
|
|
688
|
+
# timeout=min(5, timeout), # Status check timeout
|
|
689
|
+
# ).status
|
|
690
|
+
# except grpc.RpcError as e:
|
|
691
|
+
# logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
|
|
692
|
+
# exponential_backoff(attempt)
|
|
693
|
+
# attempt += 1
|
|
694
|
+
# continue
|
|
695
|
+
|
|
696
|
+
# if status == infer_pb2.InferenceStatus.STATUS_COMPLETED:
|
|
697
|
+
# break
|
|
698
|
+
# elif status == infer_pb2.InferenceStatus.STATUS_ERROR:
|
|
699
|
+
# raise OpenGradientError("Image generation failed on server")
|
|
700
|
+
# elif status != infer_pb2.InferenceStatus.STATUS_IN_PROGRESS:
|
|
701
|
+
# raise OpenGradientError(f"Unexpected status: {status}")
|
|
702
|
+
|
|
703
|
+
# exponential_backoff(attempt)
|
|
704
|
+
# attempt += 1
|
|
705
|
+
|
|
706
|
+
# # Get result
|
|
707
|
+
# result = stub.GetInferenceResult(
|
|
708
|
+
# response_id,
|
|
709
|
+
# timeout=min(30, timeout), # Result fetch timeout
|
|
710
|
+
# )
|
|
711
|
+
# return result.image_generation_result.image_data
|
|
712
|
+
|
|
713
|
+
# except (grpc.RpcError, TimeoutError) as e:
|
|
714
|
+
# retry_count += 1
|
|
715
|
+
# if retry_count >= max_retries:
|
|
716
|
+
# raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
|
|
717
|
+
|
|
718
|
+
# logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
|
|
719
|
+
# exponential_backoff(retry_count)
|
|
720
|
+
|
|
721
|
+
# except grpc.RpcError as e:
|
|
722
|
+
# logging.error(f"gRPC error: {str(e)}")
|
|
723
|
+
# raise OpenGradientError(f"Image generation failed: {str(e)}")
|
|
724
|
+
# except TimeoutError as e:
|
|
725
|
+
# logging.error(f"Timeout error: {str(e)}")
|
|
726
|
+
# raise
|
|
727
|
+
# except Exception as e:
|
|
728
|
+
# logging.error(f"Error in generate image method: {str(e)}", exc_info=True)
|
|
729
|
+
# raise OpenGradientError(f"Image generation failed: {str(e)}")
|
|
730
|
+
# finally:
|
|
731
|
+
# if channel:
|
|
732
|
+
# channel.close()
|
|
733
|
+
|
|
734
|
+
def _get_abi(self, abi_name) -> str:
|
|
645
735
|
"""
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
Args:
|
|
649
|
-
model_cid (str): The model identifier (e.g. "stabilityai/stable-diffusion-xl-base-1.0")
|
|
650
|
-
prompt (str): The text prompt to generate the image from
|
|
651
|
-
host (str, optional): gRPC host address. Defaults to DEFAULT_IMAGE_GEN_HOST.
|
|
652
|
-
port (int, optional): gRPC port number. Defaults to DEFAULT_IMAGE_GEN_PORT.
|
|
653
|
-
width (int, optional): Output image width. Defaults to 1024.
|
|
654
|
-
height (int, optional): Output image height. Defaults to 1024.
|
|
655
|
-
timeout (int, optional): Maximum time to wait for generation in seconds. Defaults to 300.
|
|
656
|
-
max_retries (int, optional): Maximum number of retry attempts. Defaults to 3.
|
|
657
|
-
|
|
658
|
-
Returns:
|
|
659
|
-
bytes: The raw image data bytes
|
|
660
|
-
|
|
661
|
-
Raises:
|
|
662
|
-
OpenGradientError: If the image generation fails
|
|
663
|
-
TimeoutError: If the generation exceeds the timeout period
|
|
736
|
+
Returns the ABI for the requested contract.
|
|
664
737
|
"""
|
|
738
|
+
abi_path = Path(__file__).parent / "abi" / abi_name
|
|
739
|
+
with open(abi_path, "r") as f:
|
|
740
|
+
return json.load(f)
|
|
665
741
|
|
|
666
|
-
|
|
667
|
-
"""Calculate and sleep for exponential backoff duration"""
|
|
668
|
-
delay = min(0.1 * (2**attempt), max_delay)
|
|
669
|
-
time.sleep(delay)
|
|
670
|
-
|
|
671
|
-
channel = None
|
|
672
|
-
start_time = time.time()
|
|
673
|
-
retry_count = 0
|
|
674
|
-
|
|
675
|
-
try:
|
|
676
|
-
while retry_count < max_retries:
|
|
677
|
-
try:
|
|
678
|
-
# Initialize gRPC channel and stub
|
|
679
|
-
channel = grpc.insecure_channel(f"{host}:{port}")
|
|
680
|
-
stub = infer_pb2_grpc.InferenceServiceStub(channel)
|
|
681
|
-
|
|
682
|
-
# Create image generation request
|
|
683
|
-
image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
|
|
684
|
-
|
|
685
|
-
# Create inference request with random transaction ID
|
|
686
|
-
tx_id = str(uuid.uuid4())
|
|
687
|
-
request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
|
|
688
|
-
|
|
689
|
-
# Send request with timeout
|
|
690
|
-
response_id = stub.RunInferenceAsync(
|
|
691
|
-
request,
|
|
692
|
-
timeout=min(30, timeout), # Initial request timeout
|
|
693
|
-
)
|
|
694
|
-
|
|
695
|
-
# Poll for completion
|
|
696
|
-
attempt = 0
|
|
697
|
-
while True:
|
|
698
|
-
# Check timeout
|
|
699
|
-
if time.time() - start_time > timeout:
|
|
700
|
-
raise TimeoutError(f"Image generation timed out after {timeout} seconds")
|
|
701
|
-
|
|
702
|
-
status_request = infer_pb2.InferenceTxId(id=response_id.id)
|
|
703
|
-
try:
|
|
704
|
-
status = stub.GetInferenceStatus(
|
|
705
|
-
status_request,
|
|
706
|
-
timeout=min(5, timeout), # Status check timeout
|
|
707
|
-
).status
|
|
708
|
-
except grpc.RpcError as e:
|
|
709
|
-
logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
|
|
710
|
-
exponential_backoff(attempt)
|
|
711
|
-
attempt += 1
|
|
712
|
-
continue
|
|
713
|
-
|
|
714
|
-
if status == infer_pb2.InferenceStatus.STATUS_COMPLETED:
|
|
715
|
-
break
|
|
716
|
-
elif status == infer_pb2.InferenceStatus.STATUS_ERROR:
|
|
717
|
-
raise OpenGradientError("Image generation failed on server")
|
|
718
|
-
elif status != infer_pb2.InferenceStatus.STATUS_IN_PROGRESS:
|
|
719
|
-
raise OpenGradientError(f"Unexpected status: {status}")
|
|
720
|
-
|
|
721
|
-
exponential_backoff(attempt)
|
|
722
|
-
attempt += 1
|
|
723
|
-
|
|
724
|
-
# Get result
|
|
725
|
-
result = stub.GetInferenceResult(
|
|
726
|
-
response_id,
|
|
727
|
-
timeout=min(30, timeout), # Result fetch timeout
|
|
728
|
-
)
|
|
729
|
-
return result.image_generation_result.image_data
|
|
730
|
-
|
|
731
|
-
except (grpc.RpcError, TimeoutError) as e:
|
|
732
|
-
retry_count += 1
|
|
733
|
-
if retry_count >= max_retries:
|
|
734
|
-
raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
|
|
735
|
-
|
|
736
|
-
logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
|
|
737
|
-
exponential_backoff(retry_count)
|
|
738
|
-
|
|
739
|
-
except grpc.RpcError as e:
|
|
740
|
-
logging.error(f"gRPC error: {str(e)}")
|
|
741
|
-
raise OpenGradientError(f"Image generation failed: {str(e)}")
|
|
742
|
-
except TimeoutError as e:
|
|
743
|
-
logging.error(f"Timeout error: {str(e)}")
|
|
744
|
-
raise
|
|
745
|
-
except Exception as e:
|
|
746
|
-
logging.error(f"Error in generate image method: {str(e)}", exc_info=True)
|
|
747
|
-
raise OpenGradientError(f"Image generation failed: {str(e)}")
|
|
748
|
-
finally:
|
|
749
|
-
if channel:
|
|
750
|
-
channel.close()
|
|
751
|
-
|
|
752
|
-
def _get_model_executor_abi(self) -> List[Dict]:
|
|
742
|
+
def _get_bin(self, bin_name) -> str:
|
|
753
743
|
"""
|
|
754
|
-
Returns the
|
|
744
|
+
Returns the bin for the requested contract.
|
|
755
745
|
"""
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
746
|
+
bin_path = Path(__file__).parent / "bin" / bin_name
|
|
747
|
+
# Read bytecode with explicit encoding
|
|
748
|
+
with open(bin_path, "r", encoding="utf-8") as f:
|
|
749
|
+
bytecode = f.read().strip()
|
|
750
|
+
if not bytecode.startswith("0x"):
|
|
751
|
+
bytecode = "0x" + bytecode
|
|
752
|
+
return bytecode
|
|
759
753
|
|
|
760
754
|
def new_workflow(
|
|
761
755
|
self,
|
|
762
756
|
model_cid: str,
|
|
763
|
-
input_query:
|
|
757
|
+
input_query: HistoricalInputQuery,
|
|
764
758
|
input_tensor_name: str,
|
|
765
759
|
scheduler_params: Optional[SchedulerParams] = None,
|
|
766
760
|
) -> str:
|
|
767
761
|
"""
|
|
768
762
|
Deploy a new workflow contract with the specified parameters.
|
|
769
|
-
"""
|
|
770
|
-
if isinstance(input_query, dict):
|
|
771
|
-
input_query = HistoricalInputQuery.from_dict(input_query)
|
|
772
763
|
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
with open(bin_path, "r") as f:
|
|
778
|
-
bytecode = f.read().strip()
|
|
764
|
+
This function deploys a new workflow contract on OpenGradient that connects
|
|
765
|
+
an AI model with its required input data. When executed, the workflow will fetch
|
|
766
|
+
the specified model, evaluate the input query to get data, and perform inference.
|
|
779
767
|
|
|
780
|
-
|
|
768
|
+
The workflow can be set to execute manually or automatically via a scheduler.
|
|
781
769
|
|
|
782
|
-
|
|
783
|
-
|
|
770
|
+
Args:
|
|
771
|
+
model_cid (str): CID of the model to be executed from the Model Hub
|
|
772
|
+
input_query (HistoricalInputQuery): Input definition for the model inference,
|
|
773
|
+
will be evaluated at runtime for each inference
|
|
774
|
+
input_tensor_name (str): Name of the input tensor expected by the model
|
|
775
|
+
scheduler_params (Optional[SchedulerParams]): Scheduler configuration for automated execution:
|
|
776
|
+
- frequency: Execution frequency in seconds
|
|
777
|
+
- duration_hours: How long the schedule should live for
|
|
784
778
|
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
"from": self._wallet_account.address,
|
|
789
|
-
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
790
|
-
"gas": 15000000,
|
|
791
|
-
"gasPrice": self._blockchain.eth.gas_price,
|
|
792
|
-
"chainId": self._blockchain.eth.chain_id,
|
|
793
|
-
}
|
|
794
|
-
)
|
|
779
|
+
Returns:
|
|
780
|
+
str: Deployed contract address. If scheduler_params was provided, the workflow
|
|
781
|
+
will be automatically executed according to the specified schedule.
|
|
795
782
|
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
783
|
+
Raises:
|
|
784
|
+
Exception: If transaction fails or gas estimation fails
|
|
785
|
+
"""
|
|
786
|
+
# Get contract ABI and bytecode
|
|
787
|
+
abi = self._get_abi("PriceHistoryInference.abi")
|
|
788
|
+
bytecode = self._get_bin("PriceHistoryInference.bin")
|
|
800
789
|
|
|
801
|
-
|
|
790
|
+
def deploy_transaction():
|
|
791
|
+
contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode)
|
|
792
|
+
query_tuple = input_query.to_abi_format()
|
|
793
|
+
constructor_args = [model_cid, input_tensor_name, query_tuple]
|
|
802
794
|
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
795
|
+
try:
|
|
796
|
+
# Estimate gas needed
|
|
797
|
+
estimated_gas = contract.constructor(*constructor_args).estimate_gas({"from": self._wallet_account.address})
|
|
798
|
+
gas_limit = int(estimated_gas * 1.2)
|
|
799
|
+
except Exception as e:
|
|
800
|
+
print(f"⚠️ Gas estimation failed: {str(e)}")
|
|
801
|
+
gas_limit = 5000000 # Conservative fallback
|
|
802
|
+
print(f"📊 Using fallback gas limit: {gas_limit}")
|
|
809
803
|
|
|
810
|
-
|
|
804
|
+
transaction = contract.constructor(*constructor_args).build_transaction(
|
|
811
805
|
{
|
|
812
|
-
"
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
"name": "registerTask",
|
|
818
|
-
"outputs": [],
|
|
819
|
-
"stateMutability": "nonpayable",
|
|
820
|
-
"type": "function",
|
|
806
|
+
"from": self._wallet_account.address,
|
|
807
|
+
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
808
|
+
"gas": gas_limit,
|
|
809
|
+
"gasPrice": self._blockchain.eth.gas_price,
|
|
810
|
+
"chainId": self._blockchain.eth.chain_id,
|
|
821
811
|
}
|
|
822
|
-
|
|
812
|
+
)
|
|
823
813
|
|
|
824
|
-
|
|
825
|
-
|
|
814
|
+
signed_txn = self._wallet_account.sign_transaction(transaction)
|
|
815
|
+
tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
|
|
826
816
|
|
|
827
|
-
|
|
828
|
-
# Register the workflow with the scheduler
|
|
829
|
-
scheduler_tx = scheduler_contract.functions.registerTask(
|
|
830
|
-
contract_address, scheduler_params.end_time, scheduler_params.frequency
|
|
831
|
-
).build_transaction(
|
|
832
|
-
{
|
|
833
|
-
"from": self._wallet_account.address,
|
|
834
|
-
"gas": 300000,
|
|
835
|
-
"gasPrice": self._blockchain.eth.gas_price,
|
|
836
|
-
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
837
|
-
"chainId": self._blockchain.eth.chain_id,
|
|
838
|
-
}
|
|
839
|
-
)
|
|
817
|
+
tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=60)
|
|
840
818
|
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash, timeout=REGULAR_TX_TIMEOUT)
|
|
819
|
+
if tx_receipt["status"] == 0:
|
|
820
|
+
raise Exception(f"❌ Contract deployment failed, transaction hash: {tx_hash.hex()}")
|
|
844
821
|
|
|
845
|
-
|
|
846
|
-
print(f" Transaction hash: {scheduler_tx_hash.hex()}")
|
|
822
|
+
return tx_receipt.contractAddress
|
|
847
823
|
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
824
|
+
contract_address = run_with_retry(deploy_transaction)
|
|
825
|
+
|
|
826
|
+
if scheduler_params:
|
|
827
|
+
self._register_with_scheduler(contract_address, scheduler_params)
|
|
852
828
|
|
|
853
829
|
return contract_address
|
|
854
830
|
|
|
831
|
+
def _register_with_scheduler(self, contract_address: str, scheduler_params: SchedulerParams) -> None:
|
|
832
|
+
"""
|
|
833
|
+
Register the deployed workflow contract with the scheduler for automated execution.
|
|
834
|
+
|
|
835
|
+
Args:
|
|
836
|
+
contract_address (str): Address of the deployed workflow contract
|
|
837
|
+
scheduler_params (SchedulerParams): Scheduler configuration containing:
|
|
838
|
+
- frequency: Execution frequency in seconds
|
|
839
|
+
- duration_hours: How long to run in hours
|
|
840
|
+
- end_time: Unix timestamp when scheduling should end
|
|
841
|
+
|
|
842
|
+
Raises:
|
|
843
|
+
Exception: If registration with scheduler fails. The workflow contract will
|
|
844
|
+
still be deployed and can be executed manually.
|
|
845
|
+
"""
|
|
846
|
+
|
|
847
|
+
scheduler_abi = self._get_abi("WorkflowScheduler.abi")
|
|
848
|
+
|
|
849
|
+
# Scheduler contract address
|
|
850
|
+
scheduler_address = DEFAULT_SCHEDULER_ADDRESS
|
|
851
|
+
scheduler_contract = self._blockchain.eth.contract(address=scheduler_address, abi=scheduler_abi)
|
|
852
|
+
|
|
853
|
+
try:
|
|
854
|
+
# Register the workflow with the scheduler
|
|
855
|
+
scheduler_tx = scheduler_contract.functions.registerTask(
|
|
856
|
+
contract_address, scheduler_params.end_time, scheduler_params.frequency
|
|
857
|
+
).build_transaction(
|
|
858
|
+
{
|
|
859
|
+
"from": self._wallet_account.address,
|
|
860
|
+
"gas": 300000,
|
|
861
|
+
"gasPrice": self._blockchain.eth.gas_price,
|
|
862
|
+
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
863
|
+
"chainId": self._blockchain.eth.chain_id,
|
|
864
|
+
}
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx)
|
|
868
|
+
scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction)
|
|
869
|
+
self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash, timeout=REGULAR_TX_TIMEOUT)
|
|
870
|
+
except Exception as e:
|
|
871
|
+
print(f"❌ Error registering contract with scheduler: {str(e)}")
|
|
872
|
+
print(" The workflow contract is still deployed and can be executed manually.")
|
|
873
|
+
|
|
855
874
|
def read_workflow_result(self, contract_address: str) -> ModelOutput:
|
|
856
875
|
"""
|
|
857
876
|
Reads the latest inference result from a deployed workflow contract.
|
|
@@ -867,12 +886,14 @@ class Client:
|
|
|
867
886
|
Web3Error: If there are issues with the web3 connection or contract interaction
|
|
868
887
|
"""
|
|
869
888
|
# Get the contract interface
|
|
870
|
-
contract = self._blockchain.eth.contract(
|
|
889
|
+
contract = self._blockchain.eth.contract(
|
|
890
|
+
address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
|
|
891
|
+
)
|
|
871
892
|
|
|
872
893
|
# Get the result
|
|
873
894
|
result = contract.functions.getInferenceResult().call()
|
|
874
895
|
|
|
875
|
-
return
|
|
896
|
+
return convert_array_to_model_output(result)
|
|
876
897
|
|
|
877
898
|
def run_workflow(self, contract_address: str) -> ModelOutput:
|
|
878
899
|
"""
|
|
@@ -889,7 +910,9 @@ class Client:
|
|
|
889
910
|
Web3Error: If there are issues with the web3 connection or contract interaction
|
|
890
911
|
"""
|
|
891
912
|
# Get the contract interface
|
|
892
|
-
contract = self._blockchain.eth.contract(
|
|
913
|
+
contract = self._blockchain.eth.contract(
|
|
914
|
+
address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
|
|
915
|
+
)
|
|
893
916
|
|
|
894
917
|
# Call run() function
|
|
895
918
|
nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
|
|
@@ -915,7 +938,28 @@ class Client:
|
|
|
915
938
|
# Get the inference result from the contract
|
|
916
939
|
result = contract.functions.getInferenceResult().call()
|
|
917
940
|
|
|
918
|
-
return
|
|
941
|
+
return convert_array_to_model_output(result)
|
|
942
|
+
|
|
943
|
+
def read_workflow_history(self, contract_address: str, num_results: int) -> List[ModelOutput]:
|
|
944
|
+
"""
|
|
945
|
+
Gets historical inference results from a workflow contract.
|
|
946
|
+
|
|
947
|
+
Retrieves the specified number of most recent inference results from the contract's
|
|
948
|
+
storage, with the most recent result first.
|
|
949
|
+
|
|
950
|
+
Args:
|
|
951
|
+
contract_address (str): Address of the deployed workflow contract
|
|
952
|
+
num_results (int): Number of historical results to retrieve
|
|
953
|
+
|
|
954
|
+
Returns:
|
|
955
|
+
List[ModelOutput]: List of historical inference results
|
|
956
|
+
"""
|
|
957
|
+
contract = self._blockchain.eth.contract(
|
|
958
|
+
address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
results = contract.functions.getLastInferenceResults(num_results).call()
|
|
962
|
+
return [convert_array_to_model_output(result) for result in results]
|
|
919
963
|
|
|
920
964
|
|
|
921
965
|
def run_with_retry(txn_function, max_retries=DEFAULT_MAX_RETRY, retry_delay=DEFAULT_RETRY_DELAY_SEC):
|