opengradient 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opengradient/__init__.py +59 -67
- opengradient/abi/PriceHistoryInference.abi +1 -0
- opengradient/abi/WorkflowScheduler.abi +13 -0
- opengradient/alphasense/read_workflow_tool.py +1 -1
- opengradient/alphasense/run_model_tool.py +3 -3
- opengradient/bin/PriceHistoryInference.bin +1 -0
- opengradient/cli.py +12 -8
- opengradient/client.py +296 -218
- opengradient/defaults.py +1 -0
- opengradient/llm/__init__.py +1 -1
- opengradient/llm/og_langchain.py +36 -20
- opengradient/llm/og_openai.py +4 -2
- opengradient/types.py +36 -17
- opengradient/utils.py +2 -0
- opengradient-0.4.7.dist-info/METADATA +159 -0
- opengradient-0.4.7.dist-info/RECORD +29 -0
- {opengradient-0.4.5.dist-info → opengradient-0.4.7.dist-info}/WHEEL +1 -1
- opengradient/abi/ModelExecutorHistorical.abi +0 -1
- opengradient-0.4.5.dist-info/METADATA +0 -189
- opengradient-0.4.5.dist-info/RECORD +0 -27
- {opengradient-0.4.5.dist-info → opengradient-0.4.7.dist-info}/LICENSE +0 -0
- {opengradient-0.4.5.dist-info → opengradient-0.4.7.dist-info}/entry_points.txt +0 -0
- {opengradient-0.4.5.dist-info → opengradient-0.4.7.dist-info}/top_level.txt +0 -0
opengradient/client.py
CHANGED
|
@@ -2,12 +2,10 @@ import json
|
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
import time
|
|
5
|
-
import uuid
|
|
6
5
|
from pathlib import Path
|
|
7
|
-
from typing import Any, Dict, List, Optional,
|
|
6
|
+
from typing import Any, Dict, List, Optional, Union
|
|
8
7
|
|
|
9
8
|
import firebase
|
|
10
|
-
import grpc
|
|
11
9
|
import numpy as np
|
|
12
10
|
import requests
|
|
13
11
|
from eth_account.account import LocalAccount
|
|
@@ -18,8 +16,18 @@ from web3.logs import DISCARD
|
|
|
18
16
|
from . import utils
|
|
19
17
|
from .exceptions import OpenGradientError
|
|
20
18
|
from .proto import infer_pb2, infer_pb2_grpc
|
|
21
|
-
from .types import
|
|
22
|
-
|
|
19
|
+
from .types import (
|
|
20
|
+
LLM,
|
|
21
|
+
TEE_LLM,
|
|
22
|
+
HistoricalInputQuery,
|
|
23
|
+
InferenceMode,
|
|
24
|
+
LlmInferenceMode,
|
|
25
|
+
ModelOutput,
|
|
26
|
+
TextGenerationOutput,
|
|
27
|
+
SchedulerParams,
|
|
28
|
+
InferenceResult,
|
|
29
|
+
)
|
|
30
|
+
from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT, DEFAULT_SCHEDULER_ADDRESS
|
|
23
31
|
|
|
24
32
|
_FIREBASE_CONFIG = {
|
|
25
33
|
"apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
|
|
@@ -48,7 +56,7 @@ class Client:
|
|
|
48
56
|
_hub_user: Dict
|
|
49
57
|
_inference_abi: Dict
|
|
50
58
|
|
|
51
|
-
def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: str, password: str):
|
|
59
|
+
def __init__(self, private_key: str, rpc_url: str, contract_address: str, email: Optional[str], password: Optional[str]):
|
|
52
60
|
"""
|
|
53
61
|
Initialize the Client with private key, RPC URL, and contract address.
|
|
54
62
|
|
|
@@ -137,7 +145,7 @@ class Client:
|
|
|
137
145
|
logging.error(f"Unexpected error during model creation: {str(e)}")
|
|
138
146
|
raise
|
|
139
147
|
|
|
140
|
-
def create_version(self, model_name: str, notes: str =
|
|
148
|
+
def create_version(self, model_name: str, notes: str = "", is_major: bool = False) -> dict:
|
|
141
149
|
"""
|
|
142
150
|
Create a new version for the specified model.
|
|
143
151
|
|
|
@@ -273,9 +281,7 @@ class Client:
|
|
|
273
281
|
if hasattr(e, "response") and e.response is not None:
|
|
274
282
|
logging.error(f"Response status code: {e.response.status_code}")
|
|
275
283
|
logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
|
|
276
|
-
raise OpenGradientError(
|
|
277
|
-
f"Upload failed due to request exception: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
|
|
278
|
-
)
|
|
284
|
+
raise OpenGradientError(f"Upload failed due to request exception: {str(e)}")
|
|
279
285
|
except Exception as e:
|
|
280
286
|
logging.error(f"Unexpected error during upload: {str(e)}", exc_info=True)
|
|
281
287
|
raise OpenGradientError(f"Unexpected error during upload: {str(e)}")
|
|
@@ -286,7 +292,7 @@ class Client:
|
|
|
286
292
|
inference_mode: InferenceMode,
|
|
287
293
|
model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
|
|
288
294
|
max_retries: Optional[int] = None,
|
|
289
|
-
) ->
|
|
295
|
+
) -> InferenceResult:
|
|
290
296
|
"""
|
|
291
297
|
Perform inference on a model.
|
|
292
298
|
|
|
@@ -297,7 +303,7 @@ class Client:
|
|
|
297
303
|
max_retries (int, optional): Maximum number of retry attempts. Defaults to 5.
|
|
298
304
|
|
|
299
305
|
Returns:
|
|
300
|
-
|
|
306
|
+
InferenceResult: The transaction hash and the model output.
|
|
301
307
|
|
|
302
308
|
Raises:
|
|
303
309
|
OpenGradientError: If the inference fails.
|
|
@@ -306,7 +312,7 @@ class Client:
|
|
|
306
312
|
def execute_transaction():
|
|
307
313
|
contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
|
|
308
314
|
|
|
309
|
-
inference_mode_uint8 =
|
|
315
|
+
inference_mode_uint8 = inference_mode.value
|
|
310
316
|
converted_model_input = utils.convert_to_model_input(model_input)
|
|
311
317
|
|
|
312
318
|
run_function = contract.functions.run(model_cid, inference_mode_uint8, converted_model_input)
|
|
@@ -337,20 +343,21 @@ class Client:
|
|
|
337
343
|
|
|
338
344
|
# TODO: This should return a ModelOutput class object
|
|
339
345
|
model_output = utils.convert_to_model_output(parsed_logs[0]["args"])
|
|
340
|
-
|
|
346
|
+
|
|
347
|
+
return InferenceResult(tx_hash.hex(), model_output)
|
|
341
348
|
|
|
342
349
|
return run_with_retry(execute_transaction, max_retries)
|
|
343
350
|
|
|
344
351
|
def llm_completion(
|
|
345
352
|
self,
|
|
346
353
|
model_cid: LLM,
|
|
347
|
-
inference_mode:
|
|
354
|
+
inference_mode: LlmInferenceMode,
|
|
348
355
|
prompt: str,
|
|
349
356
|
max_tokens: int = 100,
|
|
350
357
|
stop_sequence: Optional[List[str]] = None,
|
|
351
358
|
temperature: float = 0.0,
|
|
352
359
|
max_retries: Optional[int] = None,
|
|
353
|
-
) ->
|
|
360
|
+
) -> TextGenerationOutput:
|
|
354
361
|
"""
|
|
355
362
|
Perform inference on an LLM model using completions.
|
|
356
363
|
|
|
@@ -363,7 +370,9 @@ class Client:
|
|
|
363
370
|
temperature (float): Temperature for LLM inference, between 0 and 1. Default is 0.0.
|
|
364
371
|
|
|
365
372
|
Returns:
|
|
366
|
-
|
|
373
|
+
TextGenerationOutput: Generated text results including:
|
|
374
|
+
- Transaction hash
|
|
375
|
+
- String of completion output
|
|
367
376
|
|
|
368
377
|
Raises:
|
|
369
378
|
OpenGradientError: If the inference fails.
|
|
@@ -381,7 +390,7 @@ class Client:
|
|
|
381
390
|
|
|
382
391
|
# Prepare LLM input
|
|
383
392
|
llm_request = {
|
|
384
|
-
"mode": inference_mode,
|
|
393
|
+
"mode": inference_mode.value,
|
|
385
394
|
"modelCID": model_cid,
|
|
386
395
|
"prompt": prompt,
|
|
387
396
|
"max_tokens": max_tokens,
|
|
@@ -418,14 +427,15 @@ class Client:
|
|
|
418
427
|
raise OpenGradientError("LLM completion result event not found in transaction logs")
|
|
419
428
|
|
|
420
429
|
llm_answer = parsed_logs[0]["args"]["response"]["answer"]
|
|
421
|
-
|
|
430
|
+
|
|
431
|
+
return TextGenerationOutput(transaction_hash=tx_hash.hex(), completion_output=llm_answer)
|
|
422
432
|
|
|
423
433
|
return run_with_retry(execute_transaction, max_retries)
|
|
424
434
|
|
|
425
435
|
def llm_chat(
|
|
426
436
|
self,
|
|
427
437
|
model_cid: LLM,
|
|
428
|
-
inference_mode:
|
|
438
|
+
inference_mode: LlmInferenceMode,
|
|
429
439
|
messages: List[Dict],
|
|
430
440
|
max_tokens: int = 100,
|
|
431
441
|
stop_sequence: Optional[List[str]] = None,
|
|
@@ -433,7 +443,7 @@ class Client:
|
|
|
433
443
|
tools: Optional[List[Dict]] = [],
|
|
434
444
|
tool_choice: Optional[str] = None,
|
|
435
445
|
max_retries: Optional[int] = None,
|
|
436
|
-
) ->
|
|
446
|
+
) -> TextGenerationOutput:
|
|
437
447
|
"""
|
|
438
448
|
Perform inference on an LLM model using chat.
|
|
439
449
|
|
|
@@ -485,7 +495,10 @@ class Client:
|
|
|
485
495
|
tool_choice (str, optional): Sets a specific tool to choose. Default value is "auto".
|
|
486
496
|
|
|
487
497
|
Returns:
|
|
488
|
-
|
|
498
|
+
TextGenerationOutput: Generated text results including:
|
|
499
|
+
- Transaction hash
|
|
500
|
+
- Finish reason (tool_call, stop, etc.)
|
|
501
|
+
- Dictionary of chat message output (role, content, tool_call, etc.)
|
|
489
502
|
|
|
490
503
|
Raises:
|
|
491
504
|
OpenGradientError: If the inference fails.
|
|
@@ -527,7 +540,7 @@ class Client:
|
|
|
527
540
|
|
|
528
541
|
# Prepare LLM input
|
|
529
542
|
llm_request = {
|
|
530
|
-
"mode": inference_mode,
|
|
543
|
+
"mode": inference_mode.value,
|
|
531
544
|
"modelCID": model_cid,
|
|
532
545
|
"messages": messages,
|
|
533
546
|
"max_tokens": max_tokens,
|
|
@@ -570,7 +583,11 @@ class Client:
|
|
|
570
583
|
if (tool_calls := message.get("tool_calls")) is not None:
|
|
571
584
|
message["tool_calls"] = [dict(tool_call) for tool_call in tool_calls]
|
|
572
585
|
|
|
573
|
-
return
|
|
586
|
+
return TextGenerationOutput(
|
|
587
|
+
transaction_hash=tx_hash.hex(),
|
|
588
|
+
finish_reason=llm_result["finish_reason"],
|
|
589
|
+
chat_output=message,
|
|
590
|
+
)
|
|
574
591
|
|
|
575
592
|
return run_with_retry(execute_transaction, max_retries)
|
|
576
593
|
|
|
@@ -611,233 +628,265 @@ class Client:
|
|
|
611
628
|
if hasattr(e, "response") and e.response is not None:
|
|
612
629
|
logging.error(f"Response status code: {e.response.status_code}")
|
|
613
630
|
logging.error(f"Response content: {e.response.text[:1000]}...") # Log first 1000 characters
|
|
614
|
-
raise OpenGradientError(
|
|
615
|
-
f"File listing failed: {str(e)}", status_code=e.response.status_code if hasattr(e, "response") else None
|
|
616
|
-
)
|
|
631
|
+
raise OpenGradientError(f"File listing failed: {str(e)}")
|
|
617
632
|
except Exception as e:
|
|
618
633
|
logging.error(f"Unexpected error during file listing: {str(e)}", exc_info=True)
|
|
619
634
|
raise OpenGradientError(f"Unexpected error during file listing: {str(e)}")
|
|
620
635
|
|
|
621
|
-
def generate_image(
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
) -> bytes:
|
|
636
|
+
# def generate_image(
|
|
637
|
+
# self,
|
|
638
|
+
# model_cid: str,
|
|
639
|
+
# prompt: str,
|
|
640
|
+
# host: str = DEFAULT_IMAGE_GEN_HOST,
|
|
641
|
+
# port: int = DEFAULT_IMAGE_GEN_PORT,
|
|
642
|
+
# width: int = 1024,
|
|
643
|
+
# height: int = 1024,
|
|
644
|
+
# timeout: int = 300, # 5 minute timeout
|
|
645
|
+
# max_retries: int = 3,
|
|
646
|
+
# ) -> bytes:
|
|
647
|
+
# """
|
|
648
|
+
# Generate an image using a diffusion model through gRPC.
|
|
649
|
+
|
|
650
|
+
# Args:
|
|
651
|
+
# model_cid (str): The model identifier (e.g. "stabilityai/stable-diffusion-xl-base-1.0")
|
|
652
|
+
# prompt (str): The text prompt to generate the image from
|
|
653
|
+
# host (str, optional): gRPC host address. Defaults to DEFAULT_IMAGE_GEN_HOST.
|
|
654
|
+
# port (int, optional): gRPC port number. Defaults to DEFAULT_IMAGE_GEN_PORT.
|
|
655
|
+
# width (int, optional): Output image width. Defaults to 1024.
|
|
656
|
+
# height (int, optional): Output image height. Defaults to 1024.
|
|
657
|
+
# timeout (int, optional): Maximum time to wait for generation in seconds. Defaults to 300.
|
|
658
|
+
# max_retries (int, optional): Maximum number of retry attempts. Defaults to 3.
|
|
659
|
+
|
|
660
|
+
# Returns:
|
|
661
|
+
# bytes: The raw image data bytes
|
|
662
|
+
|
|
663
|
+
# Raises:
|
|
664
|
+
# OpenGradientError: If the image generation fails
|
|
665
|
+
# TimeoutError: If the generation exceeds the timeout period
|
|
666
|
+
# """
|
|
667
|
+
|
|
668
|
+
# def exponential_backoff(attempt: int, max_delay: float = 30.0) -> None:
|
|
669
|
+
# """Calculate and sleep for exponential backoff duration"""
|
|
670
|
+
# delay = min(0.1 * (2**attempt), max_delay)
|
|
671
|
+
# time.sleep(delay)
|
|
672
|
+
|
|
673
|
+
# channel = None
|
|
674
|
+
# start_time = time.time()
|
|
675
|
+
# retry_count = 0
|
|
676
|
+
|
|
677
|
+
# try:
|
|
678
|
+
# while retry_count < max_retries:
|
|
679
|
+
# try:
|
|
680
|
+
# # Initialize gRPC channel and stub
|
|
681
|
+
# channel = grpc.insecure_channel(f"{host}:{port}")
|
|
682
|
+
# stub = infer_pb2_grpc.InferenceServiceStub(channel)
|
|
683
|
+
|
|
684
|
+
# # Create image generation request
|
|
685
|
+
# image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
|
|
686
|
+
|
|
687
|
+
# # Create inference request with random transaction ID
|
|
688
|
+
# tx_id = str(uuid.uuid4())
|
|
689
|
+
# request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
|
|
690
|
+
|
|
691
|
+
# # Send request with timeout
|
|
692
|
+
# response_id = stub.RunInferenceAsync(
|
|
693
|
+
# request,
|
|
694
|
+
# timeout=min(30, timeout), # Initial request timeout
|
|
695
|
+
# )
|
|
696
|
+
|
|
697
|
+
# # Poll for completion
|
|
698
|
+
# attempt = 0
|
|
699
|
+
# while True:
|
|
700
|
+
# # Check timeout
|
|
701
|
+
# if time.time() - start_time > timeout:
|
|
702
|
+
# raise TimeoutError(f"Image generation timed out after {timeout} seconds")
|
|
703
|
+
|
|
704
|
+
# status_request = infer_pb2.InferenceTxId(id=response_id.id)
|
|
705
|
+
# try:
|
|
706
|
+
# status = stub.GetInferenceStatus(
|
|
707
|
+
# status_request,
|
|
708
|
+
# timeout=min(5, timeout), # Status check timeout
|
|
709
|
+
# ).status
|
|
710
|
+
# except grpc.RpcError as e:
|
|
711
|
+
# logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
|
|
712
|
+
# exponential_backoff(attempt)
|
|
713
|
+
# attempt += 1
|
|
714
|
+
# continue
|
|
715
|
+
|
|
716
|
+
# if status == infer_pb2.InferenceStatus.STATUS_COMPLETED:
|
|
717
|
+
# break
|
|
718
|
+
# elif status == infer_pb2.InferenceStatus.STATUS_ERROR:
|
|
719
|
+
# raise OpenGradientError("Image generation failed on server")
|
|
720
|
+
# elif status != infer_pb2.InferenceStatus.STATUS_IN_PROGRESS:
|
|
721
|
+
# raise OpenGradientError(f"Unexpected status: {status}")
|
|
722
|
+
|
|
723
|
+
# exponential_backoff(attempt)
|
|
724
|
+
# attempt += 1
|
|
725
|
+
|
|
726
|
+
# # Get result
|
|
727
|
+
# result = stub.GetInferenceResult(
|
|
728
|
+
# response_id,
|
|
729
|
+
# timeout=min(30, timeout), # Result fetch timeout
|
|
730
|
+
# )
|
|
731
|
+
# return result.image_generation_result.image_data
|
|
732
|
+
|
|
733
|
+
# except (grpc.RpcError, TimeoutError) as e:
|
|
734
|
+
# retry_count += 1
|
|
735
|
+
# if retry_count >= max_retries:
|
|
736
|
+
# raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
|
|
737
|
+
|
|
738
|
+
# logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
|
|
739
|
+
# exponential_backoff(retry_count)
|
|
740
|
+
|
|
741
|
+
# except grpc.RpcError as e:
|
|
742
|
+
# logging.error(f"gRPC error: {str(e)}")
|
|
743
|
+
# raise OpenGradientError(f"Image generation failed: {str(e)}")
|
|
744
|
+
# except TimeoutError as e:
|
|
745
|
+
# logging.error(f"Timeout error: {str(e)}")
|
|
746
|
+
# raise
|
|
747
|
+
# except Exception as e:
|
|
748
|
+
# logging.error(f"Error in generate image method: {str(e)}", exc_info=True)
|
|
749
|
+
# raise OpenGradientError(f"Image generation failed: {str(e)}")
|
|
750
|
+
# finally:
|
|
751
|
+
# if channel:
|
|
752
|
+
# channel.close()
|
|
753
|
+
|
|
754
|
+
def _get_abi(self, abi_name) -> List[Dict]:
|
|
632
755
|
"""
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
Args:
|
|
636
|
-
model_cid (str): The model identifier (e.g. "stabilityai/stable-diffusion-xl-base-1.0")
|
|
637
|
-
prompt (str): The text prompt to generate the image from
|
|
638
|
-
host (str, optional): gRPC host address. Defaults to DEFAULT_IMAGE_GEN_HOST.
|
|
639
|
-
port (int, optional): gRPC port number. Defaults to DEFAULT_IMAGE_GEN_PORT.
|
|
640
|
-
width (int, optional): Output image width. Defaults to 1024.
|
|
641
|
-
height (int, optional): Output image height. Defaults to 1024.
|
|
642
|
-
timeout (int, optional): Maximum time to wait for generation in seconds. Defaults to 300.
|
|
643
|
-
max_retries (int, optional): Maximum number of retry attempts. Defaults to 3.
|
|
644
|
-
|
|
645
|
-
Returns:
|
|
646
|
-
bytes: The raw image data bytes
|
|
647
|
-
|
|
648
|
-
Raises:
|
|
649
|
-
OpenGradientError: If the image generation fails
|
|
650
|
-
TimeoutError: If the generation exceeds the timeout period
|
|
756
|
+
Returns the ABI for the requested contract.
|
|
651
757
|
"""
|
|
758
|
+
abi_path = Path(__file__).parent / "abi" / abi_name
|
|
759
|
+
with open(abi_path, "r") as f:
|
|
760
|
+
return json.load(f)
|
|
652
761
|
|
|
653
|
-
|
|
654
|
-
"""Calculate and sleep for exponential backoff duration"""
|
|
655
|
-
delay = min(0.1 * (2**attempt), max_delay)
|
|
656
|
-
time.sleep(delay)
|
|
657
|
-
|
|
658
|
-
channel = None
|
|
659
|
-
start_time = time.time()
|
|
660
|
-
retry_count = 0
|
|
661
|
-
|
|
662
|
-
try:
|
|
663
|
-
while retry_count < max_retries:
|
|
664
|
-
try:
|
|
665
|
-
# Initialize gRPC channel and stub
|
|
666
|
-
channel = grpc.insecure_channel(f"{host}:{port}")
|
|
667
|
-
stub = infer_pb2_grpc.InferenceServiceStub(channel)
|
|
668
|
-
|
|
669
|
-
# Create image generation request
|
|
670
|
-
image_request = infer_pb2.ImageGenerationRequest(model=model_cid, prompt=prompt, height=height, width=width)
|
|
671
|
-
|
|
672
|
-
# Create inference request with random transaction ID
|
|
673
|
-
tx_id = str(uuid.uuid4())
|
|
674
|
-
request = infer_pb2.InferenceRequest(tx=tx_id, image_generation=image_request)
|
|
675
|
-
|
|
676
|
-
# Send request with timeout
|
|
677
|
-
response_id = stub.RunInferenceAsync(
|
|
678
|
-
request,
|
|
679
|
-
timeout=min(30, timeout), # Initial request timeout
|
|
680
|
-
)
|
|
681
|
-
|
|
682
|
-
# Poll for completion
|
|
683
|
-
attempt = 0
|
|
684
|
-
while True:
|
|
685
|
-
# Check timeout
|
|
686
|
-
if time.time() - start_time > timeout:
|
|
687
|
-
raise TimeoutError(f"Image generation timed out after {timeout} seconds")
|
|
688
|
-
|
|
689
|
-
status_request = infer_pb2.InferenceTxId(id=response_id.id)
|
|
690
|
-
try:
|
|
691
|
-
status = stub.GetInferenceStatus(
|
|
692
|
-
status_request,
|
|
693
|
-
timeout=min(5, timeout), # Status check timeout
|
|
694
|
-
).status
|
|
695
|
-
except grpc.RpcError as e:
|
|
696
|
-
logging.warning(f"Status check failed (attempt {attempt}): {str(e)}")
|
|
697
|
-
exponential_backoff(attempt)
|
|
698
|
-
attempt += 1
|
|
699
|
-
continue
|
|
700
|
-
|
|
701
|
-
if status == infer_pb2.InferenceStatus.STATUS_COMPLETED:
|
|
702
|
-
break
|
|
703
|
-
elif status == infer_pb2.InferenceStatus.STATUS_ERROR:
|
|
704
|
-
raise OpenGradientError("Image generation failed on server")
|
|
705
|
-
elif status != infer_pb2.InferenceStatus.STATUS_IN_PROGRESS:
|
|
706
|
-
raise OpenGradientError(f"Unexpected status: {status}")
|
|
707
|
-
|
|
708
|
-
exponential_backoff(attempt)
|
|
709
|
-
attempt += 1
|
|
710
|
-
|
|
711
|
-
# Get result
|
|
712
|
-
result = stub.GetInferenceResult(
|
|
713
|
-
response_id,
|
|
714
|
-
timeout=min(30, timeout), # Result fetch timeout
|
|
715
|
-
)
|
|
716
|
-
return result.image_generation_result.image_data
|
|
717
|
-
|
|
718
|
-
except (grpc.RpcError, TimeoutError) as e:
|
|
719
|
-
retry_count += 1
|
|
720
|
-
if retry_count >= max_retries:
|
|
721
|
-
raise OpenGradientError(f"Image generation failed after {max_retries} retries: {str(e)}")
|
|
722
|
-
|
|
723
|
-
logging.warning(f"Attempt {retry_count} failed: {str(e)}. Retrying...")
|
|
724
|
-
exponential_backoff(retry_count)
|
|
725
|
-
|
|
726
|
-
except grpc.RpcError as e:
|
|
727
|
-
logging.error(f"gRPC error: {str(e)}")
|
|
728
|
-
raise OpenGradientError(f"Image generation failed: {str(e)}")
|
|
729
|
-
except TimeoutError as e:
|
|
730
|
-
logging.error(f"Timeout error: {str(e)}")
|
|
731
|
-
raise
|
|
732
|
-
except Exception as e:
|
|
733
|
-
logging.error(f"Error in generate image method: {str(e)}", exc_info=True)
|
|
734
|
-
raise OpenGradientError(f"Image generation failed: {str(e)}")
|
|
735
|
-
finally:
|
|
736
|
-
if channel:
|
|
737
|
-
channel.close()
|
|
738
|
-
|
|
739
|
-
def _get_model_executor_abi(self) -> List[Dict]:
|
|
762
|
+
def _get_bin(self, bin_name) -> List[Dict]:
|
|
740
763
|
"""
|
|
741
|
-
Returns the
|
|
764
|
+
Returns the bin for the requested contract.
|
|
742
765
|
"""
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
766
|
+
bin_path = Path(__file__).parent / "bin" / bin_name
|
|
767
|
+
# Read bytecode with explicit encoding
|
|
768
|
+
with open(bin_path, "r", encoding="utf-8") as f:
|
|
769
|
+
bytecode = f.read().strip()
|
|
770
|
+
if not bytecode.startswith("0x"):
|
|
771
|
+
bytecode = "0x" + bytecode
|
|
772
|
+
return bytecode
|
|
746
773
|
|
|
747
774
|
def new_workflow(
|
|
748
775
|
self,
|
|
749
776
|
model_cid: str,
|
|
750
|
-
input_query:
|
|
777
|
+
input_query: HistoricalInputQuery,
|
|
751
778
|
input_tensor_name: str,
|
|
752
779
|
scheduler_params: Optional[SchedulerParams] = None,
|
|
753
780
|
) -> str:
|
|
754
781
|
"""
|
|
755
782
|
Deploy a new workflow contract with the specified parameters.
|
|
756
|
-
"""
|
|
757
|
-
if isinstance(input_query, dict):
|
|
758
|
-
input_query = HistoricalInputQuery.from_dict(input_query)
|
|
759
783
|
|
|
784
|
+
This function deploys a new workflow contract and optionally registers it with
|
|
785
|
+
the scheduler for automated execution. If scheduler_params is not provided,
|
|
786
|
+
the workflow will be deployed without automated execution scheduling.
|
|
787
|
+
|
|
788
|
+
Args:
|
|
789
|
+
model_cid (str): IPFS CID of the model to be executed
|
|
790
|
+
input_query (HistoricalInputQuery): Query parameters for data input
|
|
791
|
+
input_tensor_name (str): Name of the input tensor expected by the model
|
|
792
|
+
scheduler_params (Optional[SchedulerParams]): Scheduler configuration for automated execution:
|
|
793
|
+
- frequency: Execution frequency in seconds
|
|
794
|
+
- duration_hours: How long to run in hours
|
|
795
|
+
|
|
796
|
+
Returns:
|
|
797
|
+
str: Deployed contract address. If scheduler_params was provided, the workflow
|
|
798
|
+
will be automatically executed according to the specified schedule.
|
|
799
|
+
|
|
800
|
+
Raises:
|
|
801
|
+
Exception: If transaction fails or gas estimation fails
|
|
802
|
+
"""
|
|
760
803
|
# Get contract ABI and bytecode
|
|
761
|
-
abi = self.
|
|
762
|
-
|
|
804
|
+
abi = self._get_abi("PriceHistoryInference.abi")
|
|
805
|
+
bytecode = self._get_bin("PriceHistoryInference.bin")
|
|
763
806
|
|
|
764
|
-
|
|
765
|
-
|
|
807
|
+
def deploy_transaction():
|
|
808
|
+
contract = self._blockchain.eth.contract(abi=abi, bytecode=bytecode)
|
|
809
|
+
query_tuple = input_query.to_abi_format()
|
|
810
|
+
constructor_args = [model_cid, input_tensor_name, query_tuple]
|
|
766
811
|
|
|
767
|
-
|
|
812
|
+
try:
|
|
813
|
+
# Estimate gas needed
|
|
814
|
+
estimated_gas = contract.constructor(*constructor_args).estimate_gas({"from": self._wallet_account.address})
|
|
815
|
+
gas_limit = int(estimated_gas * 1.2)
|
|
816
|
+
except Exception as e:
|
|
817
|
+
print(f"⚠️ Gas estimation failed: {str(e)}")
|
|
818
|
+
gas_limit = 5000000 # Conservative fallback
|
|
819
|
+
print(f"📊 Using fallback gas limit: {gas_limit}")
|
|
768
820
|
|
|
769
|
-
|
|
770
|
-
|
|
821
|
+
transaction = contract.constructor(*constructor_args).build_transaction(
|
|
822
|
+
{
|
|
823
|
+
"from": self._wallet_account.address,
|
|
824
|
+
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
825
|
+
"gas": gas_limit,
|
|
826
|
+
"gasPrice": self._blockchain.eth.gas_price,
|
|
827
|
+
"chainId": self._blockchain.eth.chain_id,
|
|
828
|
+
}
|
|
829
|
+
)
|
|
771
830
|
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
{
|
|
775
|
-
"from": self._wallet_account.address,
|
|
776
|
-
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
777
|
-
"gas": 15000000,
|
|
778
|
-
"gasPrice": self._blockchain.eth.gas_price,
|
|
779
|
-
"chainId": self._blockchain.eth.chain_id,
|
|
780
|
-
}
|
|
781
|
-
)
|
|
831
|
+
signed_txn = self._wallet_account.sign_transaction(transaction)
|
|
832
|
+
tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)
|
|
782
833
|
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
834
|
+
tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=60)
|
|
835
|
+
|
|
836
|
+
if tx_receipt["status"] == 0:
|
|
837
|
+
raise Exception(f"❌ Contract deployment failed, transaction hash: {tx_hash.hex()}")
|
|
787
838
|
|
|
788
|
-
|
|
839
|
+
return tx_receipt.contractAddress
|
|
840
|
+
|
|
841
|
+
contract_address = run_with_retry(deploy_transaction)
|
|
789
842
|
|
|
790
|
-
# Register with scheduler if params provided
|
|
791
843
|
if scheduler_params:
|
|
792
|
-
|
|
793
|
-
print(f" • Frequency: Every {scheduler_params.frequency} seconds")
|
|
794
|
-
print(f" • Duration: {scheduler_params.duration_hours} hours")
|
|
795
|
-
print(f" • End Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(scheduler_params.end_time))}")
|
|
844
|
+
self._register_with_scheduler(contract_address, scheduler_params)
|
|
796
845
|
|
|
797
|
-
|
|
798
|
-
{
|
|
799
|
-
"inputs": [
|
|
800
|
-
{"internalType": "address", "name": "contractAddress", "type": "address"},
|
|
801
|
-
{"internalType": "uint256", "name": "endTime", "type": "uint256"},
|
|
802
|
-
{"internalType": "uint256", "name": "frequency", "type": "uint256"},
|
|
803
|
-
],
|
|
804
|
-
"name": "registerTask",
|
|
805
|
-
"outputs": [],
|
|
806
|
-
"stateMutability": "nonpayable",
|
|
807
|
-
"type": "function",
|
|
808
|
-
}
|
|
809
|
-
]
|
|
846
|
+
return contract_address
|
|
810
847
|
|
|
811
|
-
|
|
812
|
-
|
|
848
|
+
def _register_with_scheduler(self, contract_address: str, scheduler_params: SchedulerParams) -> None:
|
|
849
|
+
"""
|
|
850
|
+
Register the deployed workflow contract with the scheduler for automated execution.
|
|
813
851
|
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
"from": self._wallet_account.address,
|
|
821
|
-
"gas": 300000,
|
|
822
|
-
"gasPrice": self._blockchain.eth.gas_price,
|
|
823
|
-
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
824
|
-
"chainId": self._blockchain.eth.chain_id,
|
|
825
|
-
}
|
|
826
|
-
)
|
|
852
|
+
Args:
|
|
853
|
+
contract_address (str): Address of the deployed workflow contract
|
|
854
|
+
scheduler_params (SchedulerParams): Scheduler configuration containing:
|
|
855
|
+
- frequency: Execution frequency in seconds
|
|
856
|
+
- duration_hours: How long to run in hours
|
|
857
|
+
- end_time: Unix timestamp when scheduling should end
|
|
827
858
|
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
859
|
+
Raises:
|
|
860
|
+
Exception: If registration with scheduler fails. The workflow contract will
|
|
861
|
+
still be deployed and can be executed manually.
|
|
862
|
+
"""
|
|
831
863
|
|
|
832
|
-
|
|
833
|
-
print(f" Transaction hash: {scheduler_tx_hash.hex()}")
|
|
864
|
+
scheduler_abi = self._get_abi("WorkflowScheduler.abi")
|
|
834
865
|
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
print(" The workflow contract is still deployed and can be executed manually.")
|
|
866
|
+
# Scheduler contract address
|
|
867
|
+
scheduler_address = DEFAULT_SCHEDULER_ADDRESS
|
|
868
|
+
scheduler_contract = self._blockchain.eth.contract(address=scheduler_address, abi=scheduler_abi)
|
|
839
869
|
|
|
840
|
-
|
|
870
|
+
try:
|
|
871
|
+
# Register the workflow with the scheduler
|
|
872
|
+
scheduler_tx = scheduler_contract.functions.registerTask(
|
|
873
|
+
contract_address, scheduler_params.end_time, scheduler_params.frequency
|
|
874
|
+
).build_transaction(
|
|
875
|
+
{
|
|
876
|
+
"from": self._wallet_account.address,
|
|
877
|
+
"gas": 300000,
|
|
878
|
+
"gasPrice": self._blockchain.eth.gas_price,
|
|
879
|
+
"nonce": self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending"),
|
|
880
|
+
"chainId": self._blockchain.eth.chain_id,
|
|
881
|
+
}
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
signed_scheduler_tx = self._wallet_account.sign_transaction(scheduler_tx)
|
|
885
|
+
scheduler_tx_hash = self._blockchain.eth.send_raw_transaction(signed_scheduler_tx.raw_transaction)
|
|
886
|
+
self._blockchain.eth.wait_for_transaction_receipt(scheduler_tx_hash, timeout=REGULAR_TX_TIMEOUT)
|
|
887
|
+
except Exception as e:
|
|
888
|
+
print(f"❌ Error registering contract with scheduler: {str(e)}")
|
|
889
|
+
print(" The workflow contract is still deployed and can be executed manually.")
|
|
841
890
|
|
|
842
891
|
def read_workflow_result(self, contract_address: str) -> ModelOutput:
|
|
843
892
|
"""
|
|
@@ -854,7 +903,9 @@ class Client:
|
|
|
854
903
|
Web3Error: If there are issues with the web3 connection or contract interaction
|
|
855
904
|
"""
|
|
856
905
|
# Get the contract interface
|
|
857
|
-
contract = self._blockchain.eth.contract(
|
|
906
|
+
contract = self._blockchain.eth.contract(
|
|
907
|
+
address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
|
|
908
|
+
)
|
|
858
909
|
|
|
859
910
|
# Get the result
|
|
860
911
|
result = contract.functions.getInferenceResult().call()
|
|
@@ -876,7 +927,9 @@ class Client:
|
|
|
876
927
|
Web3Error: If there are issues with the web3 connection or contract interaction
|
|
877
928
|
"""
|
|
878
929
|
# Get the contract interface
|
|
879
|
-
contract = self._blockchain.eth.contract(
|
|
930
|
+
contract = self._blockchain.eth.contract(
|
|
931
|
+
address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
|
|
932
|
+
)
|
|
880
933
|
|
|
881
934
|
# Call run() function
|
|
882
935
|
nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")
|
|
@@ -904,6 +957,31 @@ class Client:
|
|
|
904
957
|
|
|
905
958
|
return utils.convert_array_to_model_output(result)
|
|
906
959
|
|
|
960
|
+
def read_workflow_history(self, contract_address: str, num_results: int) -> List[Dict]:
|
|
961
|
+
"""
|
|
962
|
+
Gets historical inference results from a workflow contract.
|
|
963
|
+
|
|
964
|
+
Retrieves the specified number of most recent inference results from the contract's
|
|
965
|
+
storage, with the most recent result first.
|
|
966
|
+
|
|
967
|
+
Args:
|
|
968
|
+
contract_address (str): Address of the deployed workflow contract
|
|
969
|
+
num_results (int): Number of historical results to retrieve
|
|
970
|
+
|
|
971
|
+
Returns:
|
|
972
|
+
List[Dict]: List of historical inference results, each containing:
|
|
973
|
+
- prediction values
|
|
974
|
+
- timestamps
|
|
975
|
+
- any additional metadata stored with the result
|
|
976
|
+
|
|
977
|
+
"""
|
|
978
|
+
contract = self._blockchain.eth.contract(
|
|
979
|
+
address=Web3.to_checksum_address(contract_address), abi=self._get_abi("PriceHistoryInference.abi")
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
results = contract.functions.getLastInferenceResults(num_results).call()
|
|
983
|
+
return [utils.convert_array_to_model_output(result) for result in results]
|
|
984
|
+
|
|
907
985
|
|
|
908
986
|
def run_with_retry(txn_function, max_retries=DEFAULT_MAX_RETRY, retry_delay=DEFAULT_RETRY_DELAY_SEC):
|
|
909
987
|
"""
|