opengradient 0.4.6__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
opengradient/client.py CHANGED
@@ -2,12 +2,10 @@ import json
2
2
  import logging
3
3
  import os
4
4
  import time
5
- import uuid
6
5
  from pathlib import Path
7
- from typing import Any, Dict, List, Optional, Tuple, Union
6
+ from typing import Any, Dict, List, Optional, Union
8
7
 
9
8
  import firebase
10
- import grpc
11
9
  import numpy as np
12
10
  import requests
13
11
  from eth_account.account import LocalAccount
@@ -18,8 +16,18 @@ from web3.logs import DISCARD
18
16
  from . import utils
19
17
  from .exceptions import OpenGradientError
20
18
  from .proto import infer_pb2, infer_pb2_grpc
21
- from .types import LLM, TEE_LLM, HistoricalInputQuery, InferenceMode, LlmInferenceMode, ModelOutput, TextGenerationOutput, SchedulerParams
22
- from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT
19
+ from .types import (
20
+ LLM,
21
+ TEE_LLM,
22
+ HistoricalInputQuery,
23
+ InferenceMode,
24
+ LlmInferenceMode,
25
+ ModelOutput,
26
+ TextGenerationOutput,
27
+ SchedulerParams,
28
+ InferenceResult,
29
+ )
30
+ from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT, DEFAULT_SCHEDULER_ADDRESS
23
31
 
24
32
  _FIREBASE_CONFIG = {
25
33
  "apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
@@ -48,7 +56,7 @@ class Client:
48
56
  _hub_user: Dict
49
57
  _inference_abi: Dict
50
58
 
51
- def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: str, password: str):
59
+ def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: Optional[str], password: Optional[str]):
52
60
  """
53
61
  Initialize the Client with private key, RPC URL, and contract address.
54
62
 
@@ -137,7 +145,7 @@ class Client:
137
145
  logging.error(f"Unexpected error during model creation: {str(e)}")
138
146
  raise
139
147
 
140
- def create_version(self, model_name: str, notes: str = None, is_major: bool = False) -> dict:
148
+ def create_version(self, model_name: str, notes: str = "", is_major: bool = False) -> dict:
141
149
  """
142
150
  Create a new version for the specified model.
143
151
 
@@ -273,9 +281,7 @@ class Client:
273
281
  if hasattr(e, "response") and e.response is not None:
274
282
  logging.error(f"Response status code: {e.response.status_code}")
275
283
  logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
276
- raise OpenGradientError(
277
- f"Upload failed due to request exception: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
278
- )
284
+ raise OpenGradientError(f"Upload failed due to request exception: {str(e)}")
279
285
  except Exception as e:
280
286
  logging.error(f"Unexpected error during upload: {str(e)}", exc_info=True)
281
287
  raise OpenGradientError(f"Unexpected error during upload: {str(e)}")
@@ -286,7 +292,7 @@ class Client:
286
292
  inference_mode: InferenceMode,
287
293
  model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
288
294
  max_retries: Optional[int] = None,
