opengradient 0.5.7__py3-none-any.whl → 0.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
opengradient/client.py CHANGED
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union, Callable
9
9
  import firebase
10
10
  import numpy as np
11
11
  import requests
12
+ import httpx
12
13
  from eth_account.account import LocalAccount
13
14
  from web3 import Web3
14
15
  from web3.exceptions import ContractLogicError
@@ -17,7 +18,9 @@ import urllib.parse
17
18
  import asyncio
18
19
  from x402.clients.httpx import x402HttpxClient
19
20
  from x402.clients.base import decode_x_payment_response, x402Client
21
+ from x402.clients.httpx import x402HttpxClient
20
22
 
23
+ from .x402_auth import X402Auth
21
24
  from .exceptions import OpenGradientError
22
25
  from .proto import infer_pb2, infer_pb2_grpc
23
26
  from .types import (
@@ -29,17 +32,22 @@ from .types import (
29
32
  LlmInferenceMode,
30
33
  ModelOutput,
31
34
  TextGenerationOutput,
35
+ TextGenerationStream,
32
36
  SchedulerParams,
33
37
  InferenceResult,
34
38
  ModelRepository,
35
39
  FileUploadResult,
40
+ StreamChunk,
36
41
  )
37
42
  from .defaults import (
38
- DEFAULT_IMAGE_GEN_HOST,
39
- DEFAULT_IMAGE_GEN_PORT,
43
+ DEFAULT_IMAGE_GEN_HOST,
44
+ DEFAULT_IMAGE_GEN_PORT,
40
45
  DEFAULT_SCHEDULER_ADDRESS,
41
- DEFAULT_LLM_SERVER_URL,
42
- DEFAULT_OPENGRADIENT_LLM_SERVER_URL)
46
+ DEFAULT_LLM_SERVER_URL,
47
+ DEFAULT_OPENGRADIENT_LLM_SERVER_URL,
48
+ DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL,
49
+ DEFAULT_NETWORK_FILTER,
50
+ )
43
51
  from .utils import convert_array_to_model_output, convert_to_model_input, convert_to_model_output
44
52
 
45
53
  _FIREBASE_CONFIG = {
@@ -65,6 +73,19 @@ PRECOMPILE_CONTRACT_ADDRESS = "0x00000000000000000000000000000000000000F4"
65
73
  X402_PROCESSING_HASH_HEADER = "x-processing-hash"
66
74
  X402_PLACEHOLDER_API_KEY = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"
67
75
 
76
+ TIMEOUT = httpx.Timeout(
77
+ timeout=90.0,
78
+ connect=15.0,
79
+ read=15.0,
80
+ write=30.0,
81
+ pool=10.0,
82
+ )
83
+ LIMITS = httpx.Limits(
84
+ max_keepalive_connections=100,
85
+ max_connections=500,
86
+ keepalive_expiry=60 * 20, # 20 minutes
87
+ )
88
+
68
89
  class Client:
69
90
  _inference_hub_contract_address: str
70
91
  _blockchain: Web3
@@ -76,20 +97,22 @@ class Client:
76
97
  _precompile_abi: Dict
77
98
  _llm_server_url: str
78
99
  _external_api_keys: Dict[str, str]
100
+
79
101
  def __init__(
80
- self,
81
- private_key: str,
82
- rpc_url: str,
83
- api_url: str,
84
- contract_address: str,
85
- email: Optional[str] = None,
86
- password: Optional[str] = None,
102
+ self,
103
+ private_key: str,
104
+ rpc_url: str,
105
+ api_url: str,
106
+ contract_address: str,
107
+ email: Optional[str] = None,
108
+ password: Optional[str] = None,
87
109
  llm_server_url: Optional[str] = DEFAULT_LLM_SERVER_URL,
88
110
  og_llm_server_url: Optional[str] = DEFAULT_OPENGRADIENT_LLM_SERVER_URL,
111
+ og_llm_streaming_server_url: Optional[str] = DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL,
89
112
  openai_api_key: Optional[str] = None,
90
113
  anthropic_api_key: Optional[str] = None,
91
114
  google_api_key: Optional[str] = None,
92
- ):
115
+ ):
93
116
  """
94
117
  Initialize the Client with private key, RPC URL, and contract address.
95
118
 
@@ -120,7 +143,8 @@ class Client:
120
143
 
121
144
  self._llm_server_url = llm_server_url
122
145
  self._og_llm_server_url = og_llm_server_url
123
-
146
+ self._og_llm_streaming_server_url = og_llm_streaming_server_url
147
+
124
148
  self._external_api_keys = {}
125
149
  if openai_api_key or os.getenv("OPENAI_API_KEY"):
126
150
  self._external_api_keys["openai"] = openai_api_key or os.getenv("OPENAI_API_KEY")
@@ -132,7 +156,7 @@ class Client:
132
156
  def set_api_key(self, provider: str, api_key: str):
133
157
  """
