opengradient 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
opengradient/client.py CHANGED
@@ -2,12 +2,10 @@ import json
2
2
  import logging
3
3
  import os
4
4
  import time
5
- import uuid
6
5
  from pathlib import Path
7
- from typing import Any, Dict, List, Optional, Tuple, Union
6
+ from typing import Any, Dict, List, Optional, Union
8
7
 
9
8
  import firebase
10
- import grpc
11
9
  import numpy as np
12
10
  import requests
13
11
  from eth_account.account import LocalAccount
@@ -18,8 +16,18 @@ from web3.logs import DISCARD
18
16
  from . import utils
19
17
  from .exceptions import OpenGradientError
20
18
  from .proto import infer_pb2, infer_pb2_grpc
21
- from .types import LLM, TEE_LLM, HistoricalInputQuery, InferenceMode, LlmInferenceMode, ModelOutput, SchedulerParams
22
- from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT
19
+ from .types import (
20
+ LLM,
21
+ TEE_LLM,
22
+ HistoricalInputQuery,
23
+ InferenceMode,
24
+ LlmInferenceMode,
25
+ ModelOutput,
26
+ TextGenerationOutput,
27
+ SchedulerParams,
28
+ InferenceResult,
29
+ )
30
+ from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT, DEFAULT_SCHEDULER_ADDRESS
23
31
 
24
32
  _FIREBASE_CONFIG = {
25
33
  "apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
@@ -48,7 +56,7 @@ class Client:
48
56
  _hub_user: Dict
49
57
  _inference_abi: Dict
50
58
 
51
- def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: str, password: str):
59
+ def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: Optional[str], password: Optional[str]):
52
60
  """
53
61
  Initialize the Client with private key, RPC URL, and contract address.
54
62
 
@@ -137,7 +145,7 @@ class Client:
137
145
  logging.error(f"Unexpected error during model creation: {str(e)}")
138
146
  raise
139
147
 
140
- def create_version(self, model_name: str, notes: str = None, is_major: bool = False) -> dict:
148
+ def create_version(self, model_name: str, notes: str = "", is_major: bool = False) -> dict:
141
149
  """
142
150
  Create a new version for the specified model.
143
151
 
@@ -273,9 +281,7 @@ class Client:
273
281
  if hasattr(e, "response") and e.response is not None:
274
282
  logging.error(f"Response status code: {e.response.status_code}")
275
283
  logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
276
- raise OpenGradientError(
277
- f"Upload failed due to request exception: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
278
- )
284
+ raise OpenGradientError(f"Upload failed due to request exception: {str(e)}")
279
285
  except Exception as e:
280
286
  logging.error(f"Unexpected error during upload: {str(e)}", exc_info=True)
281
287
  raise OpenGradientError(f"Unexpected error during upload: {str(e)}")
@@ -286,7 +292,7 @@ class Client:
286
292
  inference_mode: InferenceMode,
287
293
  model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
288
294
  max_retries: Optional[int] = None,
289
- ) -> Tuple[str, Dict[str, np.ndarray]]:
295
+ ) -> InferenceResult:
290
296
  """
291
297
  Perform inference on a model.
292
298
 
@@ -297,7 +303,7 @@ class Client:
297
303
  max_retries (int, optional): Maximum number of retry attempts. Defaults to 5.
298
304
 
299
305
  Returns:
300
- Tuple[str, Dict[str, np.ndarray]]: The transaction hash and the model output.
306
+ InferenceResult: The transaction hash and the model output.
301
307
 
302
308
  Raises:
303
309
  OpenGradientError: If the inference fails.
@@ -306,7 +312,7 @@ class Client:
306
312
  def execute_transaction():
307
313
  contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
308
314
 
309
- inference_mode_uint8 = int(inference_mode)
315
+ inference_mode_uint8 = inference_mode.value
310
316
  converted_model_input = utils.convert_to_model_input(model_input)
311
317
 
312
318
  run_function = contract.functions.run(model_cid, inference_mode_uint8, converted_model_input)
@@ -337,20 +343,21 @@ class Client:
337
343
 
338
344
  # TODO: This should return a ModelOutput class object
339
345
  model_output = utils.convert_to_model_output(parsed_logs[0]["args"])
340
- return tx_hash.hex(), model_output
346
+
347
+ return InferenceResult(tx_hash.hex(), model_output)
341
348
 
342
349
  return run_with_retry(execute_transaction, max_retries)
343
350
 
344
351
  def llm_completion(
345
352
  self,
346
353
  model_cid: LLM,
347
- inference_mode: InferenceMode,
354
+ inference_mode: LlmInferenceMode,
348
355
  prompt: str,
349
356
  max_tokens: int = 100,
350
357
  stop_sequence: Optional[List[str]] = None,
351
358
  temperature: float = 0.0,
352
359
  max_retries: Optional[int] = None,
353
- ) -> Tuple[str, str]:
360
+ ) -> TextGenerationOutput:
354
361
  """
