opengradient 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
opengradient/client.py CHANGED
@@ -2,12 +2,10 @@ import json
2
2
  import logging
3
3
  import os
4
4
  import time
5
- import uuid
6
5
  from pathlib import Path
7
- from typing import Any, Dict, List, Optional, Tuple, Union
6
+ from typing import Any, Dict, List, Optional, Union
8
7
 
9
8
  import firebase
10
- import grpc
11
9
  import numpy as np
12
10
  import requests
13
11
  from eth_account.account import LocalAccount
@@ -15,11 +13,23 @@ from web3 import Web3
15
13
  from web3.exceptions import ContractLogicError
16
14
  from web3.logs import DISCARD
17
15
 
18
- from . import utils
19
16
  from .exceptions import OpenGradientError
20
17
  from .proto import infer_pb2, infer_pb2_grpc
21
- from .types import LLM, TEE_LLM, HistoricalInputQuery, InferenceMode, LlmInferenceMode, ModelOutput, TextGenerationOutput, SchedulerParams
22
- from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT
18
+ from .types import (
19
+ LLM,
20
+ TEE_LLM,
21
+ HistoricalInputQuery,
22
+ InferenceMode,
23
+ LlmInferenceMode,
24
+ ModelOutput,
25
+ TextGenerationOutput,
26
+ SchedulerParams,
27
+ InferenceResult,
28
+ ModelRepository,
29
+ FileUploadResult,
30
+ )
31
+ from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT, DEFAULT_SCHEDULER_ADDRESS
32
+ from .utils import convert_array_to_model_output, convert_to_model_input, convert_to_model_output
23
33
 
24
34
  _FIREBASE_CONFIG = {
25
35
  "apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
@@ -45,10 +55,10 @@ class Client:
45
55
  _blockchain: Web3
46
56
  _wallet_account: LocalAccount
47
57
 
48
- _hub_user: Dict
58
+ _hub_user: Optional[Dict]
49
59
  _inference_abi: Dict
50
60
 
51
- def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: str, password: str):
61
+ def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: Optional[str], password: Optional[str]):
52
62
  """
53
63
  Initialize the Client with private key, RPC URL, and contract address.
54
64
 
@@ -80,7 +90,7 @@ class Client:
80
90
  logging.error(f"Authentication failed: {str(e)}")
81
91
  raise
82
92
 
83
- def create_model(self, model_name: str, model_desc: str, version: str = "1.00") -> dict:
93
+ def create_model(self, model_name: str, model_desc: str, version: str = "1.00") -> ModelRepository:
84
94
  """
85
95
  Create a new model with the given model_name and model_desc, and a specified version.
86
96
 
@@ -103,41 +113,24 @@ class Client:
103
113
  payload = {"name": model_name, "description": model_desc}
104
114
 
105
115
  try:
106
- logging.debug(f"Create Model URL: {url}")
107
- logging.debug(f"Headers: {headers}")
108
- logging.debug(f"Payload: {payload}")
109
-
110
116
  response = requests.post(url, json=payload, headers=headers)
111
117
  response.raise_for_status()
118
+ except requests.HTTPError as e:
119
+ error_details = f"HTTP {e.response.status_code}: {e.response.text}"
120
+ raise OpenGradientError(f"Model creation failed: {error_details}") from e
112
121
 
113
- json_response = response.json()
114
- model_name = json_response.get("name")
115
- if not model_name:
116
- raise Exception(f"Model creation response missing 'name'. Full response: {json_response}")
117
- logging.info(f"Model creation successful. Model name: {model_name}")
122
+ json_response = response.json()
123
+ model_name = json_response.get("name")
124
+ if not model_name:
125
+ raise Exception(f"Model creation response missing 'name'. Full response: {json_response}")
118
126
 
119
- # Create the specified version for the newly created model
120
- try:
121
- version_response = self.create_version(model_name, version)
122
- logging.info(f"Version creation successful. Version string: {version_response['versionString']}")
123
- except Exception as ve:
124
- logging.error(f"Version creation failed, but model was created. Error: {str(ve)}")
125
- return {"name": model_name, "versionString": None, "version_error": str(ve)}
127
+ # Create the specified version for the newly created model
128
+ version_response = self.create_version(model_name, version)
126
129
 