134
158
  Set or update API key for an external provider.
135
-
159
+
136
160
  Args:
137
161
  provider: Provider name (e.g., 'openai', 'anthropic', 'google')
138
162
  api_key: The API key for the provider
@@ -142,10 +166,10 @@ class Client:
142
166
  def _is_local_model(self, model_cid: str) -> bool:
143
167
  """
144
168
  Check if a model is hosted locally on OpenGradient.
145
-
169
+
146
170
  Args:
147
171
  model_cid: Model identifier
148
-
172
+
149
173
  Returns:
150
174
  True if model is local, False if it should use external provider
151
175
  """
@@ -158,7 +182,7 @@ class Client:
158
182
  def _get_provider_from_model(self, model: str) -> str:
159
183
  """Infer provider from model name."""
160
184
  model_lower = model.lower()
161
-
185
+
162
186
  if "gpt" in model_lower or model.startswith("openai/"):
163
187
  return "openai"
164
188
  elif "claude" in model_lower or model.startswith("anthropic/"):
@@ -173,10 +197,10 @@ class Client:
173
197
  def _get_api_key_for_model(self, model: str) -> Optional[str]:
174
198
  """
175
199
  Get the appropriate API key for a model.
176
-
200
+
177
201
  Args:
178
202
  model: Model identifier
179
-
203
+
180
204
  Returns:
181
205
  API key string or None
182
206
  """
@@ -418,11 +442,11 @@ class Client:
418
442
 
419
443
  return run_with_retry(execute_transaction, max_retries)
420
444
 
421
- def _og_payment_selector(self, accepts, network_filter=None, scheme_filter=None, max_value=None):
422
- """Custom payment selector for OpenGradient network (og-devnet)."""
445
+ def _og_payment_selector(self, accepts, network_filter=DEFAULT_NETWORK_FILTER, scheme_filter=None, max_value=None):
446
+ """Custom payment selector for OpenGradient network."""
423
447
  return x402Client.default_payment_requirements_selector(
424
448
  accepts,
425
- network_filter="og-devnet",
449
+ network_filter=network_filter,
426
450
  scheme_filter=scheme_filter,
427
451
  max_value=max_value,
428
452
  )
@@ -451,11 +475,17 @@ class Client:
451
475
  temperature (float): Temperature for LLM inference, between 0 and 1. Default is 0.0.
452
476
  max_retries (int, optional): Maximum number of retry attempts for blockchain transactions.
453
477
  local_model (bool, optional): Force use of local model even if not in LLM enum.
478
+ x402_settlement_mode (x402SettlementMode, optional): Settlement mode for x402 payments.
479
+ - SETTLE: Records input/output hashes only (most privacy-preserving).
480
+ - SETTLE_BATCH: Aggregates multiple inferences into batch hashes (most cost-efficient).
481
+ - SETTLE_METADATA: Records full model info, complete input/output data, and all metadata.
482
+ Defaults to SETTLE_BATCH.
454
483
 
455
484
  Returns:
456
485
  TextGenerationOutput: Generated text results including:
457
486
  - Transaction hash (or "external" for external providers)
458
487
  - String of completion output
488
+ - Payment hash for x402 transactions (when using x402 settlement)
459
489
 
460
490
  Raises:
461
491
  OpenGradientError: If the inference fails.
@@ -467,14 +497,14 @@ class Client:
467
497
  return OpenGradientError("That model CID is not supported yet for TEE inference")
468
498
 
469
499
  return self._external_llm_completion(
470
- model=model_cid.split('/')[1],
500
+ model=model_cid.split("/")[1],
471
501
  prompt=prompt,
472
502
  max_tokens=max_tokens,
473
503
  stop_sequence=stop_sequence,
474
504
  temperature=temperature,
475
505
  x402_settlement_mode=x402_settlement_mode,
476
506
  )
477
-
507
+
478
508
  # Original local model logic
479
509
  def execute_transaction():
480
510
  if inference_mode != LlmInferenceMode.VANILLA:
@@ -482,10 +512,10 @@ class Client:
482
512
 
483
513
  if model_cid not in [llm.value for llm in LLM]:
484
514
  raise OpenGradientError("That model CID is not yet supported for inference")
485
-
515
+
486
516
  model_name = model_cid
487
517
  if model_cid in [llm.value for llm in TEE_LLM]:
488
- model_name = model_cid.split('/')[1]
518
+ model_name = model_cid.split("/")[1]
489
519
 
490
520
  contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
491
521
 
@@ -523,55 +553,49 @@ class Client:
523
553
  ) -> TextGenerationOutput:
524
554
  """
525
555
  Route completion request to external LLM server with x402 payments.
526
-
556
+
527
557
  Args:
528
558
  model: Model identifier
529
559
  prompt: Input prompt
530
560
  max_tokens: Maximum tokens to generate
531
561
  stop_sequence: Stop sequences
532
562
  temperature: Sampling temperature
533
-
563
+
534
564
  Returns:
535
565
  TextGenerationOutput with completion
536
-
566
+
537
567
  Raises:
538
568
  OpenGradientError: If request fails
539
569
  """
