opengradient 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
opengradient/client.py CHANGED
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union, Callable
9
9
  import firebase
10
10
  import numpy as np
11
11
  import requests
12
+ import httpx
12
13
  from eth_account.account import LocalAccount
13
14
  from web3 import Web3
14
15
  from web3.exceptions import ContractLogicError
@@ -17,7 +18,9 @@ import urllib.parse
17
18
  import asyncio
18
19
  from x402.clients.httpx import x402HttpxClient
19
20
  from x402.clients.base import decode_x_payment_response, x402Client
21
+ from x402.clients.httpx import x402HttpxClient
20
22
 
23
+ from .x402_auth import X402Auth
21
24
  from .exceptions import OpenGradientError
22
25
  from .proto import infer_pb2, infer_pb2_grpc
23
26
  from .types import (
@@ -29,10 +32,12 @@ from .types import (
29
32
  LlmInferenceMode,
30
33
  ModelOutput,
31
34
  TextGenerationOutput,
35
+ TextGenerationStream,
32
36
  SchedulerParams,
33
37
  InferenceResult,
34
38
  ModelRepository,
35
39
  FileUploadResult,
40
+ StreamChunk,
36
41
  )
37
42
  from .defaults import (
38
43
  DEFAULT_IMAGE_GEN_HOST,
@@ -40,8 +45,10 @@ from .defaults import (
40
45
  DEFAULT_SCHEDULER_ADDRESS,
41
46
  DEFAULT_LLM_SERVER_URL,
42
47
  DEFAULT_OPENGRADIENT_LLM_SERVER_URL,
48
+ DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL,
49
+ DEFAULT_NETWORK_FILTER,
43
50
  )
44
- from .utils import convert_array_to_model_output, convert_to_model_input, convert_to_model_output
51
+ from .utils import convert_to_model_input, convert_to_model_output
45
52
 
46
53
  _FIREBASE_CONFIG = {
47
54
  "apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
@@ -66,6 +73,18 @@ PRECOMPILE_CONTRACT_ADDRESS = "0x00000000000000000000000000000000000000F4"
66
73
  X402_PROCESSING_HASH_HEADER = "x-processing-hash"
67
74
  X402_PLACEHOLDER_API_KEY = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"
68
75
 
76
+ TIMEOUT = httpx.Timeout(
77
+ timeout=90.0,
78
+ connect=15.0,
79
+ read=15.0,
80
+ write=30.0,
81
+ pool=10.0,
82
+ )
83
+ LIMITS = httpx.Limits(
84
+ max_keepalive_connections=100,
85
+ max_connections=500,
86
+ keepalive_expiry=60 * 20, # 20 minutes
87
+ )
69
88
 
70
89
  class Client:
71
90
  _inference_hub_contract_address: str
@@ -89,6 +108,7 @@ class Client:
89
108
  password: Optional[str] = None,
90
109
  llm_server_url: Optional[str] = DEFAULT_LLM_SERVER_URL,
91
110
  og_llm_server_url: Optional[str] = DEFAULT_OPENGRADIENT_LLM_SERVER_URL,
111
+ og_llm_streaming_server_url: Optional[str] = DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL,
92
112
  openai_api_key: Optional[str] = None,
93
113
  anthropic_api_key: Optional[str] = None,
94
114
  google_api_key: Optional[str] = None,
@@ -123,6 +143,7 @@ class Client:
123
143
 
124
144
  self._llm_server_url = llm_server_url
125
145
  self._og_llm_server_url = og_llm_server_url
146
+ self._og_llm_streaming_server_url = og_llm_streaming_server_url
126
147
 
127
148
  self._external_api_keys = {}
128
149
  if openai_api_key or os.getenv("OPENAI_API_KEY"):
@@ -132,6 +153,25 @@ class Client:
132
153
  if google_api_key or os.getenv("GOOGLE_API_KEY"):
133
154
  self._external_api_keys["google"] = google_api_key or os.getenv("GOOGLE_API_KEY")
134
155
 
156
+ self._alpha = None # Lazy initialization for alpha namespace
157
+
158
+ @property
159
+ def alpha(self):
160
+ """
161
+ Access Alpha Testnet features.
162
+
163
+ Returns:
164
+ Alpha: Alpha namespace with workflow and ML model execution methods.
165
+
166
+ Example:
167
+ client = og.new_client(...)
168
+ result = client.alpha.new_workflow(model_cid, input_query, input_tensor_name)
169
+ """
170
+ if self._alpha is None:
171
+ from .alpha import Alpha
172
+ self._alpha = Alpha(self)
173
+ return self._alpha
174
+
135
175
  def set_api_key(self, provider: str, api_key: str):
136
176
  """
137
177
  Set or update API key for an external provider.
@@ -421,11 +461,11 @@ class Client:
421
461
 
422
462
  return run_with_retry(execute_transaction, max_retries)
423
463
 
424
- def _og_payment_selector(self, accepts, network_filter=None, scheme_filter=None, max_value=None):
425
- """Custom payment selector for OpenGradient network (og-devnet)."""
464
+ def _og_payment_selector(self, accepts, network_filter=DEFAULT_NETWORK_FILTER, scheme_filter=None, max_value=None):
465
+ """Custom payment selector for OpenGradient network."""
426
466
  return x402Client.default_payment_requirements_selector(
427
467
  accepts,
428
- network_filter="og-devnet",
468
+ network_filter=network_filter,
429
469
  scheme_filter=scheme_filter,
430
470
  max_value=max_value,
431
471
  )
@@ -652,7 +692,8 @@ class Client:
652
692
  max_retries: Optional[int] = None,
653
693
  local_model: Optional[bool] = False,
654
694
  x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH,
655
- ) -> TextGenerationOutput:
695
+ stream: bool = False,
696
+ ) -> Union[TextGenerationOutput, TextGenerationStream]:
656
697
  """
657
698
  Perform inference on an LLM model using chat.
658
699
 
@@ -672,13 +713,12 @@ class Client:
672
713
  - SETTLE_BATCH: Aggregates multiple inferences into batch hashes (most cost-efficient).
673
714
  - SETTLE_METADATA: Records full model info, complete input/output data, and all metadata.
674
715
  Defaults to SETTLE_BATCH.
716
+ stream (bool, optional): Whether to stream the response. Default is False.
675
717
 
676
718
  Returns:
677
- TextGenerationOutput: Generated text results including:
678
- - chat_output: Dict with role, content, and tool_calls
679
- - transaction_hash: Blockchain hash (or "external" for external providers)
680
- - finish_reason: Reason for completion (e.g., "stop", "tool_call")
681
- - payment_hash: Payment hash for x402 transactions (when using x402 settlement)
719
+ Union[TextGenerationOutput, TextGenerationStream]:
720
+ - If stream=False: TextGenerationOutput with chat_output, transaction_hash, finish_reason, and payment_hash
721
+ - If stream=True: TextGenerationStream yielding StreamChunk objects with typed deltas (true streaming via threading)
682
722
 
683
723
  Raises:
684
724
  OpenGradientError: If the inference fails.
@@ -689,16 +729,33 @@ class Client:
689
729
  if model_cid not in TEE_LLM:
690
730
  return OpenGradientError("That model CID is not supported yet for TEE inference")
691
731
 
692
- return self._external_llm_chat(
693
- model=model_cid.split("/")[1],
694
- messages=messages,
695
- max_tokens=max_tokens,
696
- stop_sequence=stop_sequence,
697
- temperature=temperature,
698
- tools=tools,
699
- tool_choice=tool_choice,
700
- x402_settlement_mode=x402_settlement_mode,
701
- )
732
+ if stream:
733
+ # Use threading bridge for true sync streaming
734
+ return self._external_llm_chat_stream_sync(
735
+ model=model_cid.split("/")[1],
736
+ messages=messages,
737
+ max_tokens=max_tokens,
738
+ stop_sequence=stop_sequence,
739
+ temperature=temperature,
740
+ tools=tools,
741
+ tool_choice=tool_choice,
742
+ x402_settlement_mode=x402_settlement_mode,
743
+ use_tee=True,
744
+ )
745
+ else:
746
+ # Non-streaming
747
+ return self._external_llm_chat(
748
+ model=model_cid.split("/")[1],
749
+ messages=messages,
750
+ max_tokens=max_tokens,
751
+ stop_sequence=stop_sequence,
752
+ temperature=temperature,
753
+ tools=tools,
754
+ tool_choice=tool_choice,
755
+ x402_settlement_mode=x402_settlement_mode,
756
+ stream=False,
757
+ use_tee=True,
758
+ )
702
759
 
703
760
  # Original local model logic
704
761
  def execute_transaction():
@@ -778,7 +835,9 @@ class Client:
778
835
  tools: Optional[List[Dict]] = None,
779
836
  tool_choice: Optional[str] = None,
780
837
  x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH,
781
- ) -> TextGenerationOutput:
838
+ stream: bool = False,
839
+ use_tee: bool = False,
840
+ ) -> Union[TextGenerationOutput, TextGenerationStream]:
782
841
  """
783
842
  Route chat request to external LLM server with x402 payments.
784
843
 
@@ -790,18 +849,24 @@ class Client:
790
849
  temperature: Sampling temperature
791
850
  tools: Function calling tools
792
851
  tool_choice: Tool selection strategy
852
+ stream: Whether to stream the response
853
+ use_tee: Whether to use TEE
793
854
 
794
855
  Returns:
795
- TextGenerationOutput with chat completion
856
+ Union[TextGenerationOutput, TextGenerationStream]: Chat completion or TextGenerationStream
796
857
 
797
858
  Raises:
798
859
  OpenGradientError: If request fails
799
860
  """
800
- api_key = self._get_api_key_for_model(model)
861
+ api_key = None if use_tee else self._get_api_key_for_model(model)
801
862
 
802
863
  if api_key:
803
- logging.debug("External LLM completion using API key")
804
- url = f"{self._llm_server_url}/v1/chat/completions"
864
+ logging.debug("External LLM chat using API key")
865
+
866
+ if stream:
867
+ url = f"{self._llm_server_url}/v1/chat/completions/stream"
868
+ else:
869
+ url = f"{self._llm_server_url}/v1/chat/completions"
805
870
 
806
871
  headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
807
872
 
@@ -820,14 +885,23 @@ class Client:
820
885
  payload["tool_choice"] = tool_choice or "auto"
821
886
 
822
887
  try:
823
- response = requests.post(url, json=payload, headers=headers, timeout=60)
824
- response.raise_for_status()
888
+ if stream:
889
+ # Return streaming response wrapped in TextGenerationStream
890
+ response = requests.post(url, json=payload, headers=headers, timeout=60, stream=True)
891
+ response.raise_for_status()
892
+ return TextGenerationStream(_iterator=response.iter_lines(decode_unicode=True), _is_async=False)
893
+ else:
894
+ # Non-streaming response
895
+ response = requests.post(url, json=payload, headers=headers, timeout=60)
896
+ response.raise_for_status()
825
897
 
826
- result = response.json()
898
+ result = response.json()
827
899
 
828
- return TextGenerationOutput(
829
- transaction_hash="external", finish_reason=result.get("finish_reason"), chat_output=result.get("message")
830
- )
900
+ return TextGenerationOutput(
901
+ transaction_hash="external",
902
+ finish_reason=result.get("finish_reason"),
903
+ chat_output=result.get("message")
904
+ )
831
905
 
832
906
  except requests.RequestException as e:
833
907
  error_msg = f"External LLM chat failed: {str(e)}"
@@ -840,6 +914,7 @@ class Client:
840
914
  logging.error(error_msg)
841
915
  raise OpenGradientError(error_msg)
842
916
 
917
+ # x402 payment path - non-streaming only here
843
918
  async def make_request():
844
919
  async with x402HttpxClient(
845
920
  account=self._wallet_account,
@@ -867,13 +942,13 @@ class Client:
867
942
  payload["tool_choice"] = tool_choice or "auto"
868
943
 
869
944
  try:
870
- response = await client.post("/v1/chat/completions", json=payload, headers=headers, timeout=60)
945
+ # Non-streaming with x402
946
+ endpoint = "/v1/chat/completions"
947
+ response = await client.post(endpoint, json=payload, headers=headers, timeout=60)
871
948
 
872
949
  # Read the response content
873
950
  content = await response.aread()
874
951
  result = json.loads(content.decode())
875
- # print(f"Response: {response}")
876
- # print(f"Response Headers: {response.headers}")
877
952
 
878
953
  payment_hash = ""
879
954
  if X402_PROCESSING_HASH_HEADER in response.headers:
@@ -909,6 +984,234 @@ class Client:
909
984
  logging.error(error_msg)
910
985
  raise OpenGradientError(error_msg)
911
986
 
987
+ def _external_llm_chat_stream_sync(
988
+ self,
989
+ model: str,
990
+ messages: List[Dict],
991
+ max_tokens: int = 100,
992
+ stop_sequence: Optional[List[str]] = None,
993
+ temperature: float = 0.0,
994
+ tools: Optional[List[Dict]] = None,
995
+ tool_choice: Optional[str] = None,
996
+ x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH,
997
+ use_tee: bool = False,
998
+ ):
999
+ """
1000
+ Sync streaming using threading bridge - TRUE real-time streaming.
1001
+
1002
+ Yields StreamChunk objects as they arrive from the background thread.
1003
+ NO buffering, NO conversion, just direct pass-through.
1004
+ """
1005
+ import threading
1006
+ from queue import Queue
1007
+
1008
+ queue = Queue()
1009
+ exception_holder = []
1010
+
1011
+ def _run_async():
1012
+ """Run async streaming in background thread"""
1013
+ loop = None
1014
+ try:
1015
+ loop = asyncio.new_event_loop()
1016
+ asyncio.set_event_loop(loop)
1017
+
1018
+ async def _stream():
1019
+ try:
1020
+ async for chunk in self._external_llm_chat_stream_async(
1021
+ model=model,
1022
+ messages=messages,
1023
+ max_tokens=max_tokens,
1024
+ stop_sequence=stop_sequence,
1025
+ temperature=temperature,
1026
+ tools=tools,
1027
+ tool_choice=tool_choice,
1028
+ x402_settlement_mode=x402_settlement_mode,
1029
+ use_tee=use_tee,
1030
+ ):
1031
+ queue.put(chunk) # Put chunk immediately
1032
+ except Exception as e:
1033
+ exception_holder.append(e)
1034
+ finally:
1035
+ queue.put(None) # Signal completion
1036
+
1037
+ loop.run_until_complete(_stream())
1038
+ except Exception as e:
1039
+ exception_holder.append(e)
1040
+ queue.put(None)
1041
+ finally:
1042
+ if loop:
1043
+ try:
1044
+ pending = asyncio.all_tasks(loop)
1045
+ for task in pending:
1046
+ task.cancel()
1047
+ loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
1048
+ finally:
1049
+ loop.close()
1050
+
1051
+ # Start background thread
1052
+ thread = threading.Thread(target=_run_async, daemon=True)
1053
+ thread.start()
1054
+
1055
+ # Yield chunks DIRECTLY as they arrive - NO buffering
1056
+ try:
1057
+ while True:
1058
+ chunk = queue.get() # Blocks until chunk available
1059
+ if chunk is None:
1060
+ break
1061
+ yield chunk # Yield immediately!
1062
+
1063
+ thread.join(timeout=5)
1064
+
1065
+ if exception_holder:
1066
+ raise exception_holder[0]
1067
+ except Exception as e:
1068
+ thread.join(timeout=1)
1069
+ raise
1070
+
1071
+
1072
+ async def _external_llm_chat_stream_async(
1073
+ self,
1074
+ model: str,
1075
+ messages: List[Dict],
1076
+ max_tokens: int = 100,
1077
+ stop_sequence: Optional[List[str]] = None,
1078
+ temperature: float = 0.0,
1079
+ tools: Optional[List[Dict]] = None,
1080
+ tool_choice: Optional[str] = None,
1081
+ x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH,
1082
+ use_tee: bool = False,
1083
+ ):
1084
+ """
1085
+ Internal async streaming implementation.
1086
+
1087
+ Yields StreamChunk objects as they arrive from the server.
1088
+ """
1089
+ api_key = None if use_tee else self._get_api_key_for_model(model)
1090
+
1091
+ if api_key:
1092
+ # API key path - streaming to local llm-server
1093
+ url = f"{self._og_llm_streaming_server_url}/v1/chat/completions"
1094
+ headers = {
1095
+ "Content-Type": "application/json",
1096
+ "Authorization": f"Bearer {api_key}"
1097
+ }
1098
+
1099
+ payload = {
1100
+ "model": model,
1101
+ "messages": messages,
1102
+ "max_tokens": max_tokens,
1103
+ "temperature": temperature,
1104
+ "stream": True,
1105
+ }
1106
+
1107
+ if stop_sequence:
1108
+ payload["stop"] = stop_sequence
1109
+ if tools:
1110
+ payload["tools"] = tools
1111
+ payload["tool_choice"] = tool_choice or "auto"
1112
+
1113
+ async with httpx.AsyncClient(verify=False, timeout=None) as client:
1114
+ async with client.stream("POST", url, json=payload, headers=headers) as response:
1115
+ buffer = b""
1116
+ async for chunk in response.aiter_raw():
1117
+ if not chunk:
1118
+ continue
1119
+
1120
+ buffer += chunk
1121
+
1122
+ # Process all complete lines in buffer
1123
+ while b"\n" in buffer:
1124
+ line_bytes, buffer = buffer.split(b"\n", 1)
1125
+
1126
+ if not line_bytes.strip():
1127
+ continue
1128
+
1129
+ try:
1130
+ line = line_bytes.decode('utf-8').strip()
1131
+ except UnicodeDecodeError:
1132
+ continue
1133
+
1134
+ if not line.startswith("data: "):
1135
+ continue
1136
+
1137
+ data_str = line[6:] # Strip "data: " prefix
1138
+ if data_str.strip() == "[DONE]":
1139
+ return
1140
+
1141
+ try:
1142
+ data = json.loads(data_str)
1143
+ yield StreamChunk.from_sse_data(data)
1144
+ except json.JSONDecodeError:
1145
+ continue
1146
+ else:
1147
+ # x402 payment path
1148
+ async with httpx.AsyncClient(
1149
+ base_url=self._og_llm_streaming_server_url,
1150
+ headers={"Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}"},
1151
+ timeout=TIMEOUT,
1152
+ limits=LIMITS,
1153
+ http2=False,
1154
+ follow_redirects=False,
1155
+ auth=X402Auth(account=self._wallet_account), # type: ignore
1156
+ ) as client:
1157
+ headers = {
1158
+ "Content-Type": "application/json",
1159
+ "Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}",
1160
+ "X-SETTLEMENT-TYPE": x402_settlement_mode,
1161
+ }
1162
+
1163
+ payload = {
1164
+ "model": model,
1165
+ "messages": messages,
1166
+ "max_tokens": max_tokens,
1167
+ "temperature": temperature,
1168
+ "stream": True,
1169
+ }
1170
+
1171
+ if stop_sequence:
1172
+ payload["stop"] = stop_sequence
1173
+ if tools:
1174
+ payload["tools"] = tools
1175
+ payload["tool_choice"] = tool_choice or "auto"
1176
+
1177
+ async with client.stream(
1178
+ "POST",
1179
+ "/v1/chat/completions",
1180
+ json=payload,
1181
+ headers=headers,
1182
+ ) as response:
1183
+ buffer = b""
1184
+ async for chunk in response.aiter_raw():
1185
+ if not chunk:
1186
+ continue
1187
+
1188
+ buffer += chunk
1189
+
1190
+ # Process complete lines from buffer
1191
+ while b"\n" in buffer:
1192
+ line_bytes, buffer = buffer.split(b"\n", 1)
1193
+
1194
+ if not line_bytes.strip():
1195
+ continue
1196
+
1197
+ try:
1198
+ line = line_bytes.decode('utf-8').strip()
1199
+ except UnicodeDecodeError:
1200
+ continue
1201
+
1202
+ if not line.startswith("data: "):
1203
+ continue
1204
+
1205
+ data_str = line[6:]
1206
+ if data_str.strip() == "[DONE]":
1207
+ return
1208
+
1209
+ try:
1210
+ data = json.loads(data_str)
1211
+ yield StreamChunk.from_sse_data(data)
1212
+ except json.JSONDecodeError:
1213
+ continue
1214
+
912
1215
  def list_files(self, model_name: str, version: str) -> List[Dict]:
913
1216
  """
914
1217
  List files for a specific version of a model.
@@ -1141,216 +1444,6 @@ class Client:
1141
1444
 
1142
1445
  return tx_hash, tx_receipt
1143
1446
 
1144
- def new_workflow(
1145
- self,
1146
- model_cid: str,
1147
- input_query: HistoricalInputQuery,
1148
- input_tensor_name: str,
1149
- scheduler_params: Optional[SchedulerParams] = None,
1150
- ) -> str:
1151
- """
1152
- Deploy a new workflow contract with the specified parameters.
1153
-
1154
- This function deploys a new workflow contract on OpenGradient that connects
1155
- an AI model with its required input data. When executed, the workflow will fetch
1156
- the specified model, evaluate the input query to get data, and perform inference.
1157
-
1158
- The workflow can be set to execute manually or automatically via a scheduler.
1159
-
1160
- Args:
1161
- model_cid (str): CID of the model to be executed from the Model Hub
1162
- input_query (HistoricalInputQuery): Input definition for the model inference,
1163
- will be evaluated at runtime for each inference
1164
- input_tensor_name (str): Name of the input tensor expected by the model
1165
- scheduler_params (Optional[SchedulerParams]): Scheduler configuration for automated execution:
1166
- - frequency: Execution frequency in seconds
1167
- - duration_hours: How long the schedule should live for
1168
-
1169
- Returns:
1170
- str: Deployed contract address. If scheduler_params was provided, the workflow
1171
- will be automatically executed according to the specified schedule.
1172
-
1173
- Raises:
1174
- Exception: If transaction fails or gas estimation fails
1175
- """
1176
- # Get contract ABI and bytecode
1177
- abi = self._get_abi("PriceHistoryInference.abi")
1178
- bytecode = self._get_bin("PriceHistoryInference.bin")
1179
-
1180
- def deploy_transaction():
1181
- contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode)
1182
- query_tuple = input_query.to_abi_format()
1183
- constructor_args = [model_cid, input_tensor_name, query_tuple]
1184
-
1185
- try:
1186
- # Estimate gas needed
1187
- estimated_gas = contract.constructor(*constructor_args).estimate_gas({"from": self._wallet_account.address})
1188
- gas_limit = int(estimated_gas * 1.2)
1189
- except Exception as e:
1190
- print(f"⚠️ Gas estimation failed: {str(e)}")
1191
- gas_limit = 5000000 # Conservative fallback
1192
- print(f"📊 Using fallback gas limit: {gas_limit}")
1193
-
1194
- transaction = contract.constructor(*constructor_args).build_transaction(
1195
- {
1196
- "from": self._wallet_account.address,
1197
- "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
1198
- "gas": gas_limit,
1199
- "gasPrice": self._blockchain.eth.gas_price,
1200
- "chainId": self._blockchain.eth.chain_id,
1201
- }
1202
- )
1203
-
1204
- signed_txn = self._wallet_account.sign_transaction(transaction)
1205
- tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
1206
-
1207
- tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=60)
1208
-
1209
- if tx_receipt["status"] == 0:
1210
- raise Exception(f"❌ Contract deployment failed, transaction hash: {tx_hash.hex()}")
1211
-
1212
- return tx_receipt.contractAddress
1213
-
1214
- contract_address = run_with_retry(deploy_transaction)
1215
-
1216
- if scheduler_params:
1217
- self._register_with_scheduler(contract_address, scheduler_params)
1218
-
1219
- return contract_address
1220
-
1221
- def _register_with_scheduler(self, contract_address: str, scheduler_params: SchedulerParams) -> None:
1222
- """
1223
- Register the deployed workflow contract with the scheduler for automated execution.
1224
-
1225
- Args:
1226
- contract_address (str): Address of the deployed workflow contract
1227
- scheduler_params (SchedulerParams): Scheduler configuration containing:
1228
- - frequency: Execution frequency in seconds
1229
- - duration_hours: How long to run in hours
1230
- - end_time: Unix timestamp when scheduling should end
1231
-
1232
- Raises:
1233
- Exception: If registration with scheduler fails. The workflow contract will
1234
- still be deployed and can be executed manually.
1235
- """
1236
-
1237
- scheduler_abi = self._get_abi("WorkflowScheduler.abi")
1238
-
1239
- # Scheduler contract address
1240
- scheduler_address = DEFAULT_SCHEDULER_ADDRESS
1241
- scheduler_contract = self._blockchain.eth.contract(address=scheduler_address, abi=scheduler_abi)
1242
-
1243
- try:
1244
- # Register the workflow with the scheduler
1245
- scheduler_tx = scheduler_contract.functions.registerTask(
1246
- contract_address, scheduler_params.end_time, scheduler_params.frequency
1247
- ).build_transaction(
1248
- {
1249
- "from": self._wallet_account.address,
1250
- "gas": 300000,
1251
- "gasPrice": self._blockchain.eth.gas_price,
1252
- "nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
1253
- "chainId": self._blockchain.eth.chain_id,
1254
- }
1255
- )
1256
-
1257
- signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx)
1258
- scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction)
1259
- self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash, timeout=REGULAR_TX_TIMEOUT)
1260
- except Exception as e:
1261
- print(f"❌ Error registering contract with scheduler: {str(e)}")
1262
- print(" The workflow contract is still deployed and can be executed manually.")
1263
-
1264
- def read_workflow_result(self, contract_address: str) -> ModelOutput:
1265
- """
1266
- Reads the latest inference result from a deployed workflow contract.
1267
-
1268
- Args:
1269
- contract_address (str): Address of the deployed workflow contract
1270
-
1271
- Returns:
1272
- ModelOutput: The inference result from the contract
1273
-
1274
- Raises:
1275
- ContractLogicError: If the transaction fails
1276
- Web3Error: If there are issues with the web3 connection or contract interaction
1277
- """
1278
- # Get the contract interface
1279
- contract = self._blockchain.eth.contract(
1280
- address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
1281
- )
1282
-
1283
- # Get the result
1284
- result = contract.functions.getInferenceResult().call()
1285
-
1286
- return convert_array_to_model_output(result)
1287
-
1288
- def run_workflow(self, contract_address: str) -> ModelOutput:
1289
- """
1290
- Triggers the run() function on a deployed workflow contract and returns the result.
1291
-
1292
- Args:
1293
- contract_address (str): Address of the deployed workflow contract
1294
-
1295
- Returns:
1296
- ModelOutput: The inference result from the contract
1297
-
1298
- Raises:
1299
- ContractLogicError: If the transaction fails
1300
- Web3Error: If there are issues with the web3 connection or contract interaction
1301
- """
1302
- # Get the contract interface
1303
- contract = self._blockchain.eth.contract(
1304
- address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
1305
- )
1306
-
1307
- # Call run() function
1308
- nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
1309
-
1310
- run_function = contract.functions.run()
1311
- transaction = run_function.build_transaction(
1312
- {
1313
- "from": self._wallet_account.address,
1314
- "nonce": nonce,
1315
- "gas": 30000000,
1316
- "gasPrice": self._blockchain.eth.gas_price,
1317
- "chainId": self._blockchain.eth.chain_id,
1318
- }
1319
- )
1320
-
1321
- signed_txn = self._wallet_account.sign_transaction(transaction)
1322
- tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
1323
- tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=INFERENCE_TX_TIMEOUT)
1324
-
1325
- if tx_receipt.status == 0:
1326
- raise ContractLogicError(f"Run transaction failed. Receipt: {tx_receipt}")
1327
-
1328
- # Get the inference result from the contract
1329
- result = contract.functions.getInferenceResult().call()
1330
-
1331
- return convert_array_to_model_output(result)
1332
-
1333
- def read_workflow_history(self, contract_address: str, num_results: int) -> List[ModelOutput]:
1334
- """
1335
- Gets historical inference results from a workflow contract.
1336
-
1337
- Retrieves the specified number of most recent inference results from the contract's
1338
- storage, with the most recent result first.
1339
-
1340
- Args:
1341
- contract_address (str): Address of the deployed workflow contract
1342
- num_results (int): Number of historical results to retrieve
1343
-
1344
- Returns:
1345
- List[ModelOutput]: List of historical inference results
1346
- """
1347
- contract = self._blockchain.eth.contract(
1348
- address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
1349
- )
1350
-
1351
- results = contract.functions.getLastInferenceResults(num_results).call()
1352
- return [convert_array_to_model_output(result) for result in results]
1353
-
1354
1447
  def _get_inference_result_from_node(self, inference_id: str, inference_mode: InferenceMode) -> Dict:
1355
1448
  """
1356
1449
  Get the inference result from node.