127
- return {"name": model_name, "versionString": version_response["versionString"]}
130
+ return ModelRepository(model_name, version_response["versionString"])
128
131
 
129
- except requests.RequestException as e:
130
- logging.error(f"Model creation failed: {str(e)}")
131
- if hasattr(e, "response") and e.response is not None:
132
- logging.error(f"Response status code: {e.response.status_code}")
133
- logging.error(f"Response headers: {e.response.headers}")
134
- logging.error(f"Response content: {e.response.text}")
135
- raise Exception(f"Model creation failed: {str(e)}")
136
- except Exception as e:
137
- logging.error(f"Unexpected error during model creation: {str(e)}")
138
- raise
139
132
 
140
- def create_version(self, model_name: str, notes: str = None, is_major: bool = False) -> dict:
133
+ def create_version(self, model_name: str, notes: str = "", is_major: bool = False) -> dict:
141
134
  """
142
135
  Create a new version for the specified model.
143
136
 
@@ -196,7 +189,7 @@ class Client:
196
189
  logging.error(f"Unexpected error during version creation: {str(e)}")
197
190
  raise
198
191
 
199
- def upload(self, model_path: str, model_name: str, version: str) -> dict:
192
+ def upload(self, model_path: str, model_name: str, version: str) -> FileUploadResult:
200
193
  """
201
194
  Upload a model file to the server.
202
195
 
@@ -251,12 +244,9 @@ class Client:
251
244
  if response.status_code == 201:
252
245
  if response.content and response.content != b"null":
253
246
  json_response = response.json()
254
- logging.info(f"JSON response: {json_response}") # Log the parsed JSON response
255
- logging.info(f"Upload successful. CID: {json_response.get('ipfsCid', 'N/A')}")
256
- result = {"model_cid": json_response.get("ipfsCid"), "size": json_response.get("size")}
247
+ return FileUploadResult(json_response.get("ipfsCid"), json_response.get("size"))
257
248
  else:
258
- logging.warning("Empty or null response content received. Assuming upload was successful.")
259
- result = {"model_cid": None, "size": None}
249
+ raise RuntimeError("Empty or null response content received. Assuming upload was successful.")
260
250
  elif response.status_code == 500:
261
251
  error_message = "Internal server error occurred. Please try again later or contact support."
262
252
  logging.error(error_message)
@@ -266,16 +256,12 @@ class Client:
266
256
  logging.error(f"Upload failed with status code {response.status_code}: {error_message}")
267
257
  raise OpenGradientError(f"Upload failed: {error_message}", status_code=response.status_code)
268
258
 
269
- return result
270
-
271
259
  except requests.RequestException as e:
272
260
  logging.error(f"Request exception during upload: {str(e)}")
273
261
  if hasattr(e, "response") and e.response is not None:
274
262
  logging.error(f"Response status code: {e.response.status_code}")
275
263
  logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
276
- raise OpenGradientError(
277
- f"Upload failed due to request exception: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
278
- )
264
+ raise OpenGradientError(f"Upload failed due to request exception: {str(e)}")
279
265
  except Exception as e:
280
266
  logging.error(f"Unexpected error during upload: {str(e)}", exc_info=True)
281
267
  raise OpenGradientError(f"Unexpected error during upload: {str(e)}")
@@ -286,7 +272,7 @@ class Client:
286
272
  inference_mode: InferenceMode,
287
273
  model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
288
274
  max_retries: Optional[int] = None,