355
362
  Perform inference on an LLM model using completions.
356
363
 
@@ -363,7 +370,9 @@ class Client:
363
370
  temperature (float): Temperature for LLM inference, between 0 and 1. Default is 0.0.
364
371
 
365
372
  Returns:
366
- Tuple[str, str]: The transaction hash and the LLM completion output.
373
+ TextGenerationOutput: Generated text results including:
374
+ - Transaction hash
375
+ - String of completion output
367
376
 
368
377
  Raises:
369
378
  OpenGradientError: If the inference fails.
@@ -381,7 +390,7 @@ class Client:
381
390
 
382
391
  # Prepare LLM input
383
392
  llm_request = {
384
- "mode": inference_mode,
393
+ "mode": inference_mode.value,
385
394
  "modelCID": model_cid,
386
395
  "prompt": prompt,
387
396
  "max_tokens": max_tokens,
@@ -418,14 +427,15 @@ class Client:
418
427
  raise OpenGradientError("LLM completion result event not found in transaction logs")
419
428
 
420
429
  llm_answer = parsed_logs[0]["args"]["response"]["answer"]
421
- return tx_hash.hex(), llm_answer
430
+
431
+ return TextGenerationOutput(transaction_hash=tx_hash.hex(), completion_output=llm_answer)
422
432
 
423
433
  return run_with_retry(execute_transaction, max_retries)
424
434
 
425
435
  def llm_chat(
426
436
  self,
427
437
  model_cid: LLM,
428
- inference_mode: InferenceMode,
438
+ inference_mode: LlmInferenceMode,
429
439
  messages: List[Dict],
430
440
  max_tokens: int = 100,
431
441
  stop_sequence: Optional[List[str]] = None,
@@ -433,7 +443,7 @@ class Client:
433
443
  tools: Optional[List[Dict]] = [],
434
444
  tool_choice: Optional[str] = None,
435
445
  max_retries: Optional[int] = None,
436
- ) -> Tuple[str, str]:
446
+ ) -> TextGenerationOutput:
437
447
  """
438
448
  Perform inference on an LLM model using chat.
439
449
 
@@ -485,7 +495,10 @@ class Client:
485
495
  tool_choice (str, optional): Sets a specific tool to choose. Default value is "auto".
486
496
 
487
497
  Returns:
488
- Tuple[str, str, dict]: The transaction hash, finish reason, and a dictionary struct of LLM chat messages.
498
+ TextGenerationOutput: Generated text results including:
499
+ - Transaction hash
500
+ - Finish reason (tool_call, stop, etc.)
501
+ - Dictionary of chat message output (role, content, tool_call, etc.)
489
502
 
490
503
  Raises:
491
504
  OpenGradientError: If the inference fails.
@@ -527,7 +540,7 @@ class Client:
527
540
 
528
541
  # Prepare LLM input
529
542
  llm_request = {
530
- "mode": inference_mode,
543
+ "mode": inference_mode.value,
531
544
  "modelCID": model_cid,
532
545
  "messages": messages,
533
546
  "max_tokens": max_tokens,
@@ -570,7 +583,11 @@ class Client:
570
583
  if (tool_calls := message.get("tool_calls")) is not None:
571
584
  message["tool_calls"] = [dict(tool_call) for tool_call in tool_calls]
572
585
 
573
- return tx_hash.hex(), llm_result["finish_reason"], message
586
+ return TextGenerationOutput(
587
+ transaction_hash=tx_hash.hex(),
588
+ finish_reason=llm_result["finish_reason"],
589
+ chat_output=message,
590
+ )
574
591
 
575
592
  return run_with_retry(execute_transaction, max_retries)
576
593
 
@@ -611,233 +628,265 @@ class Client:
611
628
  if hasattr(e, "response") and e.response is not None:
612
629
  logging.error(f"Response status code: {e.response.status_code}")
613
630
  logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
614
- raise OpenGradientError(
615
- f"File listing failed: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
616
- )
631
+ raise OpenGradientError(f"File listing failed: {str(e)}")
617
632
  except Exception as e:
618
633
  logging.error(f"Unexpected error during file listing: {str(e)}", exc_info=True)
619
634
  raise OpenGradientError(f"Unexpected error during file listing: {str(e)}")
620
635
 
621
- def generate_image(
622
- self,
623
- model_cid: str,
624
- prompt: str,
625
- host: str = DEFAULT_IMAGE_GEN_HOST,
626
- port: int = DEFAULT_IMAGE_GEN_PORT,
627
- width: int = 1024,
628
- height: int = 1024,
629
- timeout: int = 300, # 5 minute timeout
630
- max_retries: int = 3,
631
- ) -> bytes:
636
+ # def generate_image(
637
+ # self,
638
+ # model_cid: str,
639
+ # prompt: str,
640
+ # host: str = DEFAULT_IMAGE_GEN_HOST,
641
+ # port: int = DEFAULT_IMAGE_GEN_PORT,
642
+ # width: int = 1024,
643
+ # height: int = 1024,
644
+ # timeout: int = 300, # 5 minute timeout
645
+ # max_retries: int = 3,
646
+ # ) -> bytes:
647
+ # """
648
+ # Generate an image using a diffusion model through gRPC.
649
+
650
+ # Args:
651
+ # model_cid (str): The model identifier (e.g. "stabilityai/stable-diffusion-xl-base-1.0")
652
+ # prompt (str): The text prompt to generate the image from
653
+ # host (str, optional): gRPC host address. Defaults to DEFAULT_IMAGE_GEN_HOST.
654
+ # port (int, optional): gRPC port number. Defaults to DEFAULT_IMAGE_GEN_PORT.
655
+ # width (int, optional): Output image width. Defaults to 1024.
656
+ # height (int, optional): Output image height. Defaults to 1024.
657
+ # timeout (int, optional): Maximum time to wait for generation in seconds. Defaults to 300.
658
+ # max_retries (int, optional): Maximum number of retry attempts. Defaults to 3.
659
+
660
+ # Returns:
661
+ # bytes: The raw image data bytes
662
+
663
+ # Raises:
664
+ # OpenGradientError: If the image generation fails
665
+ # TimeoutError: If the generation exceeds the timeout period
666
+ # """
667
+
668
+ # def exponential_backoff(attempt: int, max_delay: float = 30.0) -> None:
669
+ # """Calculate and sleep for exponential backoff duration"""
670
+ # delay = min(0.1 * (2**attempt), max_delay)
671
+ # time.sleep(delay)
672
+
673
+ # channel = None
674
+ # start_time = time.time()
675
+ # retry_count = 0
676
+
677
+ # try:
678
+ # while retry_count < max_retries:
679
+ # try:
680
+ # # Initialize gRPC channel and stub
681
+ # channel = grpc.insecure_channel(f"{host}:{port}")
682
+ # stub = infer_pb2_grpc.InferenceServiceStub(channel)
683
+
684
+ # # Create image generation request
685
+ # image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
686
+
687
+ # # Create inference request with random transaction ID
688
+ # tx_id = str(uuid.uuid4())
689
+ # request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
690
+
691
+ # # Send request with timeout
692
+ # response_id = stub.RunInferenceAsync(
693
+ # request,
694
+ # timeout=min(30, timeout), # Initial request timeout
695
+ # )
696
+
697
+ # # Poll for completion
698
+ # attempt = 0
699
+ # while True:
700
+ # # Check timeout
701
+ # if time.time() - start_time > timeout:
702
+ # raise TimeoutError(f"Image generation timed out after {timeout} seconds")
703
+
704
+ # status_request = infer_pb2.InferenceTxId(id=response_id.id)
705
+ # try:
706
+ # status = stub.GetInferenceStatus(
707
+ # status_request,
708
+ # timeout=min(5, timeout), # Status check timeout
709
+ # ).status
710
+ # except grpc.RpcError as e:
711
+ # logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
712
+ # exponential_backoff(attempt)
713
+ # attempt += 1
714
+ # continue
715
+
716
+ # if status == infer_pb2.InferenceStatus.STATUS_COMPLETED:
717
+ # break
718
+ # elif status == infer_pb2.InferenceStatus.STATUS_ERROR:
719
+ # raise OpenGradientError("Image generation failed on server")
720
+ # elif status != infer_pb2.InferenceStatus.STATUS_IN_PROGRESS:
721
+ # raise OpenGradientError(f"Unexpected status: {status}")
722
+
723
+ # exponential_backoff(attempt)
724
+ # attempt += 1
725
+
726
+ # # Get result
727
+ # result = stub.GetInferenceResult(
728
+ # response_id,
729
+ # timeout=min(30, timeout), # Result fetch timeout
730
+ # )
731
+ # return result.image_generation_result.image_data
732
+
733
+ # except (grpc.RpcError, TimeoutError) as e:
734
+ # retry_count += 1
735
+ # if retry_count >= max_retries:
736
+ # raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
737
+
738
+ # logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
739
+ # exponential_backoff(retry_count)
740
+
741
+ # except grpc.RpcError as e:
742
+ # logging.error(f"gRPC error: {str(e)}")
743
+ # raise OpenGradientError(f"Image generation failed: {str(e)}")
744
+ # except TimeoutError as e:
745
+ # logging.error(f"Timeout error: {str(e)}")
746
+ # raise
747
+ # except Exception as e:
748
+ # logging.error(f"Error in generate image method: {str(e)}", exc_info=True)
749
+ # raise OpenGradientError(f"Image generation failed: {str(e)}")
750
+ # finally:
751
+ # if channel:
752
+ # channel.close()
753
+
754
+ def _get_abi(self, abi_name) -> List[Dict]:
632
755
  """
633
- Generate an image using a diffusion model through gRPC.
634
-
635
- Args:
636
- model_cid (str): The model identifier (e.g. "stabilityai/stable-diffusion-xl-base-1.0")
637
- prompt (str): The text prompt to generate the image from
638
- host (str, optional): gRPC host address. Defaults to DEFAULT_IMAGE_GEN_HOST.
639
- port (int, optional): gRPC port number. Defaults to DEFAULT_IMAGE_GEN_PORT.
640
- width (int, optional): Output image width. Defaults to 1024.
641
- height (int, optional): Output image height. Defaults to 1024.
642
- timeout (int, optional): Maximum time to wait for generation in seconds. Defaults to 300.
643
- max_retries (int, optional): Maximum number of retry attempts. Defaults to 3.
644
-
645
- Returns:
646
- bytes: The raw image data bytes
647
-
648
- Raises:
649
- OpenGradientError: If the image generation fails
650
- TimeoutError: If the generation exceeds the timeout period
756
+ Returns the ABI for the requested contract.
651
757
  """