540
570
  api_key = self._get_api_key_for_model(model)
541
-
571
+
542
572
  if api_key:
543
573
  logging.debug("External LLM completions using API key")
544
574
  url = f"{self._llm_server_url}/v1/completions"
545
-
546
- headers = {
547
- "Content-Type": "application/json",
548
- "Authorization": f"Bearer {api_key}"
549
- }
550
-
575
+
576
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
577
+
551
578
  payload = {
552
579
  "model": model,
553
580
  "prompt": prompt,
554
581
  "max_tokens": max_tokens,
555
582
  "temperature": temperature,
556
583
  }
557
-
584
+
558
585
  if stop_sequence:
559
586
  payload["stop"] = stop_sequence
560
-
587
+
561
588
  try:
562
589
  response = requests.post(url, json=payload, headers=headers, timeout=60)
563
590
  response.raise_for_status()
564
-
591
+
565
592
  result = response.json()
566
-
567
- return TextGenerationOutput(
568
- transaction_hash="external",
569
- completion_output=result.get("completion")
570
- )
571
-
593
+
594
+ return TextGenerationOutput(transaction_hash="external", completion_output=result.get("completion"))
595
+
572
596
  except requests.RequestException as e:
573
597
  error_msg = f"External LLM completion failed: {str(e)}"
574
- if hasattr(e, 'response') and e.response is not None:
598
+ if hasattr(e, "response") and e.response is not None:
575
599
  try:
576
600
  error_detail = e.response.json()
577
601
  error_msg += f" - {error_detail}"
@@ -591,20 +615,20 @@ class Client:
591
615
  "Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}",
592
616
  "X-SETTLEMENT-TYPE": x402_settlement_mode,
593
617
  }
594
-
618
+
595
619
  payload = {
596
620
  "model": model,
597
621
  "prompt": prompt,
598
622
  "max_tokens": max_tokens,
599
623
  "temperature": temperature,
600
624
  }
601
-
625
+
602
626
  if stop_sequence:
603
627
  payload["stop"] = stop_sequence
604
-
628
+
605
629
  try:
606
630
  response = await client.post("/v1/completions", json=payload, headers=headers, timeout=60)
607
-
631
+
608
632
  # Read the response content
609
633
  content = await response.aread()
610
634
  result = json.loads(content.decode())
@@ -612,24 +636,22 @@ class Client:
612
636
 
613
637
  if X402_PROCESSING_HASH_HEADER in response.headers:
614
638
  payment_hash = response.headers[X402_PROCESSING_HASH_HEADER]
615
-
639
+
616
640
  return TextGenerationOutput(
617
- transaction_hash="external",
618
- completion_output=result.get("completion"),
619
- payment_hash=payment_hash
641
+ transaction_hash="external", completion_output=result.get("completion"), payment_hash=payment_hash
620
642
  )
621
-
643
+
622
644
  except Exception as e:
623
645
  error_msg = f"External LLM completion request failed: {str(e)}"
624
646
  logging.error(error_msg)
625
647
  raise OpenGradientError(error_msg)
626
-
648
+
627
649
  try:
628
650
  # Run the async function in a sync context
629
651
  return asyncio.run(make_request())
630
652
  except Exception as e:
631
653
  error_msg = f"External LLM completion failed: {str(e)}"
632
- if hasattr(e, 'response') and e.response is not None:
654
+ if hasattr(e, "response") and e.response is not None:
633
655
  try:
634
656
  error_detail = e.response.json()
635
657
  error_msg += f" - {error_detail}"
@@ -651,7 +673,8 @@ class Client:
651
673
  max_retries: Optional[int] = None,
652
674
  local_model: Optional[bool] = False,
653
675
  x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH,