289
- ) -> Tuple[str, Dict[str, np.ndarray]]:
275
+ ) -> InferenceResult:
290
276
  """
291
277
  Perform inference on a model.
292
278
 
@@ -297,7 +283,7 @@ class Client:
297
283
  max_retries (int, optional): Maximum number of retry attempts. Defaults to 5.
298
284
 
299
285
  Returns:
300
- Tuple[str, Dict[str, np.ndarray]]: The transaction hash and the model output.
286
+ InferenceResult: The transaction hash and the model output.
301
287
 
302
288
  Raises:
303
289
  OpenGradientError: If the inference fails.
@@ -306,8 +292,8 @@ class Client:
306
292
  def execute_transaction():
307
293
  contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
308
294
 
309
- inference_mode_uint8 = int(inference_mode)
310
- converted_model_input = utils.convert_to_model_input(model_input)
295
+ inference_mode_uint8 = inference_mode.value
296
+ converted_model_input = convert_to_model_input(model_input)
311
297
 
312
298
  run_function = contract.functions.run(model_cid, inference_mode_uint8, converted_model_input)
313
299
 
@@ -336,15 +322,16 @@ class Client:
336
322
  raise OpenGradientError("InferenceResult event not found in transaction logs")
337
323
 
338
324
  # TODO: This should return a ModelOutput class object
339
- model_output = utils.convert_to_model_output(parsed_logs[0]["args"])
340
- return tx_hash.hex(), model_output
325
+ model_output = convert_to_model_output(parsed_logs[0]["args"])
326
+
327
+ return InferenceResult(tx_hash.hex(), model_output)
341
328
 
342
329
  return run_with_retry(execute_transaction, max_retries)
343
330
 
344
331
  def llm_completion(
345
332
  self,
346
333
  model_cid: LLM,
347
- inference_mode: InferenceMode,
334
+ inference_mode: LlmInferenceMode,
348
335
  prompt: str,
349
336
  max_tokens: int = 100,
350
337
  stop_sequence: Optional[List[str]] = None,
@@ -383,7 +370,7 @@ class Client:
383
370
 
384
371
  # Prepare LLM input
385
372
  llm_request = {
386
- "mode": inference_mode,
373
+ "mode": inference_mode.value,
387
374
  "modelCID": model_cid,
388
375
  "prompt": prompt,
389
376
  "max_tokens": max_tokens,
@@ -420,18 +407,15 @@ class Client:
420
407
  raise OpenGradientError("LLM completion result event not found in transaction logs")
421
408
 
422
409
  llm_answer = parsed_logs[0]["args"]["response"]["answer"]
423
-
424
- return TextGenerationOutput(
425
- transaction_hash=tx_hash.hex(),
426
- completion_output=llm_answer
427
- )
410
+
411
+ return TextGenerationOutput(transaction_hash=tx_hash.hex(), completion_output=llm_answer)
428
412
 
429
413
  return run_with_retry(execute_transaction, max_retries)
430
414
 
431
415
  def llm_chat(
432
416
  self,
433
417
  model_cid: LLM,
434
- inference_mode: InferenceMode,
418
+ inference_mode: LlmInferenceMode,
435
419
  messages: List[Dict],
436
420
  max_tokens: int = 100,
437
421
  stop_sequence: Optional[List[str]] = None,
@@ -536,7 +520,7 @@ class Client:
536
520
 
537
521
  # Prepare LLM input
538
522
  llm_request = {
539
- "mode": inference_mode,
523
+ "mode": inference_mode.value,
540
524
  "modelCID": model_cid,
541
525
  "messages": messages,
542
526
  "max_tokens": max_tokens,
@@ -624,234 +608,269 @@ class Client:
624
608
  if hasattr(e, "response") and e.response is not None:
625
609
  logging.error(f"Response status code: {e.response.status_code}")
626
610
  logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
627
- raise OpenGradientError(
628
- f"File listing failed: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
629
- )
611
+ raise OpenGradientError(f"File listing failed: {str(e)}")
630
612
  except Exception as e:
631
613
  logging.error(f"Unexpected error during file listing: {str(e)}", exc_info=True)
632
614
  raise OpenGradientError(f"Unexpected error during file listing: {str(e)}")
633
615
 
634
- def generate_image(
635
- self,
636
- model_cid: str,
637
- prompt: str,
638
- host: str = DEFAULT_IMAGE_GEN_HOST,
639
- port: int = DEFAULT_IMAGE_GEN_PORT,
640
- width: int = 1024,
641
- height: int = 1024,
642
- timeout: int = 300, # 5 minute timeout
643
- max_retries: int = 3,
644
- ) -> bytes:
616
+ # def generate_image(
617
+ # self,
618
+ # model_cid: str,
619
+ # prompt: str,
620
+ # host: str = DEFAULT_IMAGE_GEN_HOST,
621
+ # port: int = DEFAULT_IMAGE_GEN_PORT,
622
+ # width: int = 1024,
623
+ # height: int = 1024,
624
+ # timeout: int = 300, # 5 minute timeout
625
+ # max_retries: int = 3,
626
+ # ) -> bytes:
627
+ # """
628
+ # Generate an image using a diffusion model through gRPC.
629
+
630
+ # Args:
631
+ # model_cid (str): The model identifier (e.g. "stabilityai/stable-diffusion-xl-base-1.0")
632
+ # prompt (str): The text prompt to generate the image from
633
+ # host (str, optional): gRPC host address. Defaults to DEFAULT_IMAGE_GEN_HOST.
634
+ # port (int, optional): gRPC port number. Defaults to DEFAULT_IMAGE_GEN_PORT.
635
+ # width (int, optional): Output image width. Defaults to 1024.
636
+ # height (int, optional): Output image height. Defaults to 1024.
637
+ # timeout (int, optional): Maximum time to wait for generation in seconds. Defaults to 300.
638
+ # max_retries (int, optional): Maximum number of retry attempts. Defaults to 3.
639
+
640
+ # Returns:
641
+ # bytes: The raw image data bytes
642
+
643
+ # Raises:
644
+ # OpenGradientError: If the image generation fails
645
+ # TimeoutError: If the generation exceeds the timeout period
646
+ # """
647
+
648
+ # def exponential_backoff(attempt: int, max_delay: float = 30.0) -> None:
649
+ # """Calculate and sleep for exponential backoff duration"""
650
+ # delay = min(0.1 * (2**attempt), max_delay)
651
+ # time.sleep(delay)
652
+
653
+ # channel = None
654
+ # start_time = time.time()
655
+ # retry_count = 0
656
+
657
+ # try:
658
+ # while retry_count < max_retries:
659
+ # try:
660
+ # # Initialize gRPC channel and stub
661
+ # channel = grpc.insecure_channel(f"{host}:{port}")
662
+ # stub = infer_pb2_grpc.InferenceServiceStub(channel)
663
+
664
+ # # Create image generation request
665
+ # image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
666
+
667
+ # # Create inference request with random transaction ID
668
+ # tx_id = str(uuid.uuid4())
669
+ # request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
670
+
671
+ # # Send request with timeout
672
+ # response_id = stub.RunInferenceAsync(
673
+ # request,
674
+ # timeout=min(30, timeout), # Initial request timeout
675
+ # )
676
+
677
+ # # Poll for completion
678
+ # attempt = 0
679
+ # while True:
680
+ # # Check timeout
681
+ # if time.time() - start_time > timeout:
682
+ # raise TimeoutError(f"Image generation timed out after {timeout} seconds")
683
+
684
+ # status_request = infer_pb2.InferenceTxId(id=response_id.id)
685
+ # try:
686
+ # status = stub.GetInferenceStatus(
687
+ # status_request,
688
+ # timeout=min(5, timeout), # Status check timeout
689
+ # ).status
690
+ # except grpc.RpcError as e:
691
+ # logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
692
+ # exponential_backoff(attempt)
693
+ # attempt += 1
694
+ # continue
695
+
696
+ # if status == infer_pb2.InferenceStatus.STATUS_COMPLETED:
697
+ # break
698
+ # elif status == infer_pb2.InferenceStatus.STATUS_ERROR:
699
+ # raise OpenGradientError("Image generation failed on server")
700
+ # elif status != infer_pb2.InferenceStatus.STATUS_IN_PROGRESS:
701
+ # raise OpenGradientError(f"Unexpected status: {status}")
702
+
703
+ # exponential_backoff(attempt)
704
+ # attempt += 1
705
+
706
+ # # Get result
707
+ # result = stub.GetInferenceResult(
708
+ # response_id,
709
+ # timeout=min(30, timeout), # Result fetch timeout
710
+ # )
711
+ # return result.image_generation_result.image_data
712
+
713
+ # except (grpc.RpcError, TimeoutError) as e:
714
+ # retry_count += 1
715
+ # if retry_count >= max_retries:
716
+ # raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
717
+
718
+ # logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
719
+ # exponential_backoff(retry_count)
720
+
721
+ # except grpc.RpcError as e:
722
+ # logging.error(f"gRPC error: {str(e)}")
723
+ # raise OpenGradientError(f"Image generation failed: {str(e)}")
724
+ # except TimeoutError as e:
725
+ # logging.error(f"Timeout error: {str(e)}")
726
+ # raise
727
+ # except Exception as e:
728
+ # logging.error(f"Error in generate image method: {str(e)}", exc_info=True)
729
+ # raise OpenGradientError(f"Image generation failed: {str(e)}")
730
+ # finally:
731
+ # if channel:
732
+ # channel.close()
733
+
734
+ def _get_abi(self, abi_name) -> str:
645
735
  """
646
- Generate an image using a diffusion model through gRPC.
647
-
648
- Args:
649
- model_cid (str): The model identifier (e.g. "stabilityai/stable-diffusion-xl-base-1.0")
650
- prompt (str): The text prompt to generate the image from
651
- host (str, optional): gRPC host address. Defaults to DEFAULT_IMAGE_GEN_HOST.
652
- port (int, optional): gRPC port number. Defaults to DEFAULT_IMAGE_GEN_PORT.
653
- width (int, optional): Output image width. Defaults to 1024.
654
- height (int, optional): Output image height. Defaults to 1024.
655
- timeout (int, optional): Maximum time to wait for generation in seconds. Defaults to 300.
656
- max_retries (int, optional): Maximum number of retry attempts. Defaults to 3.
657
-
658
- Returns:
659
- bytes: The raw image data bytes
660
-
661
- Raises:
662
- OpenGradientError: If the image generation fails
663
- TimeoutError: If the generation exceeds the timeout period
736
+ Returns the ABI for the requested contract.
664
737
  """