758
+ abi_path = Path(__file__).parent / "abi" / abi_name
759
+ with open(abi_path, "r") as f:
760
+ return json.load(f)
652
761
 
653
- def exponential_backoff(attempt: int, max_delay: float = 30.0) -> None:
654
- """Calculate and sleep for exponential backoff duration"""
655
- delay = min(0.1 * (2**attempt), max_delay)
656
- time.sleep(delay)
657
-
658
- channel = None
659
- start_time = time.time()
660
- retry_count = 0
661
-
662
- try:
663
- while retry_count < max_retries:
664
- try:
665
- # Initialize gRPC channel and stub
666
- channel = grpc.insecure_channel(f"{host}:{port}")
667
- stub = infer_pb2_grpc.InferenceServiceStub(channel)
668
-
669
- # Create image generation request
670
- image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
671
-
672
- # Create inference request with random transaction ID
673
- tx_id = str(uuid.uuid4())
674
- request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
675
-
676
- # Send request with timeout
677
- response_id = stub.RunInferenceAsync(
678
- request,
679
- timeout=min(30, timeout), # Initial request timeout
680
- )
681
-
682
- # Poll for completion
683
- attempt = 0
684
- while True:
685
- # Check timeout
686
- if time.time() - start_time > timeout:
687
- raise TimeoutError(f"Image generation timed out after {timeout} seconds")
688
-
689
- status_request = infer_pb2.InferenceTxId(id=response_id.id)
690
- try:
691
- status = stub.GetInferenceStatus(
692
- status_request,
693
- timeout=min(5, timeout), # Status check timeout
694
- ).status
695
- except grpc.RpcError as e:
696
- logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
697
- exponential_backoff(attempt)
698
- attempt += 1
699
- continue
700
-
701
- if status == infer_pb2.InferenceStatus.STATUS_COMPLETED:
702
- break
703
- elif status == infer_pb2.InferenceStatus.STATUS_ERROR:
704
- raise OpenGradientError("Image generation failed on server")
705
- elif status != infer_pb2.InferenceStatus.STATUS_IN_PROGRESS:
706
- raise OpenGradientError(f"Unexpected status: {status}")
707
-
708
- exponential_backoff(attempt)
709
- attempt += 1
710
-
711
- # Get result
712
- result = stub.GetInferenceResult(
713
- response_id,
714
- timeout=min(30, timeout), # Result fetch timeout
715
- )
716
- return result.image_generation_result.image_data
717
-
718
- except (grpc.RpcError, TimeoutError) as e:
719
- retry_count += 1
720
- if retry_count >= max_retries:
721
- raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
722
-
723
- logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
724
- exponential_backoff(retry_count)
725
-
726
- except grpc.RpcError as e:
727
- logging.error(f"gRPC error: {str(e)}")
728
- raise OpenGradientError(f"Image generation failed: {str(e)}")
729
- except TimeoutError as e:
730
- logging.error(f"Timeout error: {str(e)}")
731
- raise
732
- except Exception as e:
733
- logging.error(f"Error in generate image method: {str(e)}", exc_info=True)
734
- raise OpenGradientError(f"Image generation failed: {str(e)}")
735
- finally:
736
- if channel:
737
- channel.close()
738
-
739
- def _get_model_executor_abi(self) -> List[Dict]:
762
+ def _get_bin(self, bin_name) -> List[Dict]:
740
763
  """
741
- Returns the ABI for the ModelExecutorHistorical contract.
764
+ Returns the bin for the requested contract.
742
765
  """