654
- ) -> TextGenerationOutput:
676
+ stream: bool = False,
677
+ ) -> Union[TextGenerationOutput, TextGenerationStream]:
655
678
  """
656
679
  Perform inference on an LLM model using chat.
657
680
 
@@ -666,9 +689,17 @@ class Client:
666
689
  tool_choice (str, optional): Sets a specific tool to choose.
667
690
  max_retries (int, optional): Maximum number of retry attempts.
668
691
  local_model (bool, optional): Force use of local model.
692
+ x402_settlement_mode (x402SettlementMode, optional): Settlement mode for x402 payments.
693
+ - SETTLE: Records input/output hashes only (most privacy-preserving).
694
+ - SETTLE_BATCH: Aggregates multiple inferences into batch hashes (most cost-efficient).
695
+ - SETTLE_METADATA: Records full model info, complete input/output data, and all metadata.
696
+ Defaults to SETTLE_BATCH.
697
+ stream (bool, optional): Whether to stream the response. Default is False.
669
698
 
670
699
  Returns:
671
- TextGenerationOutput: Generated text results.
700
+ Union[TextGenerationOutput, TextGenerationStream]:
701
+ - If stream=False: TextGenerationOutput with chat_output, transaction_hash, finish_reason, and payment_hash
702
+ - If stream=True: TextGenerationStream yielding StreamChunk objects with typed deltas (true streaming via threading)
672
703
 
673
704
  Raises:
674
705
  OpenGradientError: If the inference fails.
@@ -679,28 +710,45 @@ class Client:
679
710
  if model_cid not in TEE_LLM:
680
711
  return OpenGradientError("That model CID is not supported yet for TEE inference")
681
712
 
682
- return self._external_llm_chat(
683
- model=model_cid.split('/')[1],
684
- messages=messages,
685
- max_tokens=max_tokens,
686
- stop_sequence=stop_sequence,
687
- temperature=temperature,
688
- tools=tools,
689
- tool_choice=tool_choice,
690
- x402_settlement_mode=x402_settlement_mode,
691
- )
692
-
713
+ if stream:
714
+ # Use threading bridge for true sync streaming
715
+ return self._external_llm_chat_stream_sync(
716
+ model=model_cid.split("/")[1],
717
+ messages=messages,
718
+ max_tokens=max_tokens,
719
+ stop_sequence=stop_sequence,
720
+ temperature=temperature,
721
+ tools=tools,
722
+ tool_choice=tool_choice,
723
+ x402_settlement_mode=x402_settlement_mode,
724
+ use_tee=True,
725
+ )
726
+ else:
727
+ # Non-streaming
728
+ return self._external_llm_chat(
729
+ model=model_cid.split("/")[1],
730
+ messages=messages,
731
+ max_tokens=max_tokens,
732
+ stop_sequence=stop_sequence,
733
+ temperature=temperature,
734
+ tools=tools,
735
+ tool_choice=tool_choice,
736
+ x402_settlement_mode=x402_settlement_mode,
737
+ stream=False,
738
+ use_tee=True,
739
+ )
740
+
693
741
  # Original local model logic
694
742
  def execute_transaction():
695
743
  if inference_mode != LlmInferenceMode.VANILLA:
696
744
  raise OpenGradientError("Invalid inference mode %s: Inference mode must be VANILLA or TEE" % inference_mode)
697
-
745
+
698
746
  if model_cid not in [llm.value for llm in LLM]:
699
747
  raise OpenGradientError("That model CID is not yet supported for inference")
700
-
748
+
701
749
  model_name = model_cid
702
750
  if model_cid in [llm.value for llm in TEE_LLM]:
703
- model_name = model_cid.split('/')[1]
751
+ model_name = model_cid.split("/")[1]
704
752
 
705
753
  contract = self._blockchain.eth.contract(address=self._inference_hub_contract_address, abi=self._inference_abi)
706
754
 
@@ -768,10 +816,12 @@ class Client:
768
816
  tools: Optional[List[Dict]] = None,
769
817
  tool_choice: Optional[str] = None,
770
818
  x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH,
771
- ) -> TextGenerationOutput:
819
+ stream: bool = False,
820
+ use_tee: bool = False,
821
+ ) -> Union[TextGenerationOutput, TextGenerationStream]:
772
822
  """
773
823
  Route chat request to external LLM server with x402 payments.
774
-
824
+
775
825
  Args:
776
826
  model: Model identifier
777
827
  messages: List of chat messages
@@ -780,53 +830,63 @@ class Client:
780
830
  temperature: Sampling temperature
781
831
  tools: Function calling tools
782
832
  tool_choice: Tool selection strategy
783
-
833
+ stream: Whether to stream the response
834
+ use_tee: Whether to use TEE
835
+
784
836
  Returns:
785
- TextGenerationOutput with chat completion
786
-
837
+ Union[TextGenerationOutput, TextGenerationStream]: Chat completion or TextGenerationStream
838
+
787
839
  Raises:
788
840
  OpenGradientError: If request fails