289
- ) -> Tuple[str, Dict[str, np.ndarray]]:
295
+ ) -> InferenceResult:
290
296
  """
291
297
  Perform inference on a model.
292
298
 
@@ -297,7 +303,7 @@ class Client:
297
303
  max_retries (int, optional): Maximum number of retry attempts. Defaults to 5.
298
304
 
299
305
  Returns:
300
- Tuple[str, Dict[str, np.ndarray]]: The transaction hash and the model output.
306
+ InferenceResult: The transaction hash and the model output.
301
307
 
302
308
  Raises:
303
309
  OpenGradientError: If the inference fails.
@@ -306,7 +312,7 @@ class Client:
306
312
  def execute_transaction():
307
313
  contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
308
314
 
309
- inference_mode_uint8 = int(inference_mode)
315
+ inference_mode_uint8 = inference_mode.value
310
316
  converted_model_input = utils.convert_to_model_input(model_input)
311
317
 
312
318
  run_function = contract.functions.run(model_cid, inference_mode_uint8, converted_model_input)
@@ -337,14 +343,15 @@ class Client:
337
343
 
338
344
  # TODO: This should return a ModelOutput class object
339
345
  model_output = utils.convert_to_model_output(parsed_logs[0]["args"])
340
- return tx_hash.hex(), model_output
346
+
347
+ return InferenceResult(tx_hash.hex(), model_output)
341
348
 
342
349
  return run_with_retry(execute_transaction, max_retries)
343
350
 
344
351
  def llm_completion(
345
352
  self,
346
353
  model_cid: LLM,
347
- inference_mode: InferenceMode,
354
+ inference_mode: LlmInferenceMode,
348
355
  prompt: str,
349
356
  max_tokens: int = 100,
350
357
  stop_sequence: Optional[List[str]] = None,
@@ -383,7 +390,7 @@ class Client:
383
390
 
384
391
  # Prepare LLM input
385
392
  llm_request = {
386
- "mode": inference_mode,
393
+ "mode": inference_mode.value,
387
394
  "modelCID": model_cid,
388
395
  "prompt": prompt,
389
396
  "max_tokens": max_tokens,
@@ -420,18 +427,15 @@ class Client:
420
427
  raise OpenGradientError("LLM completion result event not found in transaction logs")
421
428
 
422
429
  llm_answer = parsed_logs[0]["args"]["response"]["answer"]
423
-
424
- return TextGenerationOutput(
425
- transaction_hash=tx_hash.hex(),
426
- completion_output=llm_answer
427
- )
430
+
431
+ return TextGenerationOutput(transaction_hash=tx_hash.hex(), completion_output=llm_answer)
428
432
 
429
433
  return run_with_retry(execute_transaction, max_retries)
430
434
 
431
435
  def llm_chat(
432
436
  self,
433
437
  model_cid: LLM,
434
- inference_mode: InferenceMode,
438
+ inference_mode: LlmInferenceMode,
435
439
  messages: List[Dict],
436
440
  max_tokens: int = 100,
437
441
  stop_sequence: Optional[List[str]] = None,
@@ -536,7 +540,7 @@ class Client:
536
540
 
537
541
  # Prepare LLM input
538
542
  llm_request = {
539
- "mode": inference_mode,
543
+ "mode": inference_mode.value,
540
544
  "modelCID": model_cid,
541
545
  "messages": messages,
542
546
  "max_tokens": max_tokens,
@@ -624,233 +628,265 @@ class Client:
624
628
  if hasattr(e, "response") and e.response is not None:
625
629
  logging.error(f"Response status code: {e.response.status_code}")
626
630
  logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
627
- raise OpenGradientError(
628
- f"File listing failed: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
629
- )
631
+ raise OpenGradientError(f"File listing failed: {str(e)}")
630
632
  except Exception as e:
631
633
  logging.error(f"Unexpected error during file listing: {str(e)}", exc_info=True)
632
634
  raise OpenGradientError(f"Unexpected error during file listing: {str(e)}")
633
635
 
634
- def generate_image(
635
- self,
636
- model_cid: str,
637
- prompt: str,
638
- host: str = DEFAULT_IMAGE_GEN_HOST,
639
- port: int = DEFAULT_IMAGE_GEN_PORT,
640
- width: int = 1024,
641
- height: int = 1024,
642
- timeout: int = 300, # 5 minute timeout
643
- max_retries: int = 3,
644
- ) -> bytes:
636
+ # def generate_image(
637
+ # self,
638
+ # model_cid: str,
639
+ # prompt: str,
640
+ # host: str = DEFAULT_IMAGE_GEN_HOST,
641
+ # port: int = DEFAULT_IMAGE_GEN_PORT,
642
+ # width: int = 1024,
643
+ # height: int = 1024,
644
+ # timeout: int = 300, # 5 minute timeout
645
+ # max_retries: int = 3,
646
+ # ) -> bytes:
647
+ # """
648
+ # Generate an image using a diffusion model through gRPC.
649
+
650
+ # Args:
651
+ # model_cid (str): The model identifier (e.g. "stabilityai/stable-diffusion-xl-base-1.0")
652
+ # prompt (str): The text prompt to generate the image from
653
+ # host (str, optional): gRPC host address. Defaults to DEFAULT_IMAGE_GEN_HOST.
654
+ # port (int, optional): gRPC port number. Defaults to DEFAULT_IMAGE_GEN_PORT.
655
+ # width (int, optional): Output image width. Defaults to 1024.
656
+ # height (int, optional): Output image height. Defaults to 1024.
657
+ # timeout (int, optional): Maximum time to wait for generation in seconds. Defaults to 300.
658
+ # max_retries (int, optional): Maximum number of retry attempts. Defaults to 3.
659
+
660
+ # Returns:
661
+ # bytes: The raw image data bytes
662
+
663
+ # Raises:
664
+ # OpenGradientError: If the image generation fails
665
+ # TimeoutError: If the generation exceeds the timeout period
666
+ # """
667
+
668
+ # def exponential_backoff(attempt: int, max_delay: float = 30.0) -> None:
669
+ # """Calculate and sleep for exponential backoff duration"""
670
+ # delay = min(0.1 * (2**attempt), max_delay)
671
+ # time.sleep(delay)
672
+
673
+ # channel = None
674
+ # start_time = time.time()
675
+ # retry_count = 0
676
+
677
+ # try:
678
+ # while retry_count < max_retries:
679
+ # try:
680
+ # # Initialize gRPC channel and stub
681
+ # channel = grpc.insecure_channel(f"{host}:{port}")
682
+ # stub = infer_pb2_grpc.InferenceServiceStub(channel)
683
+
684
+ # # Create image generation request
685
+ # image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
686
+
687
+ # # Create inference request with random transaction ID
688
+ # tx_id = str(uuid.uuid4())
689
+ # request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
690
+
691
+ # # Send request with timeout
692
+ # response_id = stub.RunInferenceAsync(
693
+ # request,
694
+ # timeout=min(30, timeout), # Initial request timeout
695
+ # )
696
+
697
+ # # Poll for completion
698
+ # attempt = 0
699
+ # while True:
700
+ # # Check timeout
701
+ # if time.time() - start_time > timeout:
702
+ # raise TimeoutError(f"Image generation timed out after {timeout} seconds")
703
+
704
+ # status_request = infer_pb2.InferenceTxId(id=response_id.id)
705
+ # try:
706
+ # status = stub.GetInferenceStatus(
707
+ # status_request,
708
+ # timeout=min(5, timeout), # Status check timeout
709
+ # ).status
710
+ # except grpc.RpcError as e:
711
+ # logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
712
+ # exponential_backoff(attempt)
713
+ # attempt += 1
714
+ # continue
715
+
716
+ # if status == infer_pb2.InferenceStatus.STATUS_COMPLETED:
717
+ # break
718
+ # elif status == infer_pb2.InferenceStatus.STATUS_ERROR:
719
+ # raise OpenGradientError("Image generation failed on server")
720
+ # elif status != infer_pb2.InferenceStatus.STATUS_IN_PROGRESS:
721
+ # raise OpenGradientError(f"Unexpected status: {status}")
722
+
723
+ # exponential_backoff(attempt)
724
+ # attempt += 1
725
+
726
+ # # Get result
727
+ # result = stub.GetInferenceResult(
728
+ # response_id,
729
+ # timeout=min(30, timeout), # Result fetch timeout
730
+ # )
731
+ # return result.image_generation_result.image_data
732
+
733
+ # except (grpc.RpcError, TimeoutError) as e:
734
+ # retry_count += 1
735
+ # if retry_count >= max_retries:
736
+ # raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
737
+
738
+ # logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
739
+ # exponential_backoff(retry_count)
740
+
741
+ # except grpc.RpcError as e:
742
+ # logging.error(f"gRPC error: {str(e)}")
743
+ # raise OpenGradientError(f"Image generation failed: {str(e)}")
744
+ # except TimeoutError as e:
745
+ # logging.error(f"Timeout error: {str(e)}")
746
+ # raise
747
+ # except Exception as e:
748
+ # logging.error(f"Error in generate image method: {str(e)}", exc_info=True)
749
+ # raise OpenGradientError(f"Image generation failed: {str(e)}")
750
+ # finally:
751
+ # if channel:
752
+ # channel.close()
753
+
754
+ def _get_abi(self, abi_name) -> List[Dict]:
645
755
  """
646
- Generate an image using a diffusion model through gRPC.
647
-
648
- Args:
649
- model_cid (str): The model identifier (e.g. "stabilityai/stable-diffusion-xl-base-1.0")
650
- prompt (str): The text prompt to generate the image from
651
- host (str, optional): gRPC host address. Defaults to DEFAULT_IMAGE_GEN_HOST.
652
- port (int, optional): gRPC port number. Defaults to DEFAULT_IMAGE_GEN_PORT.
653
- width (int, optional): Output image width. Defaults to 1024.
654
- height (int, optional): Output image height. Defaults to 1024.
655
- timeout (int, optional): Maximum time to wait for generation in seconds. Defaults to 300.
656
- max_retries (int, optional): Maximum number of retry attempts. Defaults to 3.
657
-
658
- Returns:
659
- bytes: The raw image data bytes
660
-
661
- Raises:
662
- OpenGradientError: If the image generation fails
663
- TimeoutError: If the generation exceeds the timeout period
756
+ Returns the ABI for the requested contract.
664
757
  """