743
- abi_path = Path(__file__).parent / "abi" / "ModelExecutorHistorical.abi"
744
- with open(abi_path, "r") as f:
745
- return json.load(f)
766
+ bin_path = Path(__file__).parent / "bin" / bin_name
767
+ # Read bytecode with explicit encoding
768
+ with open(bin_path, "r", encoding="utf-8") as f:
769
+ bytecode = f.read().strip()
770
+ if not bytecode.startswith("0x"):
771
+ bytecode = "0x" + bytecode
772
+ return bytecode
746
773
 
747
774
  def new_workflow(
748
775
  self,
749
776
  model_cid: str,
750
- input_query: Union[Dict[str, Any], HistoricalInputQuery],
777
+ input_query: HistoricalInputQuery,
751
778
  input_tensor_name: str,
752
779
  scheduler_params: Optional[SchedulerParams] = None,
753
780
  ) -> str:
754
781
  """
755
782
  Deploy a new workflow contract with the specified parameters.
756
- """
757
- if isinstance(input_query, dict):
758
- input_query = HistoricalInputQuery.from_dict(input_query)
759
783
 
784
+ This function deploys a new workflow contract and optionally registers it with
785
+ the scheduler for automated execution. If scheduler_params is not provided,
786
+ the workflow will be deployed without automated execution scheduling.
787
+
788
+ Args:
789
+ model_cid (str): IPFS CID of the model to be executed
790
+ input_query (HistoricalInputQuery): Query parameters for data input
791
+ input_tensor_name (str): Name of the input tensor expected by the model
792
+ scheduler_params (Optional[SchedulerParams]): Scheduler configuration for automated execution:
793
+ - frequency: Execution frequency in seconds
794
+ - duration_hours: How long to run in hours
795
+
796
+ Returns:
797
+ str: Deployed contract address. If scheduler_params was provided, the workflow
798
+ will be automatically executed according to the specified schedule.
799
+
800
+ Raises:
801
+ Exception: If transaction fails or gas estimation fails
802
+ """
760
803
  # Get contract ABI and bytecode
761
- abi = self._get_model_executor_abi()
762
- bin_path = Path(__file__).parent / "bin" / "ModelExecutorHistorical.bin"
804
+ abi = self._get_abi("PriceHistoryInference.abi")
805
+ bytecode = self._get_bin("PriceHistoryInference.bin")
763
806
 
764
- with open(bin_path, "r") as f:
765
- bytecode = f.read().strip()
807
+ def deploy_transaction():
808
+ contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode)
809
+ query_tuple = input_query.to_abi_format()
810
+ constructor_args = [model_cid, input_tensor_name, query_tuple]
766
811
 
767
- print("📦 Deploying workflow contract...")
812
+ try:
813
+ # Estimate gas needed
814
+ estimated_gas = contract.constructor(*constructor_args).estimate_gas({"from": self._wallet_account.address})
815
+ gas_limit = int(estimated_gas * 1.2)
816
+ except Exception as e:
817
+ print(f"⚠️ Gas estimation failed: {str(e)}")
818
+ gas_limit = 5000000 # Conservative fallback
819
+ print(f"📊 Using fallback gas limit: {gas_limit}")
768
820
 
769
- # Create contract instance
770
- contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode)
821
+ transaction = contract.constructor(*constructor_args).build_transaction(
822
+ {
823
+ "from": self._wallet_account.address,
824
+ "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
825
+ "gas": gas_limit,
826
+ "gasPrice": self._blockchain.eth.gas_price,
827
+ "chainId": self._blockchain.eth.chain_id,
828
+ }
829
+ )
771
830
 
772
- # Deploy contract with constructor arguments
773
- transaction = contract.constructor().build_transaction(
774
- {
775
- "from": self._wallet_account.address,
776
- "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
777
- "gas": 15000000,
778
- "gasPrice": self._blockchain.eth.gas_price,
779
- "chainId": self._blockchain.eth.chain_id,
780
- }
781
- )
831
+ signed_txn = self._wallet_account.sign_transaction(transaction)
832
+ tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
782
833
 
783
- signed_txn = self._wallet_account.sign_transaction(transaction)
784
- tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
785
- tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=REGULAR_TX_TIMEOUT)
786
- contract_address = tx_receipt.contractAddress
834
+ tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=60)
835
+
836
+ if tx_receipt["status"] == 0:
837
+ raise Exception(f"❌ Contract deployment failed, transaction hash: {tx_hash.hex()}")
787
838
 
788
- print(f"✅ Workflow contract deployed at: {contract_address}")
839
+ return tx_receipt.contractAddress
840
+
841
+ contract_address = run_with_retry(deploy_transaction)
789
842
 
790
- # Register with scheduler if params provided
791
843
  if scheduler_params:
792
- print("\n⏰ Setting up automated execution schedule...")
793
- print(f" • Frequency: Every {scheduler_params.frequency} seconds")
794
- print(f" • Duration: {scheduler_params.duration_hours} hours")
795
- print(f" • End Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(scheduler_params.end_time))}")
844
+ self._register_with_scheduler(contract_address, scheduler_params)
796
845
 
797
- scheduler_abi = [
798
- {
799
- "inputs": [
800
- {"internalType": "address", "name": "contractAddress", "type": "address"},
801
- {"internalType": "uint256", "name": "endTime", "type": "uint256"},
802
- {"internalType": "uint256", "name": "frequency", "type": "uint256"},
803
- ],
804
- "name": "registerTask",
805
- "outputs": [],
806
- "stateMutability": "nonpayable",
807
- "type": "function",
808
- }
809
- ]
846
+ return contract_address
810
847
 
811
- scheduler_address = "0x6F937b9f4Fa7723932827cd73063B70Be2b56748"
812
- scheduler_contract = self._blockchain.eth.contract(address=scheduler_address, abi=scheduler_abi)
848
+ def _register_with_scheduler(self, contract_address: str, scheduler_params: SchedulerParams) -> None:
849
+ """
850
+ Register the deployed workflow contract with the scheduler for automated execution.
813
851
 
