opengradient 0.5.7__py3-none-any.whl → 0.5.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opengradient/__init__.py +34 -6
- opengradient/cli.py +155 -55
- opengradient/client.py +429 -146
- opengradient/defaults.py +3 -1
- opengradient/llm/og_langchain.py +6 -1
- opengradient/types.py +229 -11
- opengradient/x402_auth.py +60 -0
- {opengradient-0.5.7.dist-info → opengradient-0.5.9.dist-info}/METADATA +6 -3
- {opengradient-0.5.7.dist-info → opengradient-0.5.9.dist-info}/RECORD +13 -12
- {opengradient-0.5.7.dist-info → opengradient-0.5.9.dist-info}/WHEEL +1 -1
- {opengradient-0.5.7.dist-info → opengradient-0.5.9.dist-info}/entry_points.txt +0 -0
- {opengradient-0.5.7.dist-info → opengradient-0.5.9.dist-info}/licenses/LICENSE +0 -0
- {opengradient-0.5.7.dist-info → opengradient-0.5.9.dist-info}/top_level.txt +0 -0
opengradient/client.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union, Callable
|
|
|
9
9
|
import firebase
|
|
10
10
|
import numpy as np
|
|
11
11
|
import requests
|
|
12
|
+
import httpx
|
|
12
13
|
from eth_account.account import LocalAccount
|
|
13
14
|
from web3 import Web3
|
|
14
15
|
from web3.exceptions import ContractLogicError
|
|
@@ -17,7 +18,9 @@ import urllib.parse
|
|
|
17
18
|
import asyncio
|
|
18
19
|
from x402.clients.httpx import x402HttpxClient
|
|
19
20
|
from x402.clients.base import decode_x_payment_response, x402Client
|
|
21
|
+
from x402.clients.httpx import x402HttpxClient
|
|
20
22
|
|
|
23
|
+
from .x402_auth import X402Auth
|
|
21
24
|
from .exceptions import OpenGradientError
|
|
22
25
|
from .proto import infer_pb2, infer_pb2_grpc
|
|
23
26
|
from .types import (
|
|
@@ -29,17 +32,22 @@ from .types import (
|
|
|
29
32
|
LlmInferenceMode,
|
|
30
33
|
ModelOutput,
|
|
31
34
|
TextGenerationOutput,
|
|
35
|
+
TextGenerationStream,
|
|
32
36
|
SchedulerParams,
|
|
33
37
|
InferenceResult,
|
|
34
38
|
ModelRepository,
|
|
35
39
|
FileUploadResult,
|
|
40
|
+
StreamChunk,
|
|
36
41
|
)
|
|
37
42
|
from .defaults import (
|
|
38
|
-
DEFAULT_IMAGE_GEN_HOST,
|
|
39
|
-
DEFAULT_IMAGE_GEN_PORT,
|
|
43
|
+
DEFAULT_IMAGE_GEN_HOST,
|
|
44
|
+
DEFAULT_IMAGE_GEN_PORT,
|
|
40
45
|
DEFAULT_SCHEDULER_ADDRESS,
|
|
41
|
-
DEFAULT_LLM_SERVER_URL,
|
|
42
|
-
DEFAULT_OPENGRADIENT_LLM_SERVER_URL
|
|
46
|
+
DEFAULT_LLM_SERVER_URL,
|
|
47
|
+
DEFAULT_OPENGRADIENT_LLM_SERVER_URL,
|
|
48
|
+
DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL,
|
|
49
|
+
DEFAULT_NETWORK_FILTER,
|
|
50
|
+
)
|
|
43
51
|
from .utils import convert_array_to_model_output, convert_to_model_input, convert_to_model_output
|
|
44
52
|
|
|
45
53
|
_FIREBASE_CONFIG = {
|
|
@@ -65,6 +73,19 @@ PRECOMPILE_CONTRACT_ADDRESS = "0x00000000000000000000000000000000000000F4"
|
|
|
65
73
|
X402_PROCESSING_HASH_HEADER = "x-processing-hash"
|
|
66
74
|
X402_PLACEHOLDER_API_KEY = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"
|
|
67
75
|
|
|
76
|
+
TIMEOUT = httpx.Timeout(
|
|
77
|
+
timeout=90.0,
|
|
78
|
+
connect=15.0,
|
|
79
|
+
read=15.0,
|
|
80
|
+
write=30.0,
|
|
81
|
+
pool=10.0,
|
|
82
|
+
)
|
|
83
|
+
LIMITS = httpx.Limits(
|
|
84
|
+
max_keepalive_connections=100,
|
|
85
|
+
max_connections=500,
|
|
86
|
+
keepalive_expiry=60 * 20, # 20 minutes
|
|
87
|
+
)
|
|
88
|
+
|
|
68
89
|
class Client:
|
|
69
90
|
_inference_hub_contract_address: str
|
|
70
91
|
_blockchain: Web3
|
|
@@ -76,20 +97,22 @@ class Client:
|
|
|
76
97
|
_precompile_abi: Dict
|
|
77
98
|
_llm_server_url: str
|
|
78
99
|
_external_api_keys: Dict[str, str]
|
|
100
|
+
|
|
79
101
|
def __init__(
|
|
80
|
-
self,
|
|
81
|
-
private_key: str,
|
|
82
|
-
rpc_url: str,
|
|
83
|
-
api_url: str,
|
|
84
|
-
contract_address: str,
|
|
85
|
-
email: Optional[str] = None,
|
|
86
|
-
password: Optional[str] = None,
|
|
102
|
+
self,
|
|
103
|
+
private_key: str,
|
|
104
|
+
rpc_url: str,
|
|
105
|
+
api_url: str,
|
|
106
|
+
contract_address: str,
|
|
107
|
+
email: Optional[str] = None,
|
|
108
|
+
password: Optional[str] = None,
|
|
87
109
|
llm_server_url: Optional[str] = DEFAULT_LLM_SERVER_URL,
|
|
88
110
|
og_llm_server_url: Optional[str] = DEFAULT_OPENGRADIENT_LLM_SERVER_URL,
|
|
111
|
+
og_llm_streaming_server_url: Optional[str] = DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL,
|
|
89
112
|
openai_api_key: Optional[str] = None,
|
|
90
113
|
anthropic_api_key: Optional[str] = None,
|
|
91
114
|
google_api_key: Optional[str] = None,
|
|
92
|
-
|
|
115
|
+
):
|
|
93
116
|
"""
|
|
94
117
|
Initialize the Client with private key, RPC URL, and contract address.
|
|
95
118
|
|
|
@@ -120,7 +143,8 @@ class Client:
|
|
|
120
143
|
|
|
121
144
|
self._llm_server_url = llm_server_url
|
|
122
145
|
self._og_llm_server_url = og_llm_server_url
|
|
123
|
-
|
|
146
|
+
self._og_llm_streaming_server_url = og_llm_streaming_server_url
|
|
147
|
+
|
|
124
148
|
self._external_api_keys = {}
|
|
125
149
|
if openai_api_key or os.getenv("OPENAI_API_KEY"):
|
|
126
150
|
self._external_api_keys["openai"] = openai_api_key or os.getenv("OPENAI_API_KEY")
|
|
@@ -132,7 +156,7 @@ class Client:
|
|
|
132
156
|
def set_api_key(self, provider: str, api_key: str):
|
|
133
157
|
"""
|
|
134
158
|
Set or update API key for an external provider.
|
|
135
|
-
|
|
159
|
+
|
|
136
160
|
Args:
|
|
137
161
|
provider: Provider name (e.g., 'openai', 'anthropic', 'google')
|
|
138
162
|
api_key: The API key for the provider
|
|
@@ -142,10 +166,10 @@ class Client:
|
|
|
142
166
|
def _is_local_model(self, model_cid: str) -> bool:
|
|
143
167
|
"""
|
|
144
168
|
Check if a model is hosted locally on OpenGradient.
|
|
145
|
-
|
|
169
|
+
|
|
146
170
|
Args:
|
|
147
171
|
model_cid: Model identifier
|
|
148
|
-
|
|
172
|
+
|
|
149
173
|
Returns:
|
|
150
174
|
True if model is local, False if it should use external provider
|
|
151
175
|
"""
|
|
@@ -158,7 +182,7 @@ class Client:
|
|
|
158
182
|
def _get_provider_from_model(self, model: str) -> str:
|
|
159
183
|
"""Infer provider from model name."""
|
|
160
184
|
model_lower = model.lower()
|
|
161
|
-
|
|
185
|
+
|
|
162
186
|
if "gpt" in model_lower or model.startswith("openai/"):
|
|
163
187
|
return "openai"
|
|
164
188
|
elif "claude" in model_lower or model.startswith("anthropic/"):
|
|
@@ -173,10 +197,10 @@ class Client:
|
|
|
173
197
|
def _get_api_key_for_model(self, model: str) -> Optional[str]:
|
|
174
198
|
"""
|
|
175
199
|
Get the appropriate API key for a model.
|
|
176
|
-
|
|
200
|
+
|
|
177
201
|
Args:
|
|
178
202
|
model: Model identifier
|
|
179
|
-
|
|
203
|
+
|
|
180
204
|
Returns:
|
|
181
205
|
API key string or None
|
|
182
206
|
"""
|
|
@@ -418,11 +442,11 @@ class Client:
|
|
|
418
442
|
|
|
419
443
|
return run_with_retry(execute_transaction, max_retries)
|
|
420
444
|
|
|
421
|
-
def _og_payment_selector(self, accepts, network_filter=
|
|
422
|
-
"""Custom payment selector for OpenGradient network
|
|
445
|
+
def _og_payment_selector(self, accepts, network_filter=DEFAULT_NETWORK_FILTER, scheme_filter=None, max_value=None):
|
|
446
|
+
"""Custom payment selector for OpenGradient network."""
|
|
423
447
|
return x402Client.default_payment_requirements_selector(
|
|
424
448
|
accepts,
|
|
425
|
-
network_filter=
|
|
449
|
+
network_filter=network_filter,
|
|
426
450
|
scheme_filter=scheme_filter,
|
|
427
451
|
max_value=max_value,
|
|
428
452
|
)
|
|
@@ -451,11 +475,17 @@ class Client:
|
|
|
451
475
|
temperature (float): Temperature for LLM inference, between 0 and 1. Default is 0.0.
|
|
452
476
|
max_retries (int, optional): Maximum number of retry attempts for blockchain transactions.
|
|
453
477
|
local_model (bool, optional): Force use of local model even if not in LLM enum.
|
|
478
|
+
x402_settlement_mode (x402SettlementMode, optional): Settlement mode for x402 payments.
|
|
479
|
+
- SETTLE: Records input/output hashes only (most privacy-preserving).
|
|
480
|
+
- SETTLE_BATCH: Aggregates multiple inferences into batch hashes (most cost-efficient).
|
|
481
|
+
- SETTLE_METADATA: Records full model info, complete input/output data, and all metadata.
|
|
482
|
+
Defaults to SETTLE_BATCH.
|
|
454
483
|
|
|
455
484
|
Returns:
|
|
456
485
|
TextGenerationOutput: Generated text results including:
|
|
457
486
|
- Transaction hash (or "external" for external providers)
|
|
458
487
|
- String of completion output
|
|
488
|
+
- Payment hash for x402 transactions (when using x402 settlement)
|
|
459
489
|
|
|
460
490
|
Raises:
|
|
461
491
|
OpenGradientError: If the inference fails.
|
|
@@ -467,14 +497,14 @@ class Client:
|
|
|
467
497
|
return OpenGradientError("That model CID is not supported yet for TEE inference")
|
|
468
498
|
|
|
469
499
|
return self._external_llm_completion(
|
|
470
|
-
model=model_cid.split(
|
|
500
|
+
model=model_cid.split("/")[1],
|
|
471
501
|
prompt=prompt,
|
|
472
502
|
max_tokens=max_tokens,
|
|
473
503
|
stop_sequence=stop_sequence,
|
|
474
504
|
temperature=temperature,
|
|
475
505
|
x402_settlement_mode=x402_settlement_mode,
|
|
476
506
|
)
|
|
477
|
-
|
|
507
|
+
|
|
478
508
|
# Original local model logic
|
|
479
509
|
def execute_transaction():
|
|
480
510
|
if inference_mode != LlmInferenceMode.VANILLA:
|
|
@@ -482,10 +512,10 @@ class Client:
|
|
|
482
512
|
|
|
483
513
|
if model_cid not in [llm.value for llm in LLM]:
|
|
484
514
|
raise OpenGradientError("That model CID is not yet supported for inference")
|
|
485
|
-
|
|
515
|
+
|
|
486
516
|
model_name = model_cid
|
|
487
517
|
if model_cid in [llm.value for llm in TEE_LLM]:
|
|
488
|
-
model_name = model_cid.split(
|
|
518
|
+
model_name = model_cid.split("/")[1]
|
|
489
519
|
|
|
490
520
|
contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
|
|
491
521
|
|
|
@@ -523,55 +553,49 @@ class Client:
|
|
|
523
553
|
) -> TextGenerationOutput:
|
|
524
554
|
"""
|
|
525
555
|
Route completion request to external LLM server with x402 payments.
|
|
526
|
-
|
|
556
|
+
|
|
527
557
|
Args:
|
|
528
558
|
model: Model identifier
|
|
529
559
|
prompt: Input prompt
|
|
530
560
|
max_tokens: Maximum tokens to generate
|
|
531
561
|
stop_sequence: Stop sequences
|
|
532
562
|
temperature: Sampling temperature
|
|
533
|
-
|
|
563
|
+
|
|
534
564
|
Returns:
|
|
535
565
|
TextGenerationOutput with completion
|
|
536
|
-
|
|
566
|
+
|
|
537
567
|
Raises:
|
|
538
568
|
OpenGradientError: If request fails
|
|
539
569
|
"""
|
|
540
570
|
api_key = self._get_api_key_for_model(model)
|
|
541
|
-
|
|
571
|
+
|
|
542
572
|
if api_key:
|
|
543
573
|
logging.debug("External LLM completions using API key")
|
|
544
574
|
url = f"{self._llm_server_url}/v1/completions"
|
|
545
|
-
|
|
546
|
-
headers = {
|
|
547
|
-
|
|
548
|
-
"Authorization": f"Bearer {api_key}"
|
|
549
|
-
}
|
|
550
|
-
|
|
575
|
+
|
|
576
|
+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
|
577
|
+
|
|
551
578
|
payload = {
|
|
552
579
|
"model": model,
|
|
553
580
|
"prompt": prompt,
|
|
554
581
|
"max_tokens": max_tokens,
|
|
555
582
|
"temperature": temperature,
|
|
556
583
|
}
|
|
557
|
-
|
|
584
|
+
|
|
558
585
|
if stop_sequence:
|
|
559
586
|
payload["stop"] = stop_sequence
|
|
560
|
-
|
|
587
|
+
|
|
561
588
|
try:
|
|
562
589
|
response = requests.post(url, json=payload, headers=headers, timeout=60)
|
|
563
590
|
response.raise_for_status()
|
|
564
|
-
|
|
591
|
+
|
|
565
592
|
result = response.json()
|
|
566
|
-
|
|
567
|
-
return TextGenerationOutput(
|
|
568
|
-
|
|
569
|
-
completion_output=result.get("completion")
|
|
570
|
-
)
|
|
571
|
-
|
|
593
|
+
|
|
594
|
+
return TextGenerationOutput(transaction_hash="external", completion_output=result.get("completion"))
|
|
595
|
+
|
|
572
596
|
except requests.RequestException as e:
|
|
573
597
|
error_msg = f"External LLM completion failed: {str(e)}"
|
|
574
|
-
if hasattr(e,
|
|
598
|
+
if hasattr(e, "response") and e.response is not None:
|
|
575
599
|
try:
|
|
576
600
|
error_detail = e.response.json()
|
|
577
601
|
error_msg += f" - {error_detail}"
|
|
@@ -591,20 +615,20 @@ class Client:
|
|
|
591
615
|
"Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}",
|
|
592
616
|
"X-SETTLEMENT-TYPE": x402_settlement_mode,
|
|
593
617
|
}
|
|
594
|
-
|
|
618
|
+
|
|
595
619
|
payload = {
|
|
596
620
|
"model": model,
|
|
597
621
|
"prompt": prompt,
|
|
598
622
|
"max_tokens": max_tokens,
|
|
599
623
|
"temperature": temperature,
|
|
600
624
|
}
|
|
601
|
-
|
|
625
|
+
|
|
602
626
|
if stop_sequence:
|
|
603
627
|
payload["stop"] = stop_sequence
|
|
604
|
-
|
|
628
|
+
|
|
605
629
|
try:
|
|
606
630
|
response = await client.post("/v1/completions", json=payload, headers=headers, timeout=60)
|
|
607
|
-
|
|
631
|
+
|
|
608
632
|
# Read the response content
|
|
609
633
|
content = await response.aread()
|
|
610
634
|
result = json.loads(content.decode())
|
|
@@ -612,24 +636,22 @@ class Client:
|
|
|
612
636
|
|
|
613
637
|
if X402_PROCESSING_HASH_HEADER in response.headers:
|
|
614
638
|
payment_hash = response.headers[X402_PROCESSING_HASH_HEADER]
|
|
615
|
-
|
|
639
|
+
|
|
616
640
|
return TextGenerationOutput(
|
|
617
|
-
transaction_hash="external",
|
|
618
|
-
completion_output=result.get("completion"),
|
|
619
|
-
payment_hash=payment_hash
|
|
641
|
+
transaction_hash="external", completion_output=result.get("completion"), payment_hash=payment_hash
|
|
620
642
|
)
|
|
621
|
-
|
|
643
|
+
|
|
622
644
|
except Exception as e:
|
|
623
645
|
error_msg = f"External LLM completion request failed: {str(e)}"
|
|
624
646
|
logging.error(error_msg)
|
|
625
647
|
raise OpenGradientError(error_msg)
|
|
626
|
-
|
|
648
|
+
|
|
627
649
|
try:
|
|
628
650
|
# Run the async function in a sync context
|
|
629
651
|
return asyncio.run(make_request())
|
|
630
652
|
except Exception as e:
|
|
631
653
|
error_msg = f"External LLM completion failed: {str(e)}"
|
|
632
|
-
if hasattr(e,
|
|
654
|
+
if hasattr(e, "response") and e.response is not None:
|
|
633
655
|
try:
|
|
634
656
|
error_detail = e.response.json()
|
|
635
657
|
error_msg += f" - {error_detail}"
|
|
@@ -651,7 +673,8 @@ class Client:
|
|
|
651
673
|
max_retries: Optional[int] = None,
|
|
652
674
|
local_model: Optional[bool] = False,
|
|
653
675
|
x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH,
|
|
654
|
-
|
|
676
|
+
stream: bool = False,
|
|
677
|
+
) -> Union[TextGenerationOutput, TextGenerationStream]:
|
|
655
678
|
"""
|
|
656
679
|
Perform inference on an LLM model using chat.
|
|
657
680
|
|
|
@@ -666,9 +689,17 @@ class Client:
|
|
|
666
689
|
tool_choice (str, optional): Sets a specific tool to choose.
|
|
667
690
|
max_retries (int, optional): Maximum number of retry attempts.
|
|
668
691
|
local_model (bool, optional): Force use of local model.
|
|
692
|
+
x402_settlement_mode (x402SettlementMode, optional): Settlement mode for x402 payments.
|
|
693
|
+
- SETTLE: Records input/output hashes only (most privacy-preserving).
|
|
694
|
+
- SETTLE_BATCH: Aggregates multiple inferences into batch hashes (most cost-efficient).
|
|
695
|
+
- SETTLE_METADATA: Records full model info, complete input/output data, and all metadata.
|
|
696
|
+
Defaults to SETTLE_BATCH.
|
|
697
|
+
stream (bool, optional): Whether to stream the response. Default is False.
|
|
669
698
|
|
|
670
699
|
Returns:
|
|
671
|
-
TextGenerationOutput:
|
|
700
|
+
Union[TextGenerationOutput, TextGenerationStream]:
|
|
701
|
+
- If stream=False: TextGenerationOutput with chat_output, transaction_hash, finish_reason, and payment_hash
|
|
702
|
+
- If stream=True: TextGenerationStream yielding StreamChunk objects with typed deltas (true streaming via threading)
|
|
672
703
|
|
|
673
704
|
Raises:
|
|
674
705
|
OpenGradientError: If the inference fails.
|
|
@@ -679,28 +710,45 @@ class Client:
|
|
|
679
710
|
if model_cid not in TEE_LLM:
|
|
680
711
|
return OpenGradientError("That model CID is not supported yet for TEE inference")
|
|
681
712
|
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
713
|
+
if stream:
|
|
714
|
+
# Use threading bridge for true sync streaming
|
|
715
|
+
return self._external_llm_chat_stream_sync(
|
|
716
|
+
model=model_cid.split("/")[1],
|
|
717
|
+
messages=messages,
|
|
718
|
+
max_tokens=max_tokens,
|
|
719
|
+
stop_sequence=stop_sequence,
|
|
720
|
+
temperature=temperature,
|
|
721
|
+
tools=tools,
|
|
722
|
+
tool_choice=tool_choice,
|
|
723
|
+
x402_settlement_mode=x402_settlement_mode,
|
|
724
|
+
use_tee=True,
|
|
725
|
+
)
|
|
726
|
+
else:
|
|
727
|
+
# Non-streaming
|
|
728
|
+
return self._external_llm_chat(
|
|
729
|
+
model=model_cid.split("/")[1],
|
|
730
|
+
messages=messages,
|
|
731
|
+
max_tokens=max_tokens,
|
|
732
|
+
stop_sequence=stop_sequence,
|
|
733
|
+
temperature=temperature,
|
|
734
|
+
tools=tools,
|
|
735
|
+
tool_choice=tool_choice,
|
|
736
|
+
x402_settlement_mode=x402_settlement_mode,
|
|
737
|
+
stream=False,
|
|
738
|
+
use_tee=True,
|
|
739
|
+
)
|
|
740
|
+
|
|
693
741
|
# Original local model logic
|
|
694
742
|
def execute_transaction():
|
|
695
743
|
if inference_mode != LlmInferenceMode.VANILLA:
|
|
696
744
|
raise OpenGradientError("Invalid inference mode %s: Inference mode must be VANILLA or TEE" % inference_mode)
|
|
697
|
-
|
|
745
|
+
|
|
698
746
|
if model_cid not in [llm.value for llm in LLM]:
|
|
699
747
|
raise OpenGradientError("That model CID is not yet supported for inference")
|
|
700
|
-
|
|
748
|
+
|
|
701
749
|
model_name = model_cid
|
|
702
750
|
if model_cid in [llm.value for llm in TEE_LLM]:
|
|
703
|
-
model_name = model_cid.split(
|
|
751
|
+
model_name = model_cid.split("/")[1]
|
|
704
752
|
|
|
705
753
|
contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
|
|
706
754
|
|
|
@@ -768,10 +816,12 @@ class Client:
|
|
|
768
816
|
tools: Optional[List[Dict]] = None,
|
|
769
817
|
tool_choice: Optional[str] = None,
|
|
770
818
|
x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH,
|
|
771
|
-
|
|
819
|
+
stream: bool = False,
|
|
820
|
+
use_tee: bool = False,
|
|
821
|
+
) -> Union[TextGenerationOutput, TextGenerationStream]:
|
|
772
822
|
"""
|
|
773
823
|
Route chat request to external LLM server with x402 payments.
|
|
774
|
-
|
|
824
|
+
|
|
775
825
|
Args:
|
|
776
826
|
model: Model identifier
|
|
777
827
|
messages: List of chat messages
|
|
@@ -780,53 +830,63 @@ class Client:
|
|
|
780
830
|
temperature: Sampling temperature
|
|
781
831
|
tools: Function calling tools
|
|
782
832
|
tool_choice: Tool selection strategy
|
|
783
|
-
|
|
833
|
+
stream: Whether to stream the response
|
|
834
|
+
use_tee: Whether to use TEE
|
|
835
|
+
|
|
784
836
|
Returns:
|
|
785
|
-
TextGenerationOutput
|
|
786
|
-
|
|
837
|
+
Union[TextGenerationOutput, TextGenerationStream]: Chat completion or TextGenerationStream
|
|
838
|
+
|
|
787
839
|
Raises:
|
|
788
840
|
OpenGradientError: If request fails
|
|
789
841
|
"""
|
|
790
|
-
api_key = self._get_api_key_for_model(model)
|
|
791
|
-
|
|
842
|
+
api_key = None if use_tee else self._get_api_key_for_model(model)
|
|
843
|
+
|
|
792
844
|
if api_key:
|
|
793
|
-
logging.debug("External LLM
|
|
794
|
-
url = f"{self._llm_server_url}/v1/chat/completions"
|
|
795
|
-
|
|
796
|
-
headers = {
|
|
797
|
-
"Content-Type": "application/json",
|
|
798
|
-
"Authorization": f"Bearer {api_key}"
|
|
799
|
-
}
|
|
845
|
+
logging.debug("External LLM chat using API key")
|
|
800
846
|
|
|
847
|
+
if stream:
|
|
848
|
+
url = f"{self._llm_server_url}/v1/chat/completions/stream"
|
|
849
|
+
else:
|
|
850
|
+
url = f"{self._llm_server_url}/v1/chat/completions"
|
|
851
|
+
|
|
852
|
+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
|
853
|
+
|
|
801
854
|
payload = {
|
|
802
855
|
"model": model,
|
|
803
856
|
"messages": messages,
|
|
804
857
|
"max_tokens": max_tokens,
|
|
805
858
|
"temperature": temperature,
|
|
806
859
|
}
|
|
807
|
-
|
|
860
|
+
|
|
808
861
|
if stop_sequence:
|
|
809
862
|
payload["stop"] = stop_sequence
|
|
810
|
-
|
|
863
|
+
|
|
811
864
|
if tools:
|
|
812
865
|
payload["tools"] = tools
|
|
813
866
|
payload["tool_choice"] = tool_choice or "auto"
|
|
814
|
-
|
|
867
|
+
|
|
815
868
|
try:
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
869
|
+
if stream:
|
|
870
|
+
# Return streaming response wrapped in TextGenerationStream
|
|
871
|
+
response = requests.post(url, json=payload, headers=headers, timeout=60, stream=True)
|
|
872
|
+
response.raise_for_status()
|
|
873
|
+
return TextGenerationStream(_iterator=response.iter_lines(decode_unicode=True), _is_async=False)
|
|
874
|
+
else:
|
|
875
|
+
# Non-streaming response
|
|
876
|
+
response = requests.post(url, json=payload, headers=headers, timeout=60)
|
|
877
|
+
response.raise_for_status()
|
|
878
|
+
|
|
879
|
+
result = response.json()
|
|
880
|
+
|
|
881
|
+
return TextGenerationOutput(
|
|
882
|
+
transaction_hash="external",
|
|
883
|
+
finish_reason=result.get("finish_reason"),
|
|
884
|
+
chat_output=result.get("message")
|
|
885
|
+
)
|
|
886
|
+
|
|
827
887
|
except requests.RequestException as e:
|
|
828
888
|
error_msg = f"External LLM chat failed: {str(e)}"
|
|
829
|
-
if hasattr(e,
|
|
889
|
+
if hasattr(e, "response") and e.response is not None:
|
|
830
890
|
try:
|
|
831
891
|
error_detail = e.response.json()
|
|
832
892
|
error_msg += f" - {error_detail}"
|
|
@@ -835,6 +895,7 @@ class Client:
|
|
|
835
895
|
logging.error(error_msg)
|
|
836
896
|
raise OpenGradientError(error_msg)
|
|
837
897
|
|
|
898
|
+
# x402 payment path - non-streaming only here
|
|
838
899
|
async def make_request():
|
|
839
900
|
async with x402HttpxClient(
|
|
840
901
|
account=self._wallet_account,
|
|
@@ -844,58 +905,58 @@ class Client:
|
|
|
844
905
|
headers = {
|
|
845
906
|
"Content-Type": "application/json",
|
|
846
907
|
"Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}",
|
|
847
|
-
"X-SETTLEMENT-TYPE": x402_settlement_mode
|
|
908
|
+
"X-SETTLEMENT-TYPE": x402_settlement_mode,
|
|
848
909
|
}
|
|
849
|
-
|
|
910
|
+
|
|
850
911
|
payload = {
|
|
851
912
|
"model": model,
|
|
852
913
|
"messages": messages,
|
|
853
914
|
"max_tokens": max_tokens,
|
|
854
915
|
"temperature": temperature,
|
|
855
916
|
}
|
|
856
|
-
|
|
917
|
+
|
|
857
918
|
if stop_sequence:
|
|
858
919
|
payload["stop"] = stop_sequence
|
|
859
|
-
|
|
920
|
+
|
|
860
921
|
if tools:
|
|
861
922
|
payload["tools"] = tools
|
|
862
923
|
payload["tool_choice"] = tool_choice or "auto"
|
|
863
|
-
|
|
924
|
+
|
|
864
925
|
try:
|
|
865
|
-
|
|
866
|
-
|
|
926
|
+
# Non-streaming with x402
|
|
927
|
+
endpoint = "/v1/chat/completions"
|
|
928
|
+
response = await client.post(endpoint, json=payload, headers=headers, timeout=60)
|
|
929
|
+
|
|
867
930
|
# Read the response content
|
|
868
931
|
content = await response.aread()
|
|
869
932
|
result = json.loads(content.decode())
|
|
870
|
-
# print(f"Response: {response}")
|
|
871
|
-
# print(f"Response Headers: {response.headers}")
|
|
872
933
|
|
|
873
934
|
payment_hash = ""
|
|
874
935
|
if X402_PROCESSING_HASH_HEADER in response.headers:
|
|
875
936
|
payment_hash = response.headers[X402_PROCESSING_HASH_HEADER]
|
|
876
|
-
|
|
937
|
+
|
|
877
938
|
choices = result.get("choices")
|
|
878
939
|
if not choices:
|
|
879
940
|
raise OpenGradientError(f"Invalid response: 'choices' missing or empty in {result}")
|
|
880
|
-
|
|
941
|
+
|
|
881
942
|
return TextGenerationOutput(
|
|
882
943
|
transaction_hash="external",
|
|
883
944
|
finish_reason=choices[0].get("finish_reason"),
|
|
884
945
|
chat_output=choices[0].get("message"),
|
|
885
|
-
payment_hash=payment_hash
|
|
946
|
+
payment_hash=payment_hash,
|
|
886
947
|
)
|
|
887
|
-
|
|
948
|
+
|
|
888
949
|
except Exception as e:
|
|
889
950
|
error_msg = f"External LLM chat request failed: {str(e)}"
|
|
890
951
|
logging.error(error_msg)
|
|
891
952
|
raise OpenGradientError(error_msg)
|
|
892
|
-
|
|
953
|
+
|
|
893
954
|
try:
|
|
894
955
|
# Run the async function in a sync context
|
|
895
956
|
return asyncio.run(make_request())
|
|
896
957
|
except Exception as e:
|
|
897
958
|
error_msg = f"External LLM chat failed: {str(e)}"
|
|
898
|
-
if hasattr(e,
|
|
959
|
+
if hasattr(e, "response") and e.response is not None:
|
|
899
960
|
try:
|
|
900
961
|
error_detail = e.response.json()
|
|
901
962
|
error_msg += f" - {error_detail}"
|
|
@@ -904,6 +965,234 @@ class Client:
|
|
|
904
965
|
logging.error(error_msg)
|
|
905
966
|
raise OpenGradientError(error_msg)
|
|
906
967
|
|
|
968
|
+
def _external_llm_chat_stream_sync(
|
|
969
|
+
self,
|
|
970
|
+
model: str,
|
|
971
|
+
messages: List[Dict],
|
|
972
|
+
max_tokens: int = 100,
|
|
973
|
+
stop_sequence: Optional[List[str]] = None,
|
|
974
|
+
temperature: float = 0.0,
|
|
975
|
+
tools: Optional[List[Dict]] = None,
|
|
976
|
+
tool_choice: Optional[str] = None,
|
|
977
|
+
x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH,
|
|
978
|
+
use_tee: bool = False,
|
|
979
|
+
):
|
|
980
|
+
"""
|
|
981
|
+
Sync streaming using threading bridge - TRUE real-time streaming.
|
|
982
|
+
|
|
983
|
+
Yields StreamChunk objects as they arrive from the background thread.
|
|
984
|
+
NO buffering, NO conversion, just direct pass-through.
|
|
985
|
+
"""
|
|
986
|
+
import threading
|
|
987
|
+
from queue import Queue
|
|
988
|
+
|
|
989
|
+
queue = Queue()
|
|
990
|
+
exception_holder = []
|
|
991
|
+
|
|
992
|
+
def _run_async():
|
|
993
|
+
"""Run async streaming in background thread"""
|
|
994
|
+
loop = None
|
|
995
|
+
try:
|
|
996
|
+
loop = asyncio.new_event_loop()
|
|
997
|
+
asyncio.set_event_loop(loop)
|
|
998
|
+
|
|
999
|
+
async def _stream():
|
|
1000
|
+
try:
|
|
1001
|
+
async for chunk in self._external_llm_chat_stream_async(
|
|
1002
|
+
model=model,
|
|
1003
|
+
messages=messages,
|
|
1004
|
+
max_tokens=max_tokens,
|
|
1005
|
+
stop_sequence=stop_sequence,
|
|
1006
|
+
temperature=temperature,
|
|
1007
|
+
tools=tools,
|
|
1008
|
+
tool_choice=tool_choice,
|
|
1009
|
+
x402_settlement_mode=x402_settlement_mode,
|
|
1010
|
+
use_tee=use_tee,
|
|
1011
|
+
):
|
|
1012
|
+
queue.put(chunk) # Put chunk immediately
|
|
1013
|
+
except Exception as e:
|
|
1014
|
+
exception_holder.append(e)
|
|
1015
|
+
finally:
|
|
1016
|
+
queue.put(None) # Signal completion
|
|
1017
|
+
|
|
1018
|
+
loop.run_until_complete(_stream())
|
|
1019
|
+
except Exception as e:
|
|
1020
|
+
exception_holder.append(e)
|
|
1021
|
+
queue.put(None)
|
|
1022
|
+
finally:
|
|
1023
|
+
if loop:
|
|
1024
|
+
try:
|
|
1025
|
+
pending = asyncio.all_tasks(loop)
|
|
1026
|
+
for task in pending:
|
|
1027
|
+
task.cancel()
|
|
1028
|
+
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
|
1029
|
+
finally:
|
|
1030
|
+
loop.close()
|
|
1031
|
+
|
|
1032
|
+
# Start background thread
|
|
1033
|
+
thread = threading.Thread(target=_run_async, daemon=True)
|
|
1034
|
+
thread.start()
|
|
1035
|
+
|
|
1036
|
+
# Yield chunks DIRECTLY as they arrive - NO buffering
|
|
1037
|
+
try:
|
|
1038
|
+
while True:
|
|
1039
|
+
chunk = queue.get() # Blocks until chunk available
|
|
1040
|
+
if chunk is None:
|
|
1041
|
+
break
|
|
1042
|
+
yield chunk # Yield immediately!
|
|
1043
|
+
|
|
1044
|
+
thread.join(timeout=5)
|
|
1045
|
+
|
|
1046
|
+
if exception_holder:
|
|
1047
|
+
raise exception_holder[0]
|
|
1048
|
+
except Exception as e:
|
|
1049
|
+
thread.join(timeout=1)
|
|
1050
|
+
raise
|
|
1051
|
+
|
|
1052
|
+
|
|
1053
|
+
async def _external_llm_chat_stream_async(
|
|
1054
|
+
self,
|
|
1055
|
+
model: str,
|
|
1056
|
+
messages: List[Dict],
|
|
1057
|
+
max_tokens: int = 100,
|
|
1058
|
+
stop_sequence: Optional[List[str]] = None,
|
|
1059
|
+
temperature: float = 0.0,
|
|
1060
|
+
tools: Optional[List[Dict]] = None,
|
|
1061
|
+
tool_choice: Optional[str] = None,
|
|
1062
|
+
x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH,
|
|
1063
|
+
use_tee: bool = False,
|
|
1064
|
+
):
|
|
1065
|
+
"""
|
|
1066
|
+
Internal async streaming implementation.
|
|
1067
|
+
|
|
1068
|
+
Yields StreamChunk objects as they arrive from the server.
|
|
1069
|
+
"""
|
|
1070
|
+
api_key = None if use_tee else self._get_api_key_for_model(model)
|
|
1071
|
+
|
|
1072
|
+
if api_key:
|
|
1073
|
+
# API key path - streaming to local llm-server
|
|
1074
|
+
url = f"{self._og_llm_streaming_server_url}/v1/chat/completions"
|
|
1075
|
+
headers = {
|
|
1076
|
+
"Content-Type": "application/json",
|
|
1077
|
+
"Authorization": f"Bearer {api_key}"
|
|
1078
|
+
}
|
|
1079
|
+
|
|
1080
|
+
payload = {
|
|
1081
|
+
"model": model,
|
|
1082
|
+
"messages": messages,
|
|
1083
|
+
"max_tokens": max_tokens,
|
|
1084
|
+
"temperature": temperature,
|
|
1085
|
+
"stream": True,
|
|
1086
|
+
}
|
|
1087
|
+
|
|
1088
|
+
if stop_sequence:
|
|
1089
|
+
payload["stop"] = stop_sequence
|
|
1090
|
+
if tools:
|
|
1091
|
+
payload["tools"] = tools
|
|
1092
|
+
payload["tool_choice"] = tool_choice or "auto"
|
|
1093
|
+
|
|
1094
|
+
async with httpx.AsyncClient(verify=False, timeout=None) as client:
|
|
1095
|
+
async with client.stream("POST", url, json=payload, headers=headers) as response:
|
|
1096
|
+
buffer = b""
|
|
1097
|
+
async for chunk in response.aiter_raw():
|
|
1098
|
+
if not chunk:
|
|
1099
|
+
continue
|
|
1100
|
+
|
|
1101
|
+
buffer += chunk
|
|
1102
|
+
|
|
1103
|
+
# Process all complete lines in buffer
|
|
1104
|
+
while b"\n" in buffer:
|
|
1105
|
+
line_bytes, buffer = buffer.split(b"\n", 1)
|
|
1106
|
+
|
|
1107
|
+
if not line_bytes.strip():
|
|
1108
|
+
continue
|
|
1109
|
+
|
|
1110
|
+
try:
|
|
1111
|
+
line = line_bytes.decode('utf-8').strip()
|
|
1112
|
+
except UnicodeDecodeError:
|
|
1113
|
+
continue
|
|
1114
|
+
|
|
1115
|
+
if not line.startswith("data: "):
|
|
1116
|
+
continue
|
|
1117
|
+
|
|
1118
|
+
data_str = line[6:] # Strip "data: " prefix
|
|
1119
|
+
if data_str.strip() == "[DONE]":
|
|
1120
|
+
return
|
|
1121
|
+
|
|
1122
|
+
try:
|
|
1123
|
+
data = json.loads(data_str)
|
|
1124
|
+
yield StreamChunk.from_sse_data(data)
|
|
1125
|
+
except json.JSONDecodeError:
|
|
1126
|
+
continue
|
|
1127
|
+
else:
|
|
1128
|
+
# x402 payment path
|
|
1129
|
+
async with httpx.AsyncClient(
|
|
1130
|
+
base_url=self._og_llm_streaming_server_url,
|
|
1131
|
+
headers={"Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}"},
|
|
1132
|
+
timeout=TIMEOUT,
|
|
1133
|
+
limits=LIMITS,
|
|
1134
|
+
http2=False,
|
|
1135
|
+
follow_redirects=False,
|
|
1136
|
+
auth=X402Auth(account=self._wallet_account), # type: ignore
|
|
1137
|
+
) as client:
|
|
1138
|
+
headers = {
|
|
1139
|
+
"Content-Type": "application/json",
|
|
1140
|
+
"Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}",
|
|
1141
|
+
"X-SETTLEMENT-TYPE": x402_settlement_mode,
|
|
1142
|
+
}
|
|
1143
|
+
|
|
1144
|
+
payload = {
|
|
1145
|
+
"model": model,
|
|
1146
|
+
"messages": messages,
|
|
1147
|
+
"max_tokens": max_tokens,
|
|
1148
|
+
"temperature": temperature,
|
|
1149
|
+
"stream": True,
|
|
1150
|
+
}
|
|
1151
|
+
|
|
1152
|
+
if stop_sequence:
|
|
1153
|
+
payload["stop"] = stop_sequence
|
|
1154
|
+
if tools:
|
|
1155
|
+
payload["tools"] = tools
|
|
1156
|
+
payload["tool_choice"] = tool_choice or "auto"
|
|
1157
|
+
|
|
1158
|
+
async with client.stream(
|
|
1159
|
+
"POST",
|
|
1160
|
+
"/v1/chat/completions",
|
|
1161
|
+
json=payload,
|
|
1162
|
+
headers=headers,
|
|
1163
|
+
) as response:
|
|
1164
|
+
buffer = b""
|
|
1165
|
+
async for chunk in response.aiter_raw():
|
|
1166
|
+
if not chunk:
|
|
1167
|
+
continue
|
|
1168
|
+
|
|
1169
|
+
buffer += chunk
|
|
1170
|
+
|
|
1171
|
+
# Process complete lines from buffer
|
|
1172
|
+
while b"\n" in buffer:
|
|
1173
|
+
line_bytes, buffer = buffer.split(b"\n", 1)
|
|
1174
|
+
|
|
1175
|
+
if not line_bytes.strip():
|
|
1176
|
+
continue
|
|
1177
|
+
|
|
1178
|
+
try:
|
|
1179
|
+
line = line_bytes.decode('utf-8').strip()
|
|
1180
|
+
except UnicodeDecodeError:
|
|
1181
|
+
continue
|
|
1182
|
+
|
|
1183
|
+
if not line.startswith("data: "):
|
|
1184
|
+
continue
|
|
1185
|
+
|
|
1186
|
+
data_str = line[6:]
|
|
1187
|
+
if data_str.strip() == "[DONE]":
|
|
1188
|
+
return
|
|
1189
|
+
|
|
1190
|
+
try:
|
|
1191
|
+
data = json.loads(data_str)
|
|
1192
|
+
yield StreamChunk.from_sse_data(data)
|
|
1193
|
+
except json.JSONDecodeError:
|
|
1194
|
+
continue
|
|
1195
|
+
|
|
907
1196
|
def list_files(self, model_name: str, version: str) -> List[Dict]:
|
|
908
1197
|
"""
|
|
909
1198
|
List files for a specific version of a model.
|
|
@@ -1104,12 +1393,12 @@ class Client:
|
|
|
1104
1393
|
except ContractLogicError as e:
|
|
1105
1394
|
try:
|
|
1106
1395
|
run_function.call({"from": self._wallet_account.address})
|
|
1107
|
-
|
|
1396
|
+
|
|
1108
1397
|
except ContractLogicError as call_err:
|
|
1109
1398
|
raise ContractLogicError(f"simulation failed with revert reason: {call_err.args[0]}")
|
|
1110
|
-
|
|
1399
|
+
|
|
1111
1400
|
raise ContractLogicError(f"simulation failed with no revert reason. Reason: {e}")
|
|
1112
|
-
|
|
1401
|
+
|
|
1113
1402
|
gas_limit = int(estimated_gas * 3)
|
|
1114
1403
|
|
|
1115
1404
|
transaction = run_function.build_transaction(
|
|
@@ -1128,10 +1417,10 @@ class Client:
|
|
|
1128
1417
|
if tx_receipt["status"] == 0:
|
|
1129
1418
|
try:
|
|
1130
1419
|
run_function.call({"from": self._wallet_account.address})
|
|
1131
|
-
|
|
1420
|
+
|
|
1132
1421
|
except ContractLogicError as call_err:
|
|
1133
1422
|
raise ContractLogicError(f"Transaction failed with revert reason: {call_err.args[0]}")
|
|
1134
|
-
|
|
1423
|
+
|
|
1135
1424
|
raise ContractLogicError(f"Transaction failed with no revert reason. Receipt: {tx_receipt}")
|
|
1136
1425
|
|
|
1137
1426
|
return tx_hash, tx_receipt
|
|
@@ -1346,45 +1635,42 @@ class Client:
|
|
|
1346
1635
|
results = contract.functions.getLastInferenceResults(num_results).call()
|
|
1347
1636
|
return [convert_array_to_model_output(result) for result in results]
|
|
1348
1637
|
|
|
1349
|
-
|
|
1350
1638
|
def _get_inference_result_from_node(self, inference_id: str, inference_mode: InferenceMode) -> Dict:
|
|
1351
1639
|
"""
|
|
1352
1640
|
Get the inference result from node.
|
|
1353
|
-
|
|
1641
|
+
|
|
1354
1642
|
Args:
|
|
1355
1643
|
inference_id (str): Inference id for a inference request
|
|
1356
|
-
|
|
1644
|
+
|
|
1357
1645
|
Returns:
|
|
1358
1646
|
Dict: The inference result as returned by the node
|
|
1359
|
-
|
|
1647
|
+
|
|
1360
1648
|
Raises:
|
|
1361
1649
|
OpenGradientError: If the request fails or returns an error
|
|
1362
1650
|
"""
|
|
1363
1651
|
try:
|
|
1364
|
-
encoded_id = urllib.parse.quote(inference_id, safe=
|
|
1652
|
+
encoded_id = urllib.parse.quote(inference_id, safe="")
|
|
1365
1653
|
url = f"{self._api_url}/artela-network/artela-rollkit/inference/tx/{encoded_id}"
|
|
1366
|
-
|
|
1654
|
+
|
|
1367
1655
|
response = requests.get(url)
|
|
1368
1656
|
if response.status_code == 200:
|
|
1369
1657
|
resp = response.json()
|
|
1370
1658
|
inference_result = resp.get("inference_results", {})
|
|
1371
1659
|
if inference_result:
|
|
1372
1660
|
decoded_bytes = base64.b64decode(inference_result[0])
|
|
1373
|
-
decoded_string = decoded_bytes.decode(
|
|
1374
|
-
output = json.loads(decoded_string).get("InferenceResult",{})
|
|
1661
|
+
decoded_string = decoded_bytes.decode("utf-8")
|
|
1662
|
+
output = json.loads(decoded_string).get("InferenceResult", {})
|
|
1375
1663
|
if output is None:
|
|
1376
1664
|
raise OpenGradientError("Missing InferenceResult in inference output")
|
|
1377
|
-
|
|
1665
|
+
|
|
1378
1666
|
match inference_mode:
|
|
1379
1667
|
case InferenceMode.VANILLA:
|
|
1380
1668
|
if "VanillaResult" not in output:
|
|
1381
1669
|
raise OpenGradientError("Missing VanillaResult in inference output")
|
|
1382
1670
|
if "model_output" not in output["VanillaResult"]:
|
|
1383
1671
|
raise OpenGradientError("Missing model_output in VanillaResult")
|
|
1384
|
-
return {
|
|
1385
|
-
|
|
1386
|
-
}
|
|
1387
|
-
|
|
1672
|
+
return {"output": output["VanillaResult"]["model_output"]}
|
|
1673
|
+
|
|
1388
1674
|
case InferenceMode.TEE:
|
|
1389
1675
|
if "TeeNodeResult" not in output:
|
|
1390
1676
|
raise OpenGradientError("Missing TeeNodeResult in inference output")
|
|
@@ -1393,34 +1679,30 @@ class Client:
|
|
|
1393
1679
|
if "VanillaResponse" in output["TeeNodeResult"]["Response"]:
|
|
1394
1680
|
if "model_output" not in output["TeeNodeResult"]["Response"]["VanillaResponse"]:
|
|
1395
1681
|
raise OpenGradientError("Missing model_output in VanillaResponse")
|
|
1396
|
-
return {
|
|
1397
|
-
|
|
1398
|
-
}
|
|
1399
|
-
|
|
1682
|
+
return {"output": output["TeeNodeResult"]["Response"]["VanillaResponse"]["model_output"]}
|
|
1683
|
+
|
|
1400
1684
|
else:
|
|
1401
1685
|
raise OpenGradientError("Missing VanillaResponse in TeeNodeResult Response")
|
|
1402
|
-
|
|
1686
|
+
|
|
1403
1687
|
case InferenceMode.ZKML:
|
|
1404
1688
|
if "ZkmlResult" not in output:
|
|
1405
1689
|
raise OpenGradientError("Missing ZkmlResult in inference output")
|
|
1406
1690
|
if "model_output" not in output["ZkmlResult"]:
|
|
1407
1691
|
raise OpenGradientError("Missing model_output in ZkmlResult")
|
|
1408
|
-
return {
|
|
1409
|
-
|
|
1410
|
-
}
|
|
1411
|
-
|
|
1692
|
+
return {"output": output["ZkmlResult"]["model_output"]}
|
|
1693
|
+
|
|
1412
1694
|
case _:
|
|
1413
1695
|
raise OpenGradientError(f"Invalid inference mode: {inference_mode}")
|
|
1414
1696
|
else:
|
|
1415
1697
|
return None
|
|
1416
|
-
|
|
1698
|
+
|
|
1417
1699
|
else:
|
|
1418
1700
|
error_message = f"Failed to get inference result: HTTP {response.status_code}"
|
|
1419
1701
|
if response.text:
|
|
1420
1702
|
error_message += f" - {response.text}"
|
|
1421
1703
|
logging.error(error_message)
|
|
1422
1704
|
raise OpenGradientError(error_message)
|
|
1423
|
-
|
|
1705
|
+
|
|
1424
1706
|
except requests.RequestException as e:
|
|
1425
1707
|
logging.error(f"Request exception when getting inference result: {str(e)}")
|
|
1426
1708
|
raise OpenGradientError(f"Failed to get inference result: {str(e)}")
|
|
@@ -1428,6 +1710,7 @@ class Client:
|
|
|
1428
1710
|
logging.error(f"Unexpected error when getting inference result: {str(e)}", exc_info=True)
|
|
1429
1711
|
raise OpenGradientError(f"Failed to get inference result: {str(e)}")
|
|
1430
1712
|
|
|
1713
|
+
|
|
1431
1714
|
def run_with_retry(txn_function: Callable, max_retries=DEFAULT_MAX_RETRY, retry_delay=DEFAULT_RETRY_DELAY_SEC):
|
|
1432
1715
|
"""
|
|
1433
1716
|
Execute a blockchain transaction with retry logic.
|