758
+ abi_path = Path(__file__).parent / "abi" / abi_name
759
+ with open(abi_path, "r") as f:
760
+ return json.load(f)
665
761
 
666
- def exponential_backoff(attempt: int, max_delay: float = 30.0) -> None:
667
- """Calculate and sleep for exponential backoff duration"""
668
- delay = min(0.1 * (2**attempt), max_delay)
669
- time.sleep(delay)
670
-
671
- channel = None
672
- start_time = time.time()
673
- retry_count = 0
674
-
675
- try:
676
- while retry_count < max_retries:
677
- try:
678
- # Initialize gRPC channel and stub
679
- channel = grpc.insecure_channel(f"{host}:{port}")
680
- stub = infer_pb2_grpc.InferenceServiceStub(channel)
681
-
682
- # Create image generation request
683
- image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
684
-
685
- # Create inference request with random transaction ID
686
- tx_id = str(uuid.uuid4())
687
- request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
688
-
689
- # Send request with timeout
690
- response_id = stub.RunInferenceAsync(
691
- request,
692
- timeout=min(30, timeout), # Initial request timeout
693
- )
694
-
695
- # Poll for completion
696
- attempt = 0
697
- while True:
698
- # Check timeout
699
- if time.time() - start_time > timeout:
700
- raise TimeoutError(f"Image generation timed out after {timeout} seconds")
701
-
702
- status_request = infer_pb2.InferenceTxId(id=response_id.id)
703
- try:
704
- status = stub.GetInferenceStatus(
705
- status_request,
706
- timeout=min(5, timeout), # Status check timeout
707
- ).status
708
- except grpc.RpcError as e:
709
- logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
710
- exponential_backoff(attempt)
711
- attempt += 1
712
- continue
713
-
714
- if status == infer_pb2.InferenceStatus.STATUS_COMPLETED:
715
- break
716
- elif status == infer_pb2.InferenceStatus.STATUS_ERROR:
717
- raise OpenGradientError("Image generation failed on server")
718
- elif status != infer_pb2.InferenceStatus.STATUS_IN_PROGRESS:
719
- raise OpenGradientError(f"Unexpected status: {status}")
720
-
721
- exponential_backoff(attempt)
722
- attempt += 1
723
-
724
- # Get result
725
- result = stub.GetInferenceResult(
726
- response_id,
727
- timeout=min(30, timeout), # Result fetch timeout
728
- )
729
- return result.image_generation_result.image_data
730
-
731
- except (grpc.RpcError, TimeoutError) as e:
732
- retry_count += 1
733
- if retry_count >= max_retries:
734
- raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
735
-
736
- logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
737
- exponential_backoff(retry_count)
738
-
739
- except grpc.RpcError as e:
740
- logging.error(f"gRPC error: {str(e)}")
741
- raise OpenGradientError(f"Image generation failed: {str(e)}")
742
- except TimeoutError as e:
743
- logging.error(f"Timeout error: {str(e)}")
744
- raise
745
- except Exception as e:
746
- logging.error(f"Error in generate image method: {str(e)}", exc_info=True)
747
- raise OpenGradientError(f"Image generation failed: {str(e)}")
748
- finally:
749
- if channel:
750
- channel.close()
751
-
752
- def _get_model_executor_abi(self) -> List[Dict]:
762
+ def _get_bin(self, bin_name) -> List[Dict]:
753
763
  """
754
- Returns the ABI for the ModelExecutorHistorical contract.
764
+ Returns the bin for the requested contract.
755
765
  """