814
- try:
815
- # Register the workflow with the scheduler
816
- scheduler_tx = scheduler_contract.functions.registerTask(
817
- contract_address, scheduler_params.end_time, scheduler_params.frequency
818
- ).build_transaction(
819
- {
820
- "from": self._wallet_account.address,
821
- "gas": 300000,
822
- "gasPrice": self._blockchain.eth.gas_price,
823
- "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
824
- "chainId": self._blockchain.eth.chain_id,
825
- }
826
- )
852
+ Args:
853
+ contract_address (str): Address of the deployed workflow contract
854
+ scheduler_params (SchedulerParams): Scheduler configuration containing:
855
+ - frequency: Execution frequency in seconds
856
+ - duration_hours: How long to run in hours
857
+ - end_time: Unix timestamp when scheduling should end
827
858
 
828
- signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx)
829
- scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction)
830
- self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash, timeout=REGULAR_TX_TIMEOUT)
859
+ Raises:
860
+ Exception: If registration with scheduler fails. The workflow contract will
861
+ still be deployed and can be executed manually.
862
+ """
831
863
 
832
- print("✅ Automated execution schedule set successfully!")
833
- print(f" Transaction hash: {scheduler_tx_hash.hex()}")
864
+ scheduler_abi = self._get_abi("WorkflowScheduler.abi")
834
865
 
835
- except Exception as e:
836
- print("❌ Failed to set up automated execution schedule")
837
- print(f" Error: {str(e)}")
838
- print(" The workflow contract is still deployed and can be executed manually.")
866
+ # Scheduler contract address
867
+ scheduler_address = DEFAULT_SCHEDULER_ADDRESS
868
+ scheduler_contract = self._blockchain.eth.contract(address=scheduler_address, abi=scheduler_abi)
839
869
 
840
- return contract_address
870
+ try:
871
+ # Register the workflow with the scheduler
872
+ scheduler_tx = scheduler_contract.functions.registerTask(
873
+ contract_address, scheduler_params.end_time, scheduler_params.frequency
874
+ ).build_transaction(
875
+ {
876
+ "from": self._wallet_account.address,
877
+ "gas": 300000,
878
+ "gasPrice": self._blockchain.eth.gas_price,
879
+ "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
880
+ "chainId": self._blockchain.eth.chain_id,
881
+ }
882
+ )
883
+
884
+ signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx)
885
+ scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction)
886
+ self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash, timeout=REGULAR_TX_TIMEOUT)
887
+ except Exception as e:
888
+ print(f"❌ Error registering contract with scheduler: {str(e)}")
889
+ print(" The workflow contract is still deployed and can be executed manually.")
841
890
 
842
891
  def read_workflow_result(self, contract_address: str) -> ModelOutput:
843
892
  """