738
+ abi_path = Path(__file__).parent / "abi" / abi_name
739
+ with open(abi_path, "r") as f:
740
+ return json.load(f)
665
741
 
666
- def exponential_backoff(attempt: int, max_delay: float = 30.0) -> None:
667
- """Calculate and sleep for exponential backoff duration"""
668
- delay = min(0.1 * (2**attempt), max_delay)
669
- time.sleep(delay)
670
-
671
- channel = None
672
- start_time = time.time()
673
- retry_count = 0
674
-
675
- try:
676
- while retry_count < max_retries:
677
- try:
678
- # Initialize gRPC channel and stub
679
- channel = grpc.insecure_channel(f"{host}:{port}")
680
- stub = infer_pb2_grpc.InferenceServiceStub(channel)
681
-
682
- # Create image generation request
683
- image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
684
-
685
- # Create inference request with random transaction ID
686
- tx_id = str(uuid.uuid4())
687
- request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
688
-
689
- # Send request with timeout
690
- response_id = stub.RunInferenceAsync(
691
- request,
692
- timeout=min(30, timeout), # Initial request timeout
693
- )
694
-
695
- # Poll for completion
696
- attempt = 0
697
- while True:
698
- # Check timeout
699
- if time.time() - start_time > timeout:
700
- raise TimeoutError(f"Image generation timed out after {timeout} seconds")
701
-
702
- status_request = infer_pb2.InferenceTxId(id=response_id.id)
703
- try:
704
- status = stub.GetInferenceStatus(
705
- status_request,
706
- timeout=min(5, timeout), # Status check timeout
707
- ).status
708
- except grpc.RpcError as e:
709
- logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
710
- exponential_backoff(attempt)
711
- attempt += 1
712
- continue
713
-
714
- if status == infer_pb2.InferenceStatus.STATUS_COMPLETED:
715
- break
716
- elif status == infer_pb2.InferenceStatus.STATUS_ERROR:
717
- raise OpenGradientError("Image generation failed on server")
718
- elif status != infer_pb2.InferenceStatus.STATUS_IN_PROGRESS:
719
- raise OpenGradientError(f"Unexpected status: {status}")
720
-
721
- exponential_backoff(attempt)
722
- attempt += 1
723
-
724
- # Get result
725
- result = stub.GetInferenceResult(
726
- response_id,
727
- timeout=min(30, timeout), # Result fetch timeout
728
- )
729
- return result.image_generation_result.image_data
730
-
731
- except (grpc.RpcError, TimeoutError) as e:
732
- retry_count += 1
733
- if retry_count >= max_retries:
734
- raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
735
-
736
- logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
737
- exponential_backoff(retry_count)
738
-
739
- except grpc.RpcError as e:
740
- logging.error(f"gRPC error: {str(e)}")
741
- raise OpenGradientError(f"Image generation failed: {str(e)}")
742
- except TimeoutError as e:
743
- logging.error(f"Timeout error: {str(e)}")
744
- raise
745
- except Exception as e:
746
- logging.error(f"Error in generate image method: {str(e)}", exc_info=True)
747
- raise OpenGradientError(f"Image generation failed: {str(e)}")
748
- finally:
749
- if channel:
750
- channel.close()
751
-
752
- def _get_model_executor_abi(self) -> List[Dict]:
742
+ def _get_bin(self, bin_name) -> str:
753
743
  """
754
- Returns the ABI for the ModelExecutorHistorical contract.
744
+ Returns the bin for the requested contract.
755
745
  """