789
841
  """
790
- api_key = self._get_api_key_for_model(model)
791
-
842
+ api_key = None if use_tee else self._get_api_key_for_model(model)
843
+
792
844
  if api_key:
793
- logging.debug("External LLM completion using API key")
794
- url = f"{self._llm_server_url}/v1/chat/completions"
795
-
796
- headers = {
797
- "Content-Type": "application/json",
798
- "Authorization": f"Bearer {api_key}"
799
- }
845
+ logging.debug("External LLM chat using API key")
800
846
 
847
+ if stream:
848
+ url = f"{self._llm_server_url}/v1/chat/completions/stream"
849
+ else:
850
+ url = f"{self._llm_server_url}/v1/chat/completions"
851
+
852
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
853
+
801
854
  payload = {
802
855
  "model": model,
803
856
  "messages": messages,
804
857
  "max_tokens": max_tokens,
805
858
  "temperature": temperature,
806
859
  }
807
-
860
+
808
861
  if stop_sequence:
809
862
  payload["stop"] = stop_sequence
810
-
863
+
811
864
  if tools:
812
865
  payload["tools"] = tools
813
866
  payload["tool_choice"] = tool_choice or "auto"
814
-
867
+
815
868
  try:
816
- response = requests.post(url, json=payload, headers=headers, timeout=60)
817
- response.raise_for_status()
818
-
819
- result = response.json()
820
-
821
- return TextGenerationOutput(
822
- transaction_hash="external",
823
- finish_reason=result.get("finish_reason"),
824
- chat_output=result.get("message")
825
- )
826
-
869
+ if stream:
870
+ # Return streaming response wrapped in TextGenerationStream
871
+ response = requests.post(url, json=payload, headers=headers, timeout=60, stream=True)
872
+ response.raise_for_status()
873
+ return TextGenerationStream(_iterator=response.iter_lines(decode_unicode=True), _is_async=False)
874
+ else:
875
+ # Non-streaming response
876
+ response = requests.post(url, json=payload, headers=headers, timeout=60)
877
+ response.raise_for_status()
878
+
879
+ result = response.json()
880
+
881
+ return TextGenerationOutput(
882
+ transaction_hash="external",
883
+ finish_reason=result.get("finish_reason"),
884
+ chat_output=result.get("message")
885
+ )
886
+
827
887
  except requests.RequestException as e:
828
888
  error_msg = f"External LLM chat failed: {str(e)}"
829
- if hasattr(e, 'response') and e.response is not None:
889
+ if hasattr(e, "response") and e.response is not None:
830
890
  try:
831
891
  error_detail = e.response.json()
832
892
  error_msg += f" - {error_detail}"
@@ -835,6 +895,7 @@ class Client:
835
895
  logging.error(error_msg)
836
896
  raise OpenGradientError(error_msg)
837
897
 
898
+ # x402 payment path - non-streaming only here
838
899
  async def make_request():
839
900
  async with x402HttpxClient(
840
901
  account=self._wallet_account,
@@ -844,58 +905,58 @@ class Client:
844
905
  headers = {
845
906
  "Content-Type": "application/json",
846
907
  "Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}",
847
- "X-SETTLEMENT-TYPE": x402_settlement_mode
908
+ "X-SETTLEMENT-TYPE": x402_settlement_mode,
848
909
  }
849
-
910
+
850
911
  payload = {
851
912
  "model": model,
852
913
  "messages": messages,
853
914
  "max_tokens": max_tokens,
854
915
  "temperature": temperature,
855
916
  }
856
-
917
+
857
918
  if stop_sequence:
858
919
  payload["stop"] = stop_sequence
859
-
920
+
860
921
  if tools:
861
922
  payload["tools"] = tools
862
923
  payload["tool_choice"] = tool_choice or "auto"
863
-
924
+
864
925
  try:
865
- response = await client.post("/v1/chat/completions", json=payload, headers=headers, timeout=60)
866
-
926
+ # Non-streaming with x402
927
+ endpoint = "/v1/chat/completions"
928
+ response = await client.post(endpoint, json=payload, headers=headers, timeout=60)
929
+
867
930
  # Read the response content
868
931
  content = await response.aread()
869
932
  result = json.loads(content.decode())
870
- # print(f"Response: {response}")
871
- # print(f"Response Headers: {response.headers}")
872
933
 
873
934
  payment_hash = ""
874
935
  if X402_PROCESSING_HASH_HEADER in response.headers:
875
936
  payment_hash = response.headers[X402_PROCESSING_HASH_HEADER]
876
-
937
+
877
938
  choices = result.get("choices")
878
939
  if not choices:
879
940
  raise OpenGradientError(f"Invalid response: 'choices' missing or empty in {result}")
880
-
941
+
881
942
  return TextGenerationOutput(
882
943
  transaction_hash="external",
883
944
  finish_reason=choices[0].get("finish_reason"),
884
945
  chat_output=choices[0].get("message"),
885
- payment_hash=payment_hash
946
+ payment_hash=payment_hash,
886
947
  )
887
-
948
+
888
949
  except Exception as e:
889
950
  error_msg = f"External LLM chat request failed: {str(e)}"
890
951
  logging.error(error_msg)
891
952
  raise OpenGradientError(error_msg)
892
-
953
+
893
954
  try:
894
955
  # Run the async function in a sync context
895
956
  return asyncio.run(make_request())
896
957
  except Exception as e:
897
958
  error_msg = f"External LLM chat failed: {str(e)}"
898
- if hasattr(e, 'response') and e.response is not None:
959
+ if hasattr(e, "response") and e.response is not None:
899
960
  try:
900
961
  error_detail = e.response.json()
901
962
  error_msg += f" - {error_detail}"
@@ -904,6 +965,234 @@ class Client:
904
965
  logging.error(error_msg)
905
966
  raise OpenGradientError(error_msg)
906
967
 
968
+ def _external_llm_chat_stream_sync(
969
+ self,
970
+ model: str,
971
+ messages: List[Dict],
972
+ max_tokens: int = 100,
973
+ stop_sequence: Optional[List[str]] = None,
974
+ temperature: float = 0.0,
975
+ tools: Optional[List[Dict]] = None,
976
+ tool_choice: Optional[str] = None,
977
+ x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH,
978
+ use_tee: bool = False,
979
+ ):
980
+ """
981
+ Sync streaming using threading bridge - TRUE real-time streaming.
982
+
983
+ Yields StreamChunk objects as they arrive from the background thread.
984
+ NO buffering, NO conversion, just direct pass-through.
985
+ """
986
+ import threading
987
+ from queue import Queue
988
+
989
+ queue = Queue()
990
+ exception_holder = []
991
+
992
+ def _run_async():
993
+ """Run async streaming in background thread"""
994
+ loop = None
995
+ try:
996
+ loop = asyncio.new_event_loop()
997
+ asyncio.set_event_loop(loop)
998
+
999
+ async def _stream():
1000
+ try:
1001
+ async for chunk in self._external_llm_chat_stream_async(
1002
+ model=model,
1003
+ messages=messages,
1004
+ max_tokens=max_tokens,
1005
+ stop_sequence=stop_sequence,
1006
+ temperature=temperature,
1007
+ tools=tools,
1008
+ tool_choice=tool_choice,
1009
+ x402_settlement_mode=x402_settlement_mode,
1010
+ use_tee=use_tee,
1011
+ ):
1012
+ queue.put(chunk) # Put chunk immediately
1013
+ except Exception as e:
1014
+ exception_holder.append(e)
1015
+ finally:
1016
+ queue.put(None) # Signal completion
1017
+
1018
+ loop.run_until_complete(_stream())
1019
+ except Exception as e:
1020
+ exception_holder.append(e)
1021
+ queue.put(None)
1022
+ finally:
1023
+ if loop:
1024
+ try:
1025
+ pending = asyncio.all_tasks(loop)
1026
+ for task in pending:
1027
+ task.cancel()
1028
+ loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
1029
+ finally:
1030
+ loop.close()
1031
+
1032
+ # Start background thread
1033
+ thread = threading.Thread(target=_run_async, daemon=True)
1034
+ thread.start()
1035
+
1036
+ # Yield chunks DIRECTLY as they arrive - NO buffering
1037
+ try:
1038
+ while True:
1039
+ chunk = queue.get() # Blocks until chunk available
1040
+ if chunk is None:
1041
+ break
1042
+ yield chunk # Yield immediately!
1043
+
1044
+ thread.join(timeout=5)
1045
+
1046
+ if exception_holder:
1047
+ raise exception_holder[0]
1048
+ except Exception as e:
1049
+ thread.join(timeout=1)
1050
+ raise
1051
+
1052
+
1053
+ async def _external_llm_chat_stream_async(
1054
+ self,
1055
+ model: str,
1056
+ messages: List[Dict],
1057
+ max_tokens: int = 100,
1058
+ stop_sequence: Optional[List[str]] = None,
1059
+ temperature: float = 0.0,
1060
+ tools: Optional[List[Dict]] = None,
1061
+ tool_choice: Optional[str] = None,
1062
+ x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH,
1063
+ use_tee: bool = False,
1064
+ ):
1065
+ """
1066
+ Internal async streaming implementation.
1067
+
1068
+ Yields StreamChunk objects as they arrive from the server.
1069
+ """
1070
+ api_key = None if use_tee else self._get_api_key_for_model(model)
1071
+
1072
+ if api_key:
1073
+ # API key path - streaming to local llm-server
1074
+ url = f"{self._og_llm_streaming_server_url}/v1/chat/completions"
1075
+ headers = {
1076
+ "Content-Type": "application/json",
1077
+ "Authorization": f"Bearer {api_key}"
1078
+ }
1079
+
1080
+ payload = {
1081
+ "model": model,
1082
+ "messages": messages,
1083
+ "max_tokens": max_tokens,
1084
+ "temperature": temperature,
1085
+ "stream": True,
1086
+ }
1087
+
1088
+ if stop_sequence:
1089
+ payload["stop"] = stop_sequence
1090
+ if tools:
1091
+ payload["tools"] = tools
1092
+ payload["tool_choice"] = tool_choice or "auto"
1093
+
1094
+ async with httpx.AsyncClient(verify=False, timeout=None) as client:
1095
+ async with client.stream("POST", url, json=payload, headers=headers) as response:
1096
+ buffer = b""
1097
+ async for chunk in response.aiter_raw():
1098
+ if not chunk:
1099
+ continue
1100
+
1101
+ buffer += chunk
1102
+
1103
+ # Process all complete lines in buffer
1104
+ while b"\n" in buffer:
1105
+ line_bytes, buffer = buffer.split(b"\n", 1)
1106
+
1107
+ if not line_bytes.strip():
1108
+ continue
1109
+
1110
+ try:
1111
+ line = line_bytes.decode('utf-8').strip()
1112
+ except UnicodeDecodeError:
1113
+ continue
1114
+
1115
+ if not line.startswith("data: "):
1116
+ continue
1117
+
1118
+ data_str = line[6:] # Strip "data: " prefix
1119
+ if data_str.strip() == "[DONE]":
1120
+ return
1121
+
1122
+ try:
1123
+ data = json.loads(data_str)
1124
+ yield StreamChunk.from_sse_data(data)
1125
+ except json.JSONDecodeError:
1126
+ continue
1127
+ else:
1128
+ # x402 payment path
1129
+ async with httpx.AsyncClient(
1130
+ base_url=self._og_llm_streaming_server_url,
1131
+ headers={"Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}"},
1132
+ timeout=TIMEOUT,
1133
+ limits=LIMITS,
1134
+ http2=False,
1135
+ follow_redirects=False,
1136
+ auth=X402Auth(account=self._wallet_account), # type: ignore
1137
+ ) as client:
1138
+ headers = {
1139
+ "Content-Type": "application/json",
1140
+ "Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}",
1141
+ "X-SETTLEMENT-TYPE": x402_settlement_mode,
1142
+ }
1143
+
1144
+ payload = {
1145
+ "model": model,
1146
+ "messages": messages,
1147
+ "max_tokens": max_tokens,
1148
+ "temperature": temperature,
1149
+ "stream": True,
1150
+ }
1151
+
1152
+ if stop_sequence:
1153
+ payload["stop"] = stop_sequence
1154
+ if tools:
1155
+ payload["tools"] = tools
1156
+ payload["tool_choice"] = tool_choice or "auto"
1157
+
1158
+ async with client.stream(
1159
+ "POST",
1160
+ "/v1/chat/completions",
1161
+ json=payload,
1162
+ headers=headers,
1163
+ ) as response:
1164
+ buffer = b""
1165
+ async for chunk in response.aiter_raw():
1166
+ if not chunk:
1167
+ continue
1168
+
1169
+ buffer += chunk
1170
+
1171
+ # Process complete lines from buffer
1172
+ while b"\n" in buffer:
1173
+ line_bytes, buffer = buffer.split(b"\n", 1)
1174
+
1175
+ if not line_bytes.strip():
1176
+ continue
1177
+
1178
+ try:
1179
+ line = line_bytes.decode('utf-8').strip()
1180
+ except UnicodeDecodeError:
1181
+ continue
1182
+
1183
+ if not line.startswith("data: "):
1184
+ continue
1185
+
1186
+ data_str = line[6:]
1187
+ if data_str.strip() == "[DONE]":
1188
+ return
1189
+
1190
+ try:
1191
+ data = json.loads(data_str)
1192
+ yield StreamChunk.from_sse_data(data)
1193
+ except json.JSONDecodeError:
1194
+ continue
1195
+
907
1196
  def list_files(self, model_name: str, version: str) -> List[Dict]:
908
1197
  """