@@ -854,7 +903,9 @@ class Client:
854
903
  Web3Error: If there are issues with the web3 connection or contract interaction
855
904
  """
856
905
  # Get the contract interface
857
- contract = self._blockchain.eth.contract(address=Web3.to_checksum_address(contract_address), abi=self._get_model_executor_abi())
906
+ contract = self._blockchain.eth.contract(
907
+ address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
908
+ )
858
909
 
859
910
  # Get the result
860
911
  result = contract.functions.getInferenceResult().call()
@@ -876,7 +927,9 @@ class Client:
876
927
  Web3Error: If there are issues with the web3 connection or contract interaction
877
928
  """
878
929
  # Get the contract interface
879
- contract = self._blockchain.eth.contract(address=Web3.to_checksum_address(contract_address), abi=self._get_model_executor_abi())
930
+ contract = self._blockchain.eth.contract(
931
+ address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
932
+ )
880
933
 
881
934
  # Call run() function
882
935
  nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
@@ -904,6 +957,31 @@ class Client:
904
957
 
905
958
  return utils.convert_array_to_model_output(result)
906
959
 
960
+ def read_workflow_history(self, contract_address: str, num_results: int) -> List[Dict]:
961
+ """
962
+ Gets historical inference results from a workflow contract.
963
+
964
+ Retrieves the specified number of most recent inference results from the contract's
965
+ storage, with the most recent result first.
966
+
967
+ Args:
968
+ contract_address (str): Address of the deployed workflow contract
969
+ num_results (int): Number of historical results to retrieve
970
+
971
+ Returns:
972
+ List[Dict]: List of historical inference results, each containing:
973
+ - prediction values
974
+ - timestamps
975
+ - any additional metadata stored with the result
976
+
977
+ """
978
+ contract = self._blockchain.eth.contract(
979
+ address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
980
+ )
981
+
982
+ results = contract.functions.getLastInferenceResults(num_results).call()
983
+ return [utils.convert_array_to_model_output(result) for result in results]
984
+
907
985
 
908
986
  def run_with_retry(txn_function, max_retries=DEFAULT_MAX_RETRY, retry_delay=DEFAULT_RETRY_DELAY_SEC):
909
987
  """