756
- abi_path = Path(__file__).parent / "abi" / "ModelExecutorHistorical.abi"
757
- with open(abi_path, "r") as f:
758
- return json.load(f)
746
+ bin_path = Path(__file__).parent / "bin" / bin_name
747
+ # Read bytecode with explicit encoding
748
+ with open(bin_path, "r", encoding="utf-8") as f:
749
+ bytecode = f.read().strip()
750
+ if not bytecode.startswith("0x"):
751
+ bytecode = "0x" + bytecode
752
+ return bytecode
759
753
 
760
754
  def new_workflow(
761
755
  self,
762
756
  model_cid: str,
763
- input_query: Union[Dict[str, Any], HistoricalInputQuery],
757
+ input_query: HistoricalInputQuery,
764
758
  input_tensor_name: str,
765
759
  scheduler_params: Optional[SchedulerParams] = None,
766
760
  ) -> str:
767
761
  """
768
762
  Deploy a new workflow contract with the specified parameters.
769
- """
770
- if isinstance(input_query, dict):
771
- input_query = HistoricalInputQuery.from_dict(input_query)
772
763
 
773
- # Get contract ABI and bytecode
774
- abi = self._get_model_executor_abi()
775
- bin_path = Path(__file__).parent / "bin" / "ModelExecutorHistorical.bin"
776
-
777
- with open(bin_path, "r") as f:
778
- bytecode = f.read().strip()
764
+ This function deploys a new workflow contract on OpenGradient that connects
765
+ an AI model with its required input data. When executed, the workflow will fetch
766
+ the specified model, evaluate the input query to get data, and perform inference.
779
767
 