909
1198
  List files for a specific version of a model.
@@ -1104,12 +1393,12 @@ class Client:
1104
1393
  except ContractLogicError as e:
1105
1394
  try:
1106
1395
  run_function.call({"from": self._wallet_account.address})
1107
-
1396
+
1108
1397
  except ContractLogicError as call_err:
1109
1398
  raise ContractLogicError(f"simulation failed with revert reason: {call_err.args[0]}")
1110
-
1399
+
1111
1400
  raise ContractLogicError(f"simulation failed with no revert reason. Reason: {e}")
1112
-
1401
+
1113
1402
  gas_limit = int(estimated_gas * 3)
1114
1403
 
1115
1404
  transaction = run_function.build_transaction(
@@ -1128,10 +1417,10 @@ class Client:
1128
1417
  if tx_receipt["status"] == 0:
1129
1418
  try:
1130
1419
  run_function.call({"from": self._wallet_account.address})
1131
-
1420
+
1132
1421
  except ContractLogicError as call_err:
1133
1422
  raise ContractLogicError(f"Transaction failed with revert reason: {call_err.args[0]}")
1134
-
1423
+
1135
1424
  raise ContractLogicError(f"Transaction failed with no revert reason. Receipt: {tx_receipt}")
1136
1425
 
1137
1426
  return tx_hash, tx_receipt
@@ -1346,45 +1635,42 @@ class Client:
1346
1635
  results = contract.functions.getLastInferenceResults(num_results).call()
1347
1636
  return [convert_array_to_model_output(result) for result in results]
1348
1637
 
1349
-
1350
1638
  def _get_inference_result_from_node(self, inference_id: str, inference_mode: InferenceMode) -> Dict:
1351
1639
  """
1352
1640
  Get the inference result from node.
1353
-
1641
+
1354
1642
  Args:
1355
1643
  inference_id (str): Inference id for a inference request
1356
-
1644
+
1357
1645
  Returns:
1358
1646
  Dict: The inference result as returned by the node
1359
-
1647
+
1360
1648
  Raises:
1361
1649
  OpenGradientError: If the request fails or returns an error
1362
1650
  """
1363
1651
  try:
1364
- encoded_id = urllib.parse.quote(inference_id, safe='')
1652
+ encoded_id = urllib.parse.quote(inference_id, safe="")
1365
1653
  url = f"{self._api_url}/artela-network/artela-rollkit/inference/tx/{encoded_id}"
1366
-
1654
+
1367
1655
  response = requests.get(url)
1368
1656
  if response.status_code == 200:
1369
1657
  resp = response.json()
1370
1658
  inference_result = resp.get("inference_results", {})
1371
1659
  if inference_result:
1372
1660
  decoded_bytes = base64.b64decode(inference_result[0])
1373
- decoded_string = decoded_bytes.decode('utf-8')
1374
- output = json.loads(decoded_string).get("InferenceResult",{})
1661
+ decoded_string = decoded_bytes.decode("utf-8")
1662
+ output = json.loads(decoded_string).get("InferenceResult", {})
1375
1663
  if output is None:
1376
1664
  raise OpenGradientError("Missing InferenceResult in inference output")
1377
-
1665
+
1378
1666
  match inference_mode:
1379
1667
  case InferenceMode.VANILLA:
1380
1668
  if "VanillaResult" not in output:
1381
1669
  raise OpenGradientError("Missing VanillaResult in inference output")
1382
1670
  if "model_output" not in output["VanillaResult"]:
1383
1671
  raise OpenGradientError("Missing model_output in VanillaResult")
1384
- return {
1385
- "output": output["VanillaResult"]["model_output"]
1386
- }
1387
-
1672
+ return {"output": output["VanillaResult"]["model_output"]}
1673
+
1388
1674
  case InferenceMode.TEE:
1389
1675
  if "TeeNodeResult" not in output:
1390
1676
  raise OpenGradientError("Missing TeeNodeResult in inference output")
@@ -1393,34 +1679,30 @@ class Client:
1393
1679
  if "VanillaResponse" in output["TeeNodeResult"]["Response"]:
1394
1680
  if "model_output" not in output["TeeNodeResult"]["Response"]["VanillaResponse"]:
1395
1681
  raise OpenGradientError("Missing model_output in VanillaResponse")
1396
- return {
1397
- "output": output["TeeNodeResult"]["Response"]["VanillaResponse"]["model_output"]
1398
- }
1399
-
1682
+ return {"output": output["TeeNodeResult"]["Response"]["VanillaResponse"]["model_output"]}
1683
+
1400
1684
  else:
1401
1685
  raise OpenGradientError("Missing VanillaResponse in TeeNodeResult Response")
1402
-
1686
+
1403
1687
  case InferenceMode.ZKML:
1404
1688
  if "ZkmlResult" not in output:
1405
1689
  raise OpenGradientError("Missing ZkmlResult in inference output")
1406
1690
  if "model_output" not in output["ZkmlResult"]:
1407
1691
  raise OpenGradientError("Missing model_output in ZkmlResult")
1408
- return {
1409
- "output": output["ZkmlResult"]["model_output"]
1410
- }
1411
-
1692
+ return {"output": output["ZkmlResult"]["model_output"]}
1693
+
1412
1694
  case _:
1413
1695
  raise OpenGradientError(f"Invalid inference mode: {inference_mode}")
1414
1696
  else:
1415
1697
  return None
1416
-
1698
+
1417
1699
  else:
1418
1700
  error_message = f"Failed to get inference result: HTTP {response.status_code}"
1419
1701
  if response.text:
1420
1702
  error_message += f" - {response.text}"
1421
1703
  logging.error(error_message)
1422
1704
  raise OpenGradientError(error_message)
1423
-
1705
+
1424
1706
  except requests.RequestException as e:
1425
1707
  logging.error(f"Request exception when getting inference result: {str(e)}")
1426
1708
  raise OpenGradientError(f"Failed to get inference result: {str(e)}")
@@ -1428,6 +1710,7 @@ class Client:
1428
1710
  logging.error(f"Unexpected error when getting inference result: {str(e)}", exc_info=True)
1429
1711
  raise OpenGradientError(f"Failed to get inference result: {str(e)}")
1430
1712
 
1713
+
1431
1714
  def run_with_retry(txn_function: Callable, max_retries=DEFAULT_MAX_RETRY, retry_delay=DEFAULT_RETRY_DELAY_SEC):
1432
1715
  """
1433
1716
  Execute a blockchain transaction with retry logic.