756
- abi_path = Path(__file__).parent / "abi" / "ModelExecutorHistorical.abi"
757
- with open(abi_path, "r") as f:
758
- return json.load(f)
766
+ bin_path = Path(__file__).parent / "bin" / bin_name
767
+ # Read bytecode with explicit encoding
768
+ with open(bin_path, "r", encoding="utf-8") as f:
769
+ bytecode = f.read().strip()
770
+ if not bytecode.startswith("0x"):
771
+ bytecode = "0x" + bytecode
772
+ return bytecode
759
773
 
760
774
  def new_workflow(
761
775
  self,
762
776
  model_cid: str,
763
- input_query: Union[Dict[str, Any], HistoricalInputQuery],
777
+ input_query: HistoricalInputQuery,
764
778
  input_tensor_name: str,
765
779
  scheduler_params: Optional[SchedulerParams] = None,
766
780
  ) -> str:
767
781
  """
768
782
  Deploy a new workflow contract with the specified parameters.
769
- """
770
- if isinstance(input_query, dict):
771
- input_query = HistoricalInputQuery.from_dict(input_query)
772
783
 
784
+ This function deploys a new workflow contract and optionally registers it with
785
+ the scheduler for automated execution. If scheduler_params is not provided,
786
+ the workflow will be deployed without automated execution scheduling.
787
+
788
+ Args:
789
+ model_cid (str): IPFS CID of the model to be executed
790
+ input_query (HistoricalInputQuery): Query parameters for data input
791
+ input_tensor_name (str): Name of the input tensor expected by the model
792
+ scheduler_params (Optional[SchedulerParams]): Scheduler configuration for automated execution:
793
+ - frequency: Execution frequency in seconds
794
+ - duration_hours: How long to run in hours
795
+
796
+ Returns:
797
+ str: Deployed contract address. If scheduler_params was provided, the workflow
798
+ will be automatically executed according to the specified schedule.
799
+
800
+ Raises:
801
+ Exception: If transaction fails or gas estimation fails
802
+ """
773
803
  # Get contract ABI and bytecode
774
- abi = self._get_model_executor_abi()
775
- bin_path = Path(__file__).parent / "bin" / "ModelExecutorHistorical.bin"
804
+ abi = self._get_abi("PriceHistoryInference.abi")
805
+ bytecode = self._get_bin("PriceHistoryInference.bin")
776
806
 
777
- with open(bin_path, "r") as f:
778
- bytecode = f.read().strip()
807
+ def deploy_transaction():
808
+ contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode)
809
+ query_tuple = input_query.to_abi_format()
810
+ constructor_args = [model_cid, input_tensor_name, query_tuple]
779
811
 
780
- print("📦 Deploying workflow contract...")
812
+ try:
813
+ # Estimate gas needed
814
+ estimated_gas = contract.constructor(*constructor_args).estimate_gas({"from": self._wallet_account.address})
815
+ gas_limit = int(estimated_gas * 1.2)
816
+ except Exception as e:
817
+ print(f"⚠️ Gas estimation failed: {str(e)}")
818
+ gas_limit = 5000000 # Conservative fallback
819
+ print(f"📊 Using fallback gas limit: {gas_limit}")
781
820
 
782
- # Create contract instance
783
- contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode)
821
+ transaction = contract.constructor(*constructor_args).build_transaction(
822
+ {
823
+ "from": self._wallet_account.address,
824
+ "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
825
+ "gas": gas_limit,
826
+ "gasPrice": self._blockchain.eth.gas_price,
827
+ "chainId": self._blockchain.eth.chain_id,
828
+ }
829
+ )
784
830
 
785
- # Deploy contract with constructor arguments
786
- transaction = contract.constructor().build_transaction(
787
- {
788
- "from": self._wallet_account.address,
789
- "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
790
- "gas": 15000000,
791
- "gasPrice": self._blockchain.eth.gas_price,
792
- "chainId": self._blockchain.eth.chain_id,
793
- }
794
- )
831
+ signed_txn = self._wallet_account.sign_transaction(transaction)
832
+ tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
795
833
 
796
- signed_txn = self._wallet_account.sign_transaction(transaction)
797
- tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
798
- tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=REGULAR_TX_TIMEOUT)
799
- contract_address = tx_receipt.contractAddress
834
+ tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=60)
835
+
836
+ if tx_receipt["status"] == 0:
837
+ raise Exception(f"❌ Contract deployment failed, transaction hash: {tx_hash.hex()}")
838
+
839
+ return tx_receipt.contractAddress
800
840
 
801
- print(f"✅ Workflow contract deployed at: {contract_address}")
841
+ contract_address = run_with_retry(deploy_transaction)
802
842
 
803
- # Register with scheduler if params provided
804
843
  if scheduler_params:
805
- print("\n⏰ Setting up automated execution schedule...")
806
- print(f" • Frequency: Every {scheduler_params.frequency} seconds")
807
- print(f" • Duration: {scheduler_params.duration_hours} hours")
808
- print(f" • End Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(scheduler_params.end_time))}")
844
+ self._register_with_scheduler(contract_address, scheduler_params)
809
845
 
810
- scheduler_abi = [
811
- {
812
- "inputs": [
813
- {"internalType": "address", "name": "contractAddress", "type": "address"},
814
- {"internalType": "uint256", "name": "endTime", "type": "uint256"},
815
- {"internalType": "uint256", "name": "frequency", "type": "uint256"},
816
- ],
817
- "name": "registerTask",
818
- "outputs": [],
819
- "stateMutability": "nonpayable",
820
- "type": "function",
821
- }
822
- ]
846
+ return contract_address
823
847
 
824
- scheduler_address = "0x6F937b9f4Fa7723932827cd73063B70Be2b56748"
825
- scheduler_contract = self._blockchain.eth.contract(address=scheduler_address, abi=scheduler_abi)
848
+ def _register_with_scheduler(self, contract_address: str, scheduler_params: SchedulerParams) -> None:
849
+ """
850
+ Register the deployed workflow contract with the scheduler for automated execution.
826
851
 