780
- print("📦 Deploying workflow contract...")
768
+ The workflow can be set to execute manually or automatically via a scheduler.
781
769
 
782
- # Create contract instance
783
- contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode)
770
+ Args:
771
+ model_cid (str): CID of the model to be executed from the Model Hub
772
+ input_query (HistoricalInputQuery): Input definition for the model inference,
773
+ will be evaluated at runtime for each inference
774
+ input_tensor_name (str): Name of the input tensor expected by the model
775
+ scheduler_params (Optional[SchedulerParams]): Scheduler configuration for automated execution:
776
+ - frequency: Execution frequency in seconds
777
+ - duration_hours: How long the schedule should live for
784
778
 
785
- # Deploy contract with constructor arguments
786
- transaction = contract.constructor().build_transaction(
787
- {
788
- "from": self._wallet_account.address,
789
- "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
790
- "gas": 15000000,
791
- "gasPrice": self._blockchain.eth.gas_price,
792
- "chainId": self._blockchain.eth.chain_id,
793
- }
794
- )
779
+ Returns:
780
+ str: Deployed contract address. If scheduler_params was provided, the workflow
781
+ will be automatically executed according to the specified schedule.
795
782
 
796
- signed_txn = self._wallet_account.sign_transaction(transaction)
797
- tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
798
- tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=REGULAR_TX_TIMEOUT)
799
- contract_address = tx_receipt.contractAddress
783
+ Raises:
784
+ Exception: If transaction fails or gas estimation fails
785
+ """
786
+ # Get contract ABI and bytecode
787
+ abi = self._get_abi("PriceHistoryInference.abi")
788
+ bytecode = self._get_bin("PriceHistoryInference.bin")
800
789
 
801
- print(f"✅ Workflow contract deployed at: {contract_address}")
790
+ def deploy_transaction():
791
+ contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode)
792
+ query_tuple = input_query.to_abi_format()
793
+ constructor_args = [model_cid, input_tensor_name, query_tuple]
802
794
 
803
- # Register with scheduler if params provided
804
- if scheduler_params:
805
- print("\n⏰ Setting up automated execution schedule...")
806
- print(f" • Frequency: Every {scheduler_params.frequency} seconds")
807
- print(f" • Duration: {scheduler_params.duration_hours} hours")
808
- print(f" End Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(scheduler_params.end_time))}")
795
+ try:
796
+ # Estimate gas needed
797
+ estimated_gas = contract.constructor(*constructor_args).estimate_gas({"from": self._wallet_account.address})
798
+ gas_limit = int(estimated_gas * 1.2)
799
+ except Exception as e:
800
+ print(f"⚠️ Gas estimation failed: {str(e)}")
801
+ gas_limit = 5000000 # Conservative fallback
802
+ print(f"📊 Using fallback gas limit: {gas_limit}")
809
803
 
810
- scheduler_abi = [
804
+ transaction = contract.constructor(*constructor_args).build_transaction(
811
805
  {
812
- "inputs": [
813
- {"internalType": "address", "name": "contractAddress", "type": "address"},
814
- {"internalType": "uint256", "name": "endTime", "type": "uint256"},
815
- {"internalType": "uint256", "name": "frequency", "type": "uint256"},
816
- ],
817
- "name": "registerTask",
818
- "outputs": [],
819
- "stateMutability": "nonpayable",
820
- "type": "function",
806
+ "from": self._wallet_account.address,
807
+ "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
808
+ "gas": gas_limit,
809
+ "gasPrice": self._blockchain.eth.gas_price,
810
+ "chainId": self._blockchain.eth.chain_id,
821
811
  }
822
- ]
812
+ )
823
813
 
824
- scheduler_address = "0x6F937b9f4Fa7723932827cd73063B70Be2b56748"
825
- scheduler_contract = self._blockchain.eth.contract(address=scheduler_address, abi=scheduler_abi)
814
+ signed_txn = self._wallet_account.sign_transaction(transaction)
815
+ tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
826
816
 
827
- try:
828
- # Register the workflow with the scheduler
829
- scheduler_tx = scheduler_contract.functions.registerTask(
830
- contract_address, scheduler_params.end_time, scheduler_params.frequency
831
- ).build_transaction(
832
- {
833
- "from": self._wallet_account.address,
834
- "gas": 300000,
835
- "gasPrice": self._blockchain.eth.gas_price,
836
- "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
837
- "chainId": self._blockchain.eth.chain_id,
838
- }
839
- )
817
+ tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=60)
840
818
 
841
- signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx)
842
- scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction)
843
- self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash, timeout=REGULAR_TX_TIMEOUT)
819
+ if tx_receipt["status"] == 0:
820
+ raise Exception(f"❌ Contract deployment failed, transaction hash: {tx_hash.hex()}")
844
821
 
845
- print("✅ Automated execution schedule set successfully!")
846
- print(f" Transaction hash: {scheduler_tx_hash.hex()}")
822
+ return tx_receipt.contractAddress
847
823
 
848
- except Exception as e:
849
- print("❌ Failed to set up automated execution schedule")
850
- print(f" Error: {str(e)}")
851
- print(" The workflow contract is still deployed and can be executed manually.")
824
+ contract_address = run_with_retry(deploy_transaction)
825
+
826
+ if scheduler_params:
827
+ self._register_with_scheduler(contract_address, scheduler_params)
852
828
 
853
829
  return contract_address
854
830
 
831
+ def _register_with_scheduler(self, contract_address: str, scheduler_params: SchedulerParams) -> None:
832
+ """
833
+ Register the deployed workflow contract with the scheduler for automated execution.
834
+
835
+ Args:
836
+ contract_address (str): Address of the deployed workflow contract
837
+ scheduler_params (SchedulerParams): Scheduler configuration containing:
838
+ - frequency: Execution frequency in seconds
839
+ - duration_hours: How long to run in hours
840
+ - end_time: Unix timestamp when scheduling should end
841
+
842
+ Raises:
843
+ Exception: If registration with scheduler fails. The workflow contract will
844
+ still be deployed and can be executed manually.
845
+ """
846
+
847
+ scheduler_abi = self._get_abi("WorkflowScheduler.abi")
848
+
849
+ # Scheduler contract address
850
+ scheduler_address = DEFAULT_SCHEDULER_ADDRESS
851
+ scheduler_contract = self._blockchain.eth.contract(address=scheduler_address, abi=scheduler_abi)
852
+
853
+ try:
854
+ # Register the workflow with the scheduler
855
+ scheduler_tx = scheduler_contract.functions.registerTask(
856
+ contract_address, scheduler_params.end_time, scheduler_params.frequency
857
+ ).build_transaction(
858
+ {
859
+ "from": self._wallet_account.address,
860
+ "gas": 300000,
861
+ "gasPrice": self._blockchain.eth.gas_price,
862
+ "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
863
+ "chainId": self._blockchain.eth.chain_id,
864
+ }
865
+ )
866
+
867
+ signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx)
868
+ scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction)
869
+ self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash, timeout=REGULAR_TX_TIMEOUT)
870
+ except Exception as e:
871
+ print(f"❌ Error registering contract with scheduler: {str(e)}")
872
+ print(" The workflow contract is still deployed and can be executed manually.")
873
+
855
874
  def read_workflow_result(self, contract_address: str) -> ModelOutput:
856
875
  """
