opengradient 0.4.6__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opengradient/__init__.py +59 -67
- opengradient/abi/PriceHistoryInference.abi +1 -0
- opengradient/abi/WorkflowScheduler.abi +13 -0
- opengradient/alphasense/read_workflow_tool.py +1 -1
- opengradient/alphasense/run_model_tool.py +3 -3
- opengradient/bin/PriceHistoryInference.bin +1 -0
- opengradient/cli.py +8 -4
- opengradient/client.py +282 -217
- opengradient/defaults.py +1 -0
- opengradient/llm/__init__.py +1 -1
- opengradient/llm/og_langchain.py +36 -22
- opengradient/llm/og_openai.py +1 -1
- opengradient/types.py +22 -20
- opengradient/utils.py +2 -0
- opengradient-0.4.7.dist-info/METADATA +159 -0
- opengradient-0.4.7.dist-info/RECORD +29 -0
- {opengradient-0.4.6.dist-info → opengradient-0.4.7.dist-info}/WHEEL +1 -1
- opengradient/abi/ModelExecutorHistorical.abi +0 -1
- opengradient-0.4.6.dist-info/METADATA +0 -189
- opengradient-0.4.6.dist-info/RECORD +0 -27
- {opengradient-0.4.6.dist-info → opengradient-0.4.7.dist-info}/LICENSE +0 -0
- {opengradient-0.4.6.dist-info → opengradient-0.4.7.dist-info}/entry_points.txt +0 -0
- {opengradient-0.4.6.dist-info → opengradient-0.4.7.dist-info}/top_level.txt +0 -0
opengradient/client.py
CHANGED
|
@@ -2,12 +2,10 @@ import json
|
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
import time
|
|
5
|
-
import uuid
|
|
6
5
|
from pathlib import Path
|
|
7
|
-
from typing import Any, Dict, List, Optional,
|
|
6
|
+
from typing import Any, Dict, List, Optional, Union
|
|
8
7
|
|
|
9
8
|
import firebase
|
|
10
|
-
import grpc
|
|
11
9
|
import numpy as np
|
|
12
10
|
import requests
|
|
13
11
|
from eth_account.account import LocalAccount
|
|
@@ -18,8 +16,18 @@ from web3.logs import DISCARD
|
|
|
18
16
|
from . import utils
|
|
19
17
|
from .exceptions import OpenGradientError
|
|
20
18
|
from .proto import infer_pb2, infer_pb2_grpc
|
|
21
|
-
from .types import
|
|
22
|
-
|
|
19
|
+
from .types import (
|
|
20
|
+
LLM,
|
|
21
|
+
TEE_LLM,
|
|
22
|
+
HistoricalInputQuery,
|
|
23
|
+
InferenceMode,
|
|
24
|
+
LlmInferenceMode,
|
|
25
|
+
ModelOutput,
|
|
26
|
+
TextGenerationOutput,
|
|
27
|
+
SchedulerParams,
|
|
28
|
+
InferenceResult,
|
|
29
|
+
)
|
|
30
|
+
from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT, DEFAULT_SCHEDULER_ADDRESS
|
|
23
31
|
|
|
24
32
|
_FIREBASE_CONFIG = {
|
|
25
33
|
"apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
|
|
@@ -48,7 +56,7 @@ class Client:
|
|
|
48
56
|
_hub_user: Dict
|
|
49
57
|
_inference_abi: Dict
|
|
50
58
|
|
|
51
|
-
def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: str, password: str):
|
|
59
|
+
def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: Optional[str], password: Optional[str]):
|
|
52
60
|
"""
|
|
53
61
|
Initialize the Client with private key, RPC URL, and contract address.
|
|
54
62
|
|
|
@@ -137,7 +145,7 @@ class Client:
|
|
|
137
145
|
logging.error(f"Unexpected error during model creation: {str(e)}")
|
|
138
146
|
raise
|
|
139
147
|
|
|
140
|
-
def create_version(self, model_name: str, notes: str =
|
|
148
|
+
def create_version(self, model_name: str, notes: str = "", is_major: bool = False) -> dict:
|
|
141
149
|
"""
|
|
142
150
|
Create a new version for the specified model.
|
|
143
151
|
|
|
@@ -273,9 +281,7 @@ class Client:
|
|
|
273
281
|
if hasattr(e, "response") and e.response is not None:
|
|
274
282
|
logging.error(f"Response status code: {e.response.status_code}")
|
|
275
283
|
logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
|
|
276
|
-
raise OpenGradientError(
|
|
277
|
-
f"Upload failed due to request exception: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
|
|
278
|
-
)
|
|
284
|
+
raise OpenGradientError(f"Upload failed due to request exception: {str(e)}")
|
|
279
285
|
except Exception as e:
|
|
280
286
|
logging.error(f"Unexpected error during upload: {str(e)}", exc_info=True)
|
|
281
287
|
raise OpenGradientError(f"Unexpected error during upload: {str(e)}")
|
|
@@ -286,7 +292,7 @@ class Client:
|
|
|
286
292
|
inference_mode: InferenceMode,
|
|
287
293
|
model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
|
|
288
294
|
max_retries: Optional[int] = None,
|
|
289
|
-
) ->
|
|
295
|
+
) -> InferenceResult:
|
|
290
296
|
"""
|
|
291
297
|
Perform inference on a model.
|
|
292
298
|
|
|
@@ -297,7 +303,7 @@ class Client:
|
|
|
297
303
|
max_retries (int, optional): Maximum number of retry attempts. Defaults to 5.
|
|
298
304
|
|
|
299
305
|
Returns:
|
|
300
|
-
|
|
306
|
+
InferenceResult: The transaction hash and the model output.
|
|
301
307
|
|
|
302
308
|
Raises:
|
|
303
309
|
OpenGradientError: If the inference fails.
|
|
@@ -306,7 +312,7 @@ class Client:
|
|
|
306
312
|
def execute_transaction():
|
|
307
313
|
contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
|
|
308
314
|
|
|
309
|
-
inference_mode_uint8 =
|
|
315
|
+
inference_mode_uint8 = inference_mode.value
|
|
310
316
|
converted_model_input = utils.convert_to_model_input(model_input)
|
|
311
317
|
|
|
312
318
|
run_function = contract.functions.run(model_cid, inference_mode_uint8, converted_model_input)
|
|
@@ -337,14 +343,15 @@ class Client:
|
|
|
337
343
|
|
|
338
344
|
# TODO: This should return a ModelOutput class object
|
|
339
345
|
model_output = utils.convert_to_model_output(parsed_logs[0]["args"])
|
|
340
|
-
|
|
346
|
+
|
|
347
|
+
return InferenceResult(tx_hash.hex(), model_output)
|
|
341
348
|
|
|
342
349
|
return run_with_retry(execute_transaction, max_retries)
|
|
343
350
|
|
|
344
351
|
def llm_completion(
|
|
345
352
|
self,
|
|
346
353
|
model_cid: LLM,
|
|
347
|
-
inference_mode:
|
|
354
|
+
inference_mode: LlmInferenceMode,
|
|
348
355
|
prompt: str,
|
|
349
356
|
max_tokens: int = 100,
|
|
350
357
|
stop_sequence: Optional[List[str]] = None,
|
|
@@ -383,7 +390,7 @@ class Client:
|
|
|
383
390
|
|
|
384
391
|
# Prepare LLM input
|
|
385
392
|
llm_request = {
|
|
386
|
-
"mode": inference_mode,
|
|
393
|
+
"mode": inference_mode.value,
|
|
387
394
|
"modelCID": model_cid,
|
|
388
395
|
"prompt": prompt,
|
|
389
396
|
"max_tokens": max_tokens,
|
|
@@ -420,18 +427,15 @@ class Client:
|
|
|
420
427
|
raise OpenGradientError("LLM completion result event not found in transaction logs")
|
|
421
428
|
|
|
422
429
|
llm_answer = parsed_logs[0]["args"]["response"]["answer"]
|
|
423
|
-
|
|
424
|
-
return TextGenerationOutput(
|
|
425
|
-
transaction_hash=tx_hash.hex(),
|
|
426
|
-
completion_output=llm_answer
|
|
427
|
-
)
|
|
430
|
+
|
|
431
|
+
return TextGenerationOutput(transaction_hash=tx_hash.hex(), completion_output=llm_answer)
|
|
428
432
|
|
|
429
433
|
return run_with_retry(execute_transaction, max_retries)
|
|
430
434
|
|
|
431
435
|
def llm_chat(
|
|
432
436
|
self,
|
|
433
437
|
model_cid: LLM,
|
|
434
|
-
inference_mode:
|
|
438
|
+
inference_mode: LlmInferenceMode,
|
|
435
439
|
messages: List[Dict],
|
|
436
440
|
max_tokens: int = 100,
|
|
437
441
|
stop_sequence: Optional[List[str]] = None,
|
|
@@ -536,7 +540,7 @@ class Client:
|
|
|
536
540
|
|
|
537
541
|
# Prepare LLM input
|
|
538
542
|
llm_request = {
|
|
539
|
-
"mode": inference_mode,
|
|
543
|
+
"mode": inference_mode.value,
|
|
540
544
|
"modelCID": model_cid,
|
|
541
545
|
"messages": messages,
|
|
542
546
|
"max_tokens": max_tokens,
|
|
@@ -624,233 +628,265 @@ class Client:
|
|
|
624
628
|
if hasattr(e, "response") and e.response is not None:
|
|
625
629
|
logging.error(f"Response status code: {e.response.status_code}")
|
|
626
630
|
logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
|
|
627
|
-
raise OpenGradientError(
|
|
628
|
-
f"File listing failed: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
|
|
629
|
-
)
|
|
631
|
+
raise OpenGradientError(f"File listing failed: {str(e)}")
|
|
630
632
|
except Exception as e:
|
|
631
633
|
logging.error(f"Unexpected error during file listing: {str(e)}", exc_info=True)
|
|
632
634
|
raise OpenGradientError(f"Unexpected error during file listing: {str(e)}")
|
|
633
635
|
|
|
634
|
-
def generate_image(
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
) -> bytes:
|
|
636
|
+
# def generate_image(
|
|
637
|
+
# self,
|
|
638
|
+
# model_cid: str,
|
|
639
|
+
# prompt: str,
|
|
640
|
+
# host: str = DEFAULT_IMAGE_GEN_HOST,
|
|
641
|
+
# port: int = DEFAULT_IMAGE_GEN_PORT,
|
|
642
|
+
# width: int = 1024,
|
|
643
|
+
# height: int = 1024,
|
|
644
|
+
# timeout: int = 300, # 5 minute timeout
|
|
645
|
+
# max_retries: int = 3,
|
|
646
|
+
# ) -> bytes:
|
|
647
|
+
# """
|
|
648
|
+
# Generate an image using a diffusion model through gRPC.
|
|
649
|
+
|
|
650
|
+
# Args:
|
|
651
|
+
# model_cid (str): The model identifier (e.g. "stabilityai/stable-diffusion-xl-base-1.0")
|
|
652
|
+
# prompt (str): The text prompt to generate the image from
|
|
653
|
+
# host (str, optional): gRPC host address. Defaults to DEFAULT_IMAGE_GEN_HOST.
|
|
654
|
+
# port (int, optional): gRPC port number. Defaults to DEFAULT_IMAGE_GEN_PORT.
|
|
655
|
+
# width (int, optional): Output image width. Defaults to 1024.
|
|
656
|
+
# height (int, optional): Output image height. Defaults to 1024.
|
|
657
|
+
# timeout (int, optional): Maximum time to wait for generation in seconds. Defaults to 300.
|
|
658
|
+
# max_retries (int, optional): Maximum number of retry attempts. Defaults to 3.
|
|
659
|
+
|
|
660
|
+
# Returns:
|
|
661
|
+
# bytes: The raw image data bytes
|
|
662
|
+
|
|
663
|
+
# Raises:
|
|
664
|
+
# OpenGradientError: If the image generation fails
|
|
665
|
+
# TimeoutError: If the generation exceeds the timeout period
|
|
666
|
+
# """
|
|
667
|
+
|
|
668
|
+
# def exponential_backoff(attempt: int, max_delay: float = 30.0) -> None:
|
|
669
|
+
# """Calculate and sleep for exponential backoff duration"""
|
|
670
|
+
# delay = min(0.1 * (2**attempt), max_delay)
|
|
671
|
+
# time.sleep(delay)
|
|
672
|
+
|
|
673
|
+
# channel = None
|
|
674
|
+
# start_time = time.time()
|
|
675
|
+
# retry_count = 0
|
|
676
|
+
|
|
677
|
+
# try:
|
|
678
|
+
# while retry_count < max_retries:
|
|
679
|
+
# try:
|
|
680
|
+
# # Initialize gRPC channel and stub
|
|
681
|
+
# channel = grpc.insecure_channel(f"{host}:{port}")
|
|
682
|
+
# stub = infer_pb2_grpc.InferenceServiceStub(channel)
|
|
683
|
+
|
|
684
|
+
# # Create image generation request
|
|
685
|
+
# image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
|
|
686
|
+
|
|
687
|
+
# # Create inference request with random transaction ID
|
|
688
|
+
# tx_id = str(uuid.uuid4())
|
|
689
|
+
# request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
|
|
690
|
+
|
|
691
|
+
# # Send request with timeout
|
|
692
|
+
# response_id = stub.RunInferenceAsync(
|
|
693
|
+
# request,
|
|
694
|
+
# timeout=min(30, timeout), # Initial request timeout
|
|
695
|
+
# )
|
|
696
|
+
|
|
697
|
+
# # Poll for completion
|
|
698
|
+
# attempt = 0
|
|
699
|
+
# while True:
|
|
700
|
+
# # Check timeout
|
|
701
|
+
# if time.time() - start_time > timeout:
|
|
702
|
+
# raise TimeoutError(f"Image generation timed out after {timeout} seconds")
|
|
703
|
+
|
|
704
|
+
# status_request = infer_pb2.InferenceTxId(id=response_id.id)
|
|
705
|
+
# try:
|
|
706
|
+
# status = stub.GetInferenceStatus(
|
|
707
|
+
# status_request,
|
|
708
|
+
# timeout=min(5, timeout), # Status check timeout
|
|
709
|
+
# ).status
|
|
710
|
+
# except grpc.RpcError as e:
|
|
711
|
+
# logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
|
|
712
|
+
# exponential_backoff(attempt)
|
|
713
|
+
# attempt += 1
|
|
714
|
+
# continue
|
|
715
|
+
|
|
716
|
+
# if status == infer_pb2.InferenceStatus.STATUS_COMPLETED:
|
|
717
|
+
# break
|
|
718
|
+
# elif status == infer_pb2.InferenceStatus.STATUS_ERROR:
|
|
719
|
+
# raise OpenGradientError("Image generation failed on server")
|
|
720
|
+
# elif status != infer_pb2.InferenceStatus.STATUS_IN_PROGRESS:
|
|
721
|
+
# raise OpenGradientError(f"Unexpected status: {status}")
|
|
722
|
+
|
|
723
|
+
# exponential_backoff(attempt)
|
|
724
|
+
# attempt += 1
|
|
725
|
+
|
|
726
|
+
# # Get result
|
|
727
|
+
# result = stub.GetInferenceResult(
|
|
728
|
+
# response_id,
|
|
729
|
+
# timeout=min(30, timeout), # Result fetch timeout
|
|
730
|
+
# )
|
|
731
|
+
# return result.image_generation_result.image_data
|
|
732
|
+
|
|
733
|
+
# except (grpc.RpcError, TimeoutError) as e:
|
|
734
|
+
# retry_count += 1
|
|
735
|
+
# if retry_count >= max_retries:
|
|
736
|
+
# raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
|
|
737
|
+
|
|
738
|
+
# logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
|
|
739
|
+
# exponential_backoff(retry_count)
|
|
740
|
+
|
|
741
|
+
# except grpc.RpcError as e:
|
|
742
|
+
# logging.error(f"gRPC error: {str(e)}")
|
|
743
|
+
# raise OpenGradientError(f"Image generation failed: {str(e)}")
|
|
744
|
+
# except TimeoutError as e:
|
|
745
|
+
# logging.error(f"Timeout error: {str(e)}")
|
|
746
|
+
# raise
|
|
747
|
+
# except Exception as e:
|
|
748
|
+
# logging.error(f"Error in generate image method: {str(e)}", exc_info=True)
|
|
749
|
+
# raise OpenGradientError(f"Image generation failed: {str(e)}")
|
|
750
|
+
# finally:
|
|
751
|
+
# if channel:
|
|
752
|
+
# channel.close()
|
|
753
|
+
|
|
754
|
+
def _get_abi(self, abi_name) -> List[Dict]:
|
|
645
755
|
"""
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
Args:
|
|
649
|
-
model_cid (str): The model identifier (e.g. "stabilityai/stable-diffusion-xl-base-1.0")
|
|
650
|
-
prompt (str): The text prompt to generate the image from
|
|
651
|
-
host (str, optional): gRPC host address. Defaults to DEFAULT_IMAGE_GEN_HOST.
|
|
652
|
-
port (int, optional): gRPC port number. Defaults to DEFAULT_IMAGE_GEN_PORT.
|
|
653
|
-
width (int, optional): Output image width. Defaults to 1024.
|
|
654
|
-
height (int, optional): Output image height. Defaults to 1024.
|
|
655
|
-
timeout (int, optional): Maximum time to wait for generation in seconds. Defaults to 300.
|
|
656
|
-
max_retries (int, optional): Maximum number of retry attempts. Defaults to 3.
|
|
657
|
-
|
|
658
|
-
Returns:
|
|
659
|
-
bytes: The raw image data bytes
|
|
660
|
-
|
|
661
|
-
Raises:
|
|
662
|
-
OpenGradientError: If the image generation fails
|
|
663
|
-
TimeoutError: If the generation exceeds the timeout period
|
|
756
|
+
Returns the ABI for the requested contract.
|
|
664
757
|
"""
|
|
758
|
+
abi_path = Path(__file__).parent / "abi" / abi_name
|
|
759
|
+
with open(abi_path, "r") as f:
|
|
760
|
+
return json.load(f)
|
|
665
761
|
|
|
666
|
-
|
|
667
|
-
"""Calculate and sleep for exponential backoff duration"""
|
|
668
|
-
delay = min(0.1 * (2**attempt), max_delay)
|
|
669
|
-
time.sleep(delay)
|
|
670
|
-
|
|
671
|
-
channel = None
|
|
672
|
-
start_time = time.time()
|
|
673
|
-
retry_count = 0
|
|
674
|
-
|
|
675
|
-
try:
|
|
676
|
-
while retry_count < max_retries:
|
|
677
|
-
try:
|
|
678
|
-
# Initialize gRPC channel and stub
|
|
679
|
-
channel = grpc.insecure_channel(f"{host}:{port}")
|
|
680
|
-
stub = infer_pb2_grpc.InferenceServiceStub(channel)
|
|
681
|
-
|
|
682
|
-
# Create image generation request
|
|
683
|
-
image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
|
|
684
|
-
|
|
685
|
-
# Create inference request with random transaction ID
|
|
686
|
-
tx_id = str(uuid.uuid4())
|
|
687
|
-
request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
|
|
688
|
-
|
|
689
|
-
# Send request with timeout
|
|
690
|
-
response_id = stub.RunInferenceAsync(
|
|
691
|
-
request,
|
|
692
|
-
timeout=min(30, timeout), # Initial request timeout
|
|
693
|
-
)
|
|
694
|
-
|
|
695
|
-
# Poll for completion
|
|
696
|
-
attempt = 0
|
|
697
|
-
while True:
|
|
698
|
-
# Check timeout
|
|
699
|
-
if time.time() - start_time > timeout:
|
|
700
|
-
raise TimeoutError(f"Image generation timed out after {timeout} seconds")
|
|
701
|
-
|
|
702
|
-
status_request = infer_pb2.InferenceTxId(id=response_id.id)
|
|
703
|
-
try:
|
|
704
|
-
status = stub.GetInferenceStatus(
|
|
705
|
-
status_request,
|
|
706
|
-
timeout=min(5, timeout), # Status check timeout
|
|
707
|
-
).status
|
|
708
|
-
except grpc.RpcError as e:
|
|
709
|
-
logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
|
|
710
|
-
exponential_backoff(attempt)
|
|
711
|
-
attempt += 1
|
|
712
|
-
continue
|
|
713
|
-
|
|
714
|
-
if status == infer_pb2.InferenceStatus.STATUS_COMPLETED:
|
|
715
|
-
break
|
|
716
|
-
elif status == infer_pb2.InferenceStatus.STATUS_ERROR:
|
|
717
|
-
raise OpenGradientError("Image generation failed on server")
|
|
718
|
-
elif status != infer_pb2.InferenceStatus.STATUS_IN_PROGRESS:
|
|
719
|
-
raise OpenGradientError(f"Unexpected status: {status}")
|
|
720
|
-
|
|
721
|
-
exponential_backoff(attempt)
|
|
722
|
-
attempt += 1
|
|
723
|
-
|
|
724
|
-
# Get result
|
|
725
|
-
result = stub.GetInferenceResult(
|
|
726
|
-
response_id,
|
|
727
|
-
timeout=min(30, timeout), # Result fetch timeout
|
|
728
|
-
)
|
|
729
|
-
return result.image_generation_result.image_data
|
|
730
|
-
|
|
731
|
-
except (grpc.RpcError, TimeoutError) as e:
|
|
732
|
-
retry_count += 1
|
|
733
|
-
if retry_count >= max_retries:
|
|
734
|
-
raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
|
|
735
|
-
|
|
736
|
-
logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
|
|
737
|
-
exponential_backoff(retry_count)
|
|
738
|
-
|
|
739
|
-
except grpc.RpcError as e:
|
|
740
|
-
logging.error(f"gRPC error: {str(e)}")
|
|
741
|
-
raise OpenGradientError(f"Image generation failed: {str(e)}")
|
|
742
|
-
except TimeoutError as e:
|
|
743
|
-
logging.error(f"Timeout error: {str(e)}")
|
|
744
|
-
raise
|
|
745
|
-
except Exception as e:
|
|
746
|
-
logging.error(f"Error in generate image method: {str(e)}", exc_info=True)
|
|
747
|
-
raise OpenGradientError(f"Image generation failed: {str(e)}")
|
|
748
|
-
finally:
|
|
749
|
-
if channel:
|
|
750
|
-
channel.close()
|
|
751
|
-
|
|
752
|
-
def _get_model_executor_abi(self) -> List[Dict]:
|
|
762
|
+
def _get_bin(self, bin_name) -> List[Dict]:
|
|
753
763
|
"""
|
|
754
|
-
Returns the
|
|
764
|
+
Returns the bin for the requested contract.
|
|
755
765
|
"""
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
766
|
+
bin_path = Path(__file__).parent / "bin" / bin_name
|
|
767
|
+
# Read bytecode with explicit encoding
|
|
768
|
+
with open(bin_path, "r", encoding="utf-8") as f:
|
|
769
|
+
bytecode = f.read().strip()
|
|
770
|
+
if not bytecode.startswith("0x"):
|
|
771
|
+
bytecode = "0x" + bytecode
|
|
772
|
+
return bytecode
|
|
759
773
|
|
|
760
774
|
def new_workflow(
|
|
761
775
|
self,
|
|
762
776
|
model_cid: str,
|
|
763
|
-
input_query:
|
|
777
|
+
input_query: HistoricalInputQuery,
|
|
764
778
|
input_tensor_name: str,
|
|
765
779
|
scheduler_params: Optional[SchedulerParams] = None,
|
|
766
780
|
) -> str:
|
|
767
781
|
"""
|
|
768
782
|
Deploy a new workflow contract with the specified parameters.
|
|
769
|
-
"""
|
|
770
|
-
if isinstance(input_query, dict):
|
|
771
|
-
input_query = HistoricalInputQuery.from_dict(input_query)
|
|
772
783
|
|
|
784
|
+
This function deploys a new workflow contract and optionally registers it with
|
|
785
|
+
the scheduler for automated execution. If scheduler_params is not provided,
|
|
786
|
+
the workflow will be deployed without automated execution scheduling.
|
|
787
|
+
|
|
788
|
+
Args:
|
|
789
|
+
model_cid (str): IPFS CID of the model to be executed
|
|
790
|
+
input_query (HistoricalInputQuery): Query parameters for data input
|
|
791
|
+
input_tensor_name (str): Name of the input tensor expected by the model
|
|
792
|
+
scheduler_params (Optional[SchedulerParams]): Scheduler configuration for automated execution:
|
|
793
|
+
- frequency: Execution frequency in seconds
|
|
794
|
+
- duration_hours: How long to run in hours
|
|
795
|
+
|
|
796
|
+
Returns:
|
|
797
|
+
str: Deployed contract address. If scheduler_params was provided, the workflow
|
|
798
|
+
will be automatically executed according to the specified schedule.
|
|
799
|
+
|
|
800
|
+
Raises:
|
|
801
|
+
Exception: If transaction fails or gas estimation fails
|
|
802
|
+
"""
|
|
773
803
|
# Get contract ABI and bytecode
|
|
774
|
-
abi = self.
|
|
775
|
-
|
|
804
|
+
abi = self._get_abi("PriceHistoryInference.abi")
|
|
805
|
+
bytecode = self._get_bin("PriceHistoryInference.bin")
|
|
776
806
|
|
|
777
|
-
|
|
778
|
-
|
|
807
|
+
def deploy_transaction():
|
|
808
|
+
contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode)
|
|
809
|
+
query_tuple = input_query.to_abi_format()
|
|
810
|
+
constructor_args = [model_cid, input_tensor_name, query_tuple]
|
|
779
811
|
|
|
780
|
-
|
|
812
|
+
try:
|
|
813
|
+
# Estimate gas needed
|
|
814
|
+
estimated_gas = contract.constructor(*constructor_args).estimate_gas({"from": self._wallet_account.address})
|
|
815
|
+
gas_limit = int(estimated_gas * 1.2)
|
|
816
|
+
except Exception as e:
|
|
817
|
+
print(f"⚠️ Gas estimation failed: {str(e)}")
|
|
818
|
+
gas_limit = 5000000 # Conservative fallback
|
|
819
|
+
print(f"📊 Using fallback gas limit: {gas_limit}")
|
|
781
820
|
|
|
782
|
-
|
|
783
|
-
|
|
821
|
+
transaction = contract.constructor(*constructor_args).build_transaction(
|
|
822
|
+
{
|
|
823
|
+
"from": self._wallet_account.address,
|
|
824
|
+
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
825
|
+
"gas": gas_limit,
|
|
826
|
+
"gasPrice": self._blockchain.eth.gas_price,
|
|
827
|
+
"chainId": self._blockchain.eth.chain_id,
|
|
828
|
+
}
|
|
829
|
+
)
|
|
784
830
|
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
{
|
|
788
|
-
"from": self._wallet_account.address,
|
|
789
|
-
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
790
|
-
"gas": 15000000,
|
|
791
|
-
"gasPrice": self._blockchain.eth.gas_price,
|
|
792
|
-
"chainId": self._blockchain.eth.chain_id,
|
|
793
|
-
}
|
|
794
|
-
)
|
|
831
|
+
signed_txn = self._wallet_account.sign_transaction(transaction)
|
|
832
|
+
tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
|
|
795
833
|
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
834
|
+
tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=60)
|
|
835
|
+
|
|
836
|
+
if tx_receipt["status"] == 0:
|
|
837
|
+
raise Exception(f"❌ Contract deployment failed, transaction hash: {tx_hash.hex()}")
|
|
838
|
+
|
|
839
|
+
return tx_receipt.contractAddress
|
|
800
840
|
|
|
801
|
-
|
|
841
|
+
contract_address = run_with_retry(deploy_transaction)
|
|
802
842
|
|
|
803
|
-
# Register with scheduler if params provided
|
|
804
843
|
if scheduler_params:
|
|
805
|
-
|
|
806
|
-
print(f" • Frequency: Every {scheduler_params.frequency} seconds")
|
|
807
|
-
print(f" • Duration: {scheduler_params.duration_hours} hours")
|
|
808
|
-
print(f" • End Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(scheduler_params.end_time))}")
|
|
844
|
+
self._register_with_scheduler(contract_address, scheduler_params)
|
|
809
845
|
|
|
810
|
-
|
|
811
|
-
{
|
|
812
|
-
"inputs": [
|
|
813
|
-
{"internalType": "address", "name": "contractAddress", "type": "address"},
|
|
814
|
-
{"internalType": "uint256", "name": "endTime", "type": "uint256"},
|
|
815
|
-
{"internalType": "uint256", "name": "frequency", "type": "uint256"},
|
|
816
|
-
],
|
|
817
|
-
"name": "registerTask",
|
|
818
|
-
"outputs": [],
|
|
819
|
-
"stateMutability": "nonpayable",
|
|
820
|
-
"type": "function",
|
|
821
|
-
}
|
|
822
|
-
]
|
|
846
|
+
return contract_address
|
|
823
847
|
|
|
824
|
-
|
|
825
|
-
|
|
848
|
+
def _register_with_scheduler(self, contract_address: str, scheduler_params: SchedulerParams) -> None:
|
|
849
|
+
"""
|
|
850
|
+
Register the deployed workflow contract with the scheduler for automated execution.
|
|
826
851
|
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
"from": self._wallet_account.address,
|
|
834
|
-
"gas": 300000,
|
|
835
|
-
"gasPrice": self._blockchain.eth.gas_price,
|
|
836
|
-
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
837
|
-
"chainId": self._blockchain.eth.chain_id,
|
|
838
|
-
}
|
|
839
|
-
)
|
|
852
|
+
Args:
|
|
853
|
+
contract_address (str): Address of the deployed workflow contract
|
|
854
|
+
scheduler_params (SchedulerParams): Scheduler configuration containing:
|
|
855
|
+
- frequency: Execution frequency in seconds
|
|
856
|
+
- duration_hours: How long to run in hours
|
|
857
|
+
- end_time: Unix timestamp when scheduling should end
|
|
840
858
|
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
859
|
+
Raises:
|
|
860
|
+
Exception: If registration with scheduler fails. The workflow contract will
|
|
861
|
+
still be deployed and can be executed manually.
|
|
862
|
+
"""
|
|
844
863
|
|
|
845
|
-
|
|
846
|
-
print(f" Transaction hash: {scheduler_tx_hash.hex()}")
|
|
864
|
+
scheduler_abi = self._get_abi("WorkflowScheduler.abi")
|
|
847
865
|
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
print(" The workflow contract is still deployed and can be executed manually.")
|
|
866
|
+
# Scheduler contract address
|
|
867
|
+
scheduler_address = DEFAULT_SCHEDULER_ADDRESS
|
|
868
|
+
scheduler_contract = self._blockchain.eth.contract(address=scheduler_address, abi=scheduler_abi)
|
|
852
869
|
|
|
853
|
-
|
|
870
|
+
try:
|
|
871
|
+
# Register the workflow with the scheduler
|
|
872
|
+
scheduler_tx = scheduler_contract.functions.registerTask(
|
|
873
|
+
contract_address, scheduler_params.end_time, scheduler_params.frequency
|
|
874
|
+
).build_transaction(
|
|
875
|
+
{
|
|
876
|
+
"from": self._wallet_account.address,
|
|
877
|
+
"gas": 300000,
|
|
878
|
+
"gasPrice": self._blockchain.eth.gas_price,
|
|
879
|
+
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
880
|
+
"chainId": self._blockchain.eth.chain_id,
|
|
881
|
+
}
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx)
|
|
885
|
+
scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction)
|
|
886
|
+
self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash, timeout=REGULAR_TX_TIMEOUT)
|
|
887
|
+
except Exception as e:
|
|
888
|
+
print(f"❌ Error registering contract with scheduler: {str(e)}")
|
|
889
|
+
print(" The workflow contract is still deployed and can be executed manually.")
|
|
854
890
|
|
|
855
891
|
def read_workflow_result(self, contract_address: str) -> ModelOutput:
|
|
856
892
|
"""
|
|
@@ -867,7 +903,9 @@ class Client:
|
|
|
867
903
|
Web3Error: If there are issues with the web3 connection or contract interaction
|
|
868
904
|
"""
|
|
869
905
|
# Get the contract interface
|
|
870
|
-
contract = self._blockchain.eth.contract(
|
|
906
|
+
contract = self._blockchain.eth.contract(
|
|
907
|
+
address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
|
|
908
|
+
)
|
|
871
909
|
|
|
872
910
|
# Get the result
|
|
873
911
|
result = contract.functions.getInferenceResult().call()
|
|
@@ -889,7 +927,9 @@ class Client:
|
|
|
889
927
|
Web3Error: If there are issues with the web3 connection or contract interaction
|
|
890
928
|
"""
|
|
891
929
|
# Get the contract interface
|
|
892
|
-
contract = self._blockchain.eth.contract(
|
|
930
|
+
contract = self._blockchain.eth.contract(
|
|
931
|
+
address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
|
|
932
|
+
)
|
|
893
933
|
|
|
894
934
|
# Call run() function
|
|
895
935
|
nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
|
|
@@ -917,6 +957,31 @@ class Client:
|
|
|
917
957
|
|
|
918
958
|
return utils.convert_array_to_model_output(result)
|
|
919
959
|
|
|
960
|
+
def read_workflow_history(self, contract_address: str, num_results: int) -> List[Dict]:
|
|
961
|
+
"""
|
|
962
|
+
Gets historical inference results from a workflow contract.
|
|
963
|
+
|
|
964
|
+
Retrieves the specified number of most recent inference results from the contract's
|
|
965
|
+
storage, with the most recent result first.
|
|
966
|
+
|
|
967
|
+
Args:
|
|
968
|
+
contract_address (str): Address of the deployed workflow contract
|
|
969
|
+
num_results (int): Number of historical results to retrieve
|
|
970
|
+
|
|
971
|
+
Returns:
|
|
972
|
+
List[Dict]: List of historical inference results, each containing:
|
|
973
|
+
- prediction values
|
|
974
|
+
- timestamps
|
|
975
|
+
- any additional metadata stored with the result
|
|
976
|
+
|
|
977
|
+
"""
|
|
978
|
+
contract = self._blockchain.eth.contract(
|
|
979
|
+
address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
results = contract.functions.getLastInferenceResults(num_results).call()
|
|
983
|
+
return [utils.convert_array_to_model_output(result) for result in results]
|
|
984
|
+
|
|
920
985
|
|
|
921
986
|
def run_with_retry(txn_function, max_retries=DEFAULT_MAX_RETRY, retry_delay=DEFAULT_RETRY_DELAY_SEC):
|
|
922
987
|
"""
|
opengradient/defaults.py
CHANGED
|
@@ -3,6 +3,7 @@ DEFAULT_RPC_URL = "http://18.188.176.119:8545"
|
|
|
3
3
|
DEFAULT_OG_FAUCET_URL = "https://faucet.opengradient.ai/?address="
|
|
4
4
|
DEFAULT_HUB_SIGNUP_URL = "https://hub.opengradient.ai/signup"
|
|
5
5
|
DEFAULT_INFERENCE_CONTRACT_ADDRESS = "0x8383C9bD7462F12Eb996DD02F78234C0421A6FaE"
|
|
6
|
+
DEFAULT_SCHEDULER_ADDRESS = "0x7179724De4e7FF9271FA40C0337c7f90C0508eF6"
|
|
6
7
|
DEFAULT_BLOCKCHAIN_EXPLORER = "https://explorer.opengradient.ai/tx/"
|
|
7
8
|
DEFAULT_IMAGE_GEN_HOST = "18.217.25.69"
|
|
8
9
|
DEFAULT_IMAGE_GEN_PORT = 5125
|