827
- try:
828
- # Register the workflow with the scheduler
829
- scheduler_tx = scheduler_contract.functions.registerTask(
830
- contract_address, scheduler_params.end_time, scheduler_params.frequency
831
- ).build_transaction(
832
- {
833
- "from": self._wallet_account.address,
834
- "gas": 300000,
835
- "gasPrice": self._blockchain.eth.gas_price,
836
- "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
837
- "chainId": self._blockchain.eth.chain_id,
838
- }
839
- )
852
+ Args:
853
+ contract_address (str): Address of the deployed workflow contract
854
+ scheduler_params (SchedulerParams): Scheduler configuration containing:
855
+ - frequency: Execution frequency in seconds
856
+ - duration_hours: How long to run in hours
857
+ - end_time: Unix timestamp when scheduling should end
840
858
 
841
- signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx)
842
- scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction)
843
- self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash, timeout=REGULAR_TX_TIMEOUT)
859
+ Raises:
860
+ Exception: If registration with scheduler fails. The workflow contract will
861
+ still be deployed and can be executed manually.
862
+ """
844
863
 
845
- print("✅ Automated execution schedule set successfully!")
846
- print(f" Transaction hash: {scheduler_tx_hash.hex()}")
864
+ scheduler_abi = self._get_abi("WorkflowScheduler.abi")
847
865
 
848
- except Exception as e:
849
- print("❌ Failed to set up automated execution schedule")
850
- print(f" Error: {str(e)}")
851
- print(" The workflow contract is still deployed and can be executed manually.")
866
+ # Scheduler contract address
867
+ scheduler_address = DEFAULT_SCHEDULER_ADDRESS
868
+ scheduler_contract = self._blockchain.eth.contract(address=scheduler_address, abi=scheduler_abi)
852
869
 
853
- return contract_address
870
+ try:
871
+ # Register the workflow with the scheduler
872
+ scheduler_tx = scheduler_contract.functions.registerTask(
873
+ contract_address, scheduler_params.end_time, scheduler_params.frequency
874
+ ).build_transaction(
875
+ {
876
+ "from": self._wallet_account.address,
877
+ "gas": 300000,
878
+ "gasPrice": self._blockchain.eth.gas_price,
879
+ "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
880
+ "chainId": self._blockchain.eth.chain_id,
881
+ }
882
+ )
883
+
884
+ signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx)
885
+ scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction)
886
+ self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash, timeout=REGULAR_TX_TIMEOUT)
887
+ except Exception as e:
888
+ print(f"❌ Error registering contract with scheduler: {str(e)}")
889
+ print(" The workflow contract is still deployed and can be executed manually.")
854
890
 
855
891
  def read_workflow_result(self, contract_address: str) -> ModelOutput:
856
892
  """