857
876
  Reads the latest inference result from a deployed workflow contract.
@@ -867,12 +886,14 @@ class Client:
867
886
  Web3Error: If there are issues with the web3 connection or contract interaction
868
887
  """
869
888
  # Get the contract interface
870
- contract = self._blockchain.eth.contract(address=Web3.to_checksum_address(contract_address), abi=self._get_model_executor_abi())
889
+ contract = self._blockchain.eth.contract(
890
+ address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
891
+ )
871
892
 
872
893
  # Get the result
873
894
  result = contract.functions.getInferenceResult().call()
874
895
 
875
- return utils.convert_array_to_model_output(result)
896
+ return convert_array_to_model_output(result)
876
897
 
877
898
  def run_workflow(self, contract_address: str) -> ModelOutput:
878
899
  """
@@ -889,7 +910,9 @@ class Client:
889
910
  Web3Error: If there are issues with the web3 connection or contract interaction
890
911
  """
891
912
  # Get the contract interface
892
- contract = self._blockchain.eth.contract(address=Web3.to_checksum_address(contract_address), abi=self._get_model_executor_abi())
913
+ contract = self._blockchain.eth.contract(
914
+ address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
915
+ )
893
916
 
894
917
  # Call run() function
895
918
  nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
@@ -915,7 +938,28 @@ class Client:
915
938
  # Get the inference result from the contract
916
939
  result = contract.functions.getInferenceResult().call()
917
940
 
918
- return utils.convert_array_to_model_output(result)
941
+ return convert_array_to_model_output(result)
942
+
943
+ def read_workflow_history(self, contract_address: str, num_results: int) -> List[ModelOutput]:
944
+ """
945
+ Gets historical inference results from a workflow contract.
946
+
947
+ Retrieves the specified number of most recent inference results from the contract's
948
+ storage, with the most recent result first.
949
+
950
+ Args:
951
+ contract_address (str): Address of the deployed workflow contract
952
+ num_results (int): Number of historical results to retrieve
953
+
954
+ Returns:
955
+ List[ModelOutput]: List of historical inference results
956
+ """
957
+ contract = self._blockchain.eth.contract(
958
+ address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
959
+ )
960
+
961
+ results = contract.functions.getLastInferenceResults(num_results).call()
962
+ return [convert_array_to_model_output(result) for result in results]
919
963
 
920
964
 
921
965
  def run_with_retry(txn_function, max_retries=DEFAULT_MAX_RETRY, retry_delay=DEFAULT_RETRY_DELAY_SEC):