@@ -867,7 +903,9 @@ class Client:
867
903
  Web3Error: If there are issues with the web3 connection or contract interaction
868
904
  """
869
905
  # Get the contract interface
870
- contract = self._blockchain.eth.contract(address=Web3.to_checksum_address(contract_address), abi=self._get_model_executor_abi())
906
+ contract = self._blockchain.eth.contract(
907
+ address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
908
+ )
871
909
 
872
910
  # Get the result
873
911
  result = contract.functions.getInferenceResult().call()
@@ -889,7 +927,9 @@ class Client:
889
927
  Web3Error: If there are issues with the web3 connection or contract interaction
890
928
  """
891
929
  # Get the contract interface
892
- contract = self._blockchain.eth.contract(address=Web3.to_checksum_address(contract_address), abi=self._get_model_executor_abi())
930
+ contract = self._blockchain.eth.contract(
931
+ address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
932
+ )
893
933
 
894
934
  # Call run() function
895
935
  nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
@@ -917,6 +957,31 @@ class Client:
917
957
 
918
958
  return utils.convert_array_to_model_output(result)
919
959
 
960
+ def read_workflow_history(self, contract_address: str, num_results: int) -> List[Dict]:
961
+ """
962
+ Gets historical inference results from a workflow contract.
963
+
964
+ Retrieves the specified number of most recent inference results from the contract's
965
+ storage, with the most recent result first.
966
+
967
+ Args:
968
+ contract_address (str): Address of the deployed workflow contract
969
+ num_results (int): Number of historical results to retrieve
970
+
971
+ Returns:
972
+ List[Dict]: List of historical inference results, each containing:
973
+ - prediction values
974
+ - timestamps
975
+ - any additional metadata stored with the result
976
+
977
+ """
978
+ contract = self._blockchain.eth.contract(
979
+ address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
980
+ )
981
+
982
+ results = contract.functions.getLastInferenceResults(num_results).call()
983
+ return [utils.convert_array_to_model_output(result) for result in results]
984
+
920
985
 
921
986
  def run_with_retry(txn_function, max_retries=DEFAULT_MAX_RETRY, retry_delay=DEFAULT_RETRY_DELAY_SEC):
922
987
  """
opengradient/defaults.py CHANGED
@@ -3,6 +3,7 @@ DEFAULT_RPC_URL = "http://18.188.176.119:8545"
3
3
  DEFAULT_OG_FAUCET_URL = "https://faucet.opengradient.ai/?address="
4
4
  DEFAULT_HUB_SIGNUP_URL = "https://hub.opengradient.ai/signup"
5
5
  DEFAULT_INFERENCE_CONTRACT_ADDRESS = "0x8383C9bD7462F12Eb996DD02F78234C0421A6FaE"
6
+ DEFAULT_SCHEDULER_ADDRESS = "0x7179724De4e7FF9271FA40C0337c7f90C0508eF6"
6
7
  DEFAULT_BLOCKCHAIN_EXPLORER = "https://explorer.opengradient.ai/tx/"
7
8
  DEFAULT_IMAGE_GEN_HOST = "18.217.25.69"
8
9
  DEFAULT_IMAGE_GEN_PORT = 5125