openhands-sdk 1.10.0__py3-none-any.whl → 1.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openhands/sdk/agent/agent.py +60 -27
- openhands/sdk/agent/base.py +1 -1
- openhands/sdk/context/condenser/base.py +36 -3
- openhands/sdk/context/condenser/llm_summarizing_condenser.py +65 -1
- openhands/sdk/context/prompts/templates/system_message_suffix.j2 +2 -1
- openhands/sdk/context/skills/skill.py +15 -30
- openhands/sdk/conversation/base.py +31 -0
- openhands/sdk/conversation/conversation.py +5 -0
- openhands/sdk/conversation/impl/local_conversation.py +63 -13
- openhands/sdk/conversation/impl/remote_conversation.py +128 -13
- openhands/sdk/conversation/state.py +19 -0
- openhands/sdk/conversation/stuck_detector.py +18 -9
- openhands/sdk/llm/__init__.py +16 -0
- openhands/sdk/llm/auth/__init__.py +28 -0
- openhands/sdk/llm/auth/credentials.py +157 -0
- openhands/sdk/llm/auth/openai.py +762 -0
- openhands/sdk/llm/llm.py +175 -20
- openhands/sdk/llm/message.py +21 -11
- openhands/sdk/llm/options/responses_options.py +8 -7
- openhands/sdk/llm/utils/model_features.py +2 -0
- openhands/sdk/llm/utils/verified_models.py +3 -0
- openhands/sdk/mcp/tool.py +27 -4
- openhands/sdk/secret/secrets.py +13 -1
- openhands/sdk/workspace/remote/base.py +8 -3
- openhands/sdk/workspace/remote/remote_workspace_mixin.py +40 -7
- {openhands_sdk-1.10.0.dist-info → openhands_sdk-1.11.1.dist-info}/METADATA +1 -1
- {openhands_sdk-1.10.0.dist-info → openhands_sdk-1.11.1.dist-info}/RECORD +29 -26
- {openhands_sdk-1.10.0.dist-info → openhands_sdk-1.11.1.dist-info}/WHEEL +0 -0
- {openhands_sdk-1.10.0.dist-info → openhands_sdk-1.11.1.dist-info}/top_level.txt +0 -0
|
@@ -6,7 +6,8 @@ import threading
|
|
|
6
6
|
import time
|
|
7
7
|
import uuid
|
|
8
8
|
from collections.abc import Mapping
|
|
9
|
-
from
|
|
9
|
+
from queue import Empty, Queue
|
|
10
|
+
from typing import TYPE_CHECKING, SupportsIndex, overload
|
|
10
11
|
from urllib.parse import urlparse
|
|
11
12
|
|
|
12
13
|
import httpx
|
|
@@ -14,6 +15,10 @@ import websockets
|
|
|
14
15
|
|
|
15
16
|
from openhands.sdk.agent.base import AgentBase
|
|
16
17
|
from openhands.sdk.conversation.base import BaseConversation, ConversationStateProtocol
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from openhands.sdk.tool.schema import Action, Observation
|
|
17
22
|
from openhands.sdk.conversation.conversation_stats import ConversationStats
|
|
18
23
|
from openhands.sdk.conversation.events_list_base import EventsListBase
|
|
19
24
|
from openhands.sdk.conversation.exceptions import (
|
|
@@ -555,6 +560,8 @@ class RemoteConversation(BaseConversation):
|
|
|
555
560
|
_client: httpx.Client
|
|
556
561
|
_hook_processor: HookEventProcessor | None
|
|
557
562
|
_cleanup_initiated: bool
|
|
563
|
+
_terminal_status_queue: Queue[str] # Thread-safe queue for terminal status from WS
|
|
564
|
+
delete_on_close: bool = False
|
|
558
565
|
|
|
559
566
|
def __init__(
|
|
560
567
|
self,
|
|
@@ -573,6 +580,7 @@ class RemoteConversation(BaseConversation):
|
|
|
573
580
|
type[ConversationVisualizerBase] | ConversationVisualizerBase | None
|
|
574
581
|
) = DefaultConversationVisualizer,
|
|
575
582
|
secrets: Mapping[str, SecretValue] | None = None,
|
|
583
|
+
delete_on_close: bool = False,
|
|
576
584
|
**_: object,
|
|
577
585
|
) -> None:
|
|
578
586
|
"""Remote conversation proxy that talks to an agent server.
|
|
@@ -607,6 +615,7 @@ class RemoteConversation(BaseConversation):
|
|
|
607
615
|
self._client = workspace.client
|
|
608
616
|
self._hook_processor = None
|
|
609
617
|
self._cleanup_initiated = False
|
|
618
|
+
self._terminal_status_queue: Queue[str] = Queue()
|
|
610
619
|
|
|
611
620
|
should_create = conversation_id is None
|
|
612
621
|
if conversation_id is not None:
|
|
@@ -706,8 +715,21 @@ class RemoteConversation(BaseConversation):
|
|
|
706
715
|
# No visualization (visualizer is None)
|
|
707
716
|
self._visualizer = None
|
|
708
717
|
|
|
718
|
+
# Add a callback that signals when run completes via WebSocket
|
|
719
|
+
# This ensures we wait for all events to be delivered before run() returns
|
|
720
|
+
def run_complete_callback(event: Event) -> None:
|
|
721
|
+
if isinstance(event, ConversationStateUpdateEvent):
|
|
722
|
+
if event.key == "execution_status":
|
|
723
|
+
try:
|
|
724
|
+
status = ConversationExecutionStatus(event.value)
|
|
725
|
+
if status.is_terminal():
|
|
726
|
+
self._terminal_status_queue.put(event.value)
|
|
727
|
+
except ValueError:
|
|
728
|
+
pass # Unknown status value, ignore
|
|
729
|
+
|
|
709
730
|
# Compose all callbacks into a single callback
|
|
710
|
-
|
|
731
|
+
all_callbacks = self._callbacks + [run_complete_callback]
|
|
732
|
+
composed_callback = BaseConversation.compose_callbacks(all_callbacks)
|
|
711
733
|
|
|
712
734
|
# Initialize WebSocket client for callbacks
|
|
713
735
|
self._ws_client = WebSocketCallbackClient(
|
|
@@ -765,6 +787,7 @@ class RemoteConversation(BaseConversation):
|
|
|
765
787
|
)
|
|
766
788
|
self._hook_processor = HookEventProcessor(hook_manager=hook_manager)
|
|
767
789
|
self._hook_processor.run_session_start()
|
|
790
|
+
self.delete_on_close = delete_on_close
|
|
768
791
|
|
|
769
792
|
def _create_llm_completion_log_callback(self) -> ConversationCallbackType:
|
|
770
793
|
"""Create a callback that writes LLM completion logs to client filesystem."""
|
|
@@ -859,6 +882,14 @@ class RemoteConversation(BaseConversation):
|
|
|
859
882
|
Raises:
|
|
860
883
|
ConversationRunError: If the run fails or times out.
|
|
861
884
|
"""
|
|
885
|
+
# Drain any stale terminal status events from previous runs.
|
|
886
|
+
# This prevents stale events from causing early returns.
|
|
887
|
+
while True:
|
|
888
|
+
try:
|
|
889
|
+
self._terminal_status_queue.get_nowait()
|
|
890
|
+
except Empty:
|
|
891
|
+
break
|
|
892
|
+
|
|
862
893
|
# Trigger a run on the server using the dedicated run endpoint.
|
|
863
894
|
# Let the server tell us if it's already running (409), avoiding an extra GET.
|
|
864
895
|
try:
|
|
@@ -886,10 +917,20 @@ class RemoteConversation(BaseConversation):
|
|
|
886
917
|
poll_interval: float = 1.0,
|
|
887
918
|
timeout: float = 1800.0,
|
|
888
919
|
) -> None:
|
|
889
|
-
"""
|
|
920
|
+
"""Wait for the conversation run to complete.
|
|
921
|
+
|
|
922
|
+
This method waits for the run to complete by listening for the terminal
|
|
923
|
+
status event via WebSocket. This ensures all events are delivered before
|
|
924
|
+
returning, avoiding the race condition where polling sees "finished"
|
|
925
|
+
status before WebSocket delivers the final events.
|
|
926
|
+
|
|
927
|
+
As a fallback, it also polls the server periodically. If the WebSocket
|
|
928
|
+
is delayed or disconnected, we return after multiple consecutive polls
|
|
929
|
+
show a terminal status, and reconcile events to catch any that were
|
|
930
|
+
missed via WebSocket.
|
|
890
931
|
|
|
891
932
|
Args:
|
|
892
|
-
poll_interval: Time in seconds between status polls.
|
|
933
|
+
poll_interval: Time in seconds between status polls (fallback).
|
|
893
934
|
timeout: Maximum time in seconds to wait.
|
|
894
935
|
|
|
895
936
|
Raises:
|
|
@@ -898,6 +939,14 @@ class RemoteConversation(BaseConversation):
|
|
|
898
939
|
responses are retried until timeout.
|
|
899
940
|
"""
|
|
900
941
|
start_time = time.monotonic()
|
|
942
|
+
consecutive_terminal_polls = 0
|
|
943
|
+
# Return after this many consecutive terminal polls (fallback for WS issues).
|
|
944
|
+
# We use 3 polls to balance latency vs reliability:
|
|
945
|
+
# - 1 poll could be a transient state during shutdown
|
|
946
|
+
# - 2 polls might still catch a race condition
|
|
947
|
+
# - 3 polls (with default 1s interval = 3s total) provides high confidence
|
|
948
|
+
# that the run is truly complete while keeping fallback latency reasonable
|
|
949
|
+
TERMINAL_POLL_THRESHOLD = 3
|
|
901
950
|
|
|
902
951
|
while True:
|
|
903
952
|
elapsed = time.monotonic() - start_time
|
|
@@ -910,20 +959,57 @@ class RemoteConversation(BaseConversation):
|
|
|
910
959
|
),
|
|
911
960
|
)
|
|
912
961
|
|
|
962
|
+
# Wait for either:
|
|
963
|
+
# 1. WebSocket delivers terminal status event (preferred)
|
|
964
|
+
# 2. Poll interval expires (fallback - check status via REST)
|
|
965
|
+
try:
|
|
966
|
+
ws_status = self._terminal_status_queue.get(timeout=poll_interval)
|
|
967
|
+
# Handle ERROR/STUCK states - raises ConversationRunError
|
|
968
|
+
self._handle_conversation_status(ws_status)
|
|
969
|
+
|
|
970
|
+
logger.info(
|
|
971
|
+
"Run completed via WebSocket notification "
|
|
972
|
+
"(status: %s, elapsed: %.1fs)",
|
|
973
|
+
ws_status,
|
|
974
|
+
elapsed,
|
|
975
|
+
)
|
|
976
|
+
return
|
|
977
|
+
except Empty:
|
|
978
|
+
pass # Queue.get() timed out, fall through to REST polling
|
|
979
|
+
|
|
980
|
+
# Poll the server for status as a health check and fallback.
|
|
981
|
+
# This catches ERROR/STUCK states that need immediate attention,
|
|
982
|
+
# and provides a fallback if WebSocket is delayed/disconnected.
|
|
913
983
|
try:
|
|
914
984
|
status = self._poll_status_once()
|
|
915
985
|
except Exception as exc:
|
|
916
986
|
self._handle_poll_exception(exc)
|
|
987
|
+
consecutive_terminal_polls = 0 # Reset on error
|
|
917
988
|
else:
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
989
|
+
# Raises ConversationRunError for ERROR/STUCK states
|
|
990
|
+
self._handle_conversation_status(status)
|
|
991
|
+
|
|
992
|
+
# Track consecutive terminal polls as a fallback for WS issues.
|
|
993
|
+
# If WebSocket is delayed/disconnected, we return after multiple
|
|
994
|
+
# consecutive polls confirm the terminal status.
|
|
995
|
+
if status and ConversationExecutionStatus(status).is_terminal():
|
|
996
|
+
consecutive_terminal_polls += 1
|
|
997
|
+
if consecutive_terminal_polls >= TERMINAL_POLL_THRESHOLD:
|
|
998
|
+
logger.info(
|
|
999
|
+
"Run completed via REST fallback after %d consecutive "
|
|
1000
|
+
"terminal polls (status: %s, elapsed: %.1fs). "
|
|
1001
|
+
"Reconciling events...",
|
|
1002
|
+
consecutive_terminal_polls,
|
|
1003
|
+
status,
|
|
1004
|
+
elapsed,
|
|
1005
|
+
)
|
|
1006
|
+
# Reconcile events to catch any that were missed via WS.
|
|
1007
|
+
# This is only called in the fallback path, so it doesn't
|
|
1008
|
+
# add overhead in the common case where WS works.
|
|
1009
|
+
self._state.events.reconcile()
|
|
1010
|
+
return
|
|
1011
|
+
else:
|
|
1012
|
+
consecutive_terminal_polls = 0
|
|
927
1013
|
|
|
928
1014
|
def _poll_status_once(self) -> str | None:
|
|
929
1015
|
"""Fetch the current execution status from the remote conversation."""
|
|
@@ -1113,6 +1199,28 @@ class RemoteConversation(BaseConversation):
|
|
|
1113
1199
|
"""
|
|
1114
1200
|
_send_request(self._client, "POST", f"/api/conversations/{self._id}/condense")
|
|
1115
1201
|
|
|
1202
|
+
def execute_tool(self, tool_name: str, action: "Action") -> "Observation":
|
|
1203
|
+
"""Execute a tool directly without going through the agent loop.
|
|
1204
|
+
|
|
1205
|
+
Note: This method is not yet supported for RemoteConversation.
|
|
1206
|
+
Tool execution for remote conversations happens on the server side
|
|
1207
|
+
during the normal agent loop.
|
|
1208
|
+
|
|
1209
|
+
Args:
|
|
1210
|
+
tool_name: The name of the tool to execute
|
|
1211
|
+
action: The action to pass to the tool executor
|
|
1212
|
+
|
|
1213
|
+
Raises:
|
|
1214
|
+
NotImplementedError: Always, as this feature is not yet supported
|
|
1215
|
+
for remote conversations.
|
|
1216
|
+
"""
|
|
1217
|
+
raise NotImplementedError(
|
|
1218
|
+
"execute_tool is not yet supported for RemoteConversation. "
|
|
1219
|
+
"Tool execution for remote conversations happens on the server side "
|
|
1220
|
+
"during the normal agent loop. Use LocalConversation for direct "
|
|
1221
|
+
"tool execution."
|
|
1222
|
+
)
|
|
1223
|
+
|
|
1116
1224
|
def close(self) -> None:
|
|
1117
1225
|
"""Close the conversation and clean up resources.
|
|
1118
1226
|
|
|
@@ -1134,6 +1242,13 @@ class RemoteConversation(BaseConversation):
|
|
|
1134
1242
|
pass
|
|
1135
1243
|
|
|
1136
1244
|
self._end_observability_span()
|
|
1245
|
+
if self.delete_on_close:
|
|
1246
|
+
try:
|
|
1247
|
+
# trigger server-side delete_conversation to release resources
|
|
1248
|
+
# like tmux sessions
|
|
1249
|
+
_send_request(self._client, "DELETE", f"/api/conversations/{self.id}")
|
|
1250
|
+
except Exception:
|
|
1251
|
+
pass
|
|
1137
1252
|
|
|
1138
1253
|
def __del__(self) -> None:
|
|
1139
1254
|
try:
|
|
@@ -45,6 +45,25 @@ class ConversationExecutionStatus(str, Enum):
|
|
|
45
45
|
STUCK = "stuck" # Conversation is stuck in a loop or unable to proceed
|
|
46
46
|
DELETING = "deleting" # Conversation is in the process of being deleted
|
|
47
47
|
|
|
48
|
+
def is_terminal(self) -> bool:
|
|
49
|
+
"""Check if this status represents a terminal state.
|
|
50
|
+
|
|
51
|
+
Terminal states indicate the run has completed and the agent is no longer
|
|
52
|
+
actively processing. These are: FINISHED, ERROR, STUCK.
|
|
53
|
+
|
|
54
|
+
Note: IDLE is NOT a terminal state - it's the initial state of a conversation
|
|
55
|
+
before any run has started. Including IDLE would cause false positives when
|
|
56
|
+
the WebSocket delivers the initial state update during connection.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
True if this is a terminal status, False otherwise.
|
|
60
|
+
"""
|
|
61
|
+
return self in (
|
|
62
|
+
ConversationExecutionStatus.FINISHED,
|
|
63
|
+
ConversationExecutionStatus.ERROR,
|
|
64
|
+
ConversationExecutionStatus.STUCK,
|
|
65
|
+
)
|
|
66
|
+
|
|
48
67
|
|
|
49
68
|
class ConversationState(OpenHandsModel):
|
|
50
69
|
# ===== Public, validated fields =====
|
|
@@ -15,6 +15,12 @@ from openhands.sdk.logger import get_logger
|
|
|
15
15
|
logger = get_logger(__name__)
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
# Maximum recent events to scan for stuck detection.
|
|
19
|
+
# This window should be large enough to capture repetitive patterns
|
|
20
|
+
# (4 repeats × 2 events per cycle = 8 events minimum, plus buffer for user messages)
|
|
21
|
+
MAX_EVENTS_TO_SCAN_FOR_STUCK_DETECTION: int = 20
|
|
22
|
+
|
|
23
|
+
|
|
18
24
|
class StuckDetector:
|
|
19
25
|
"""Detects when an agent is stuck in repetitive or unproductive patterns.
|
|
20
26
|
|
|
@@ -54,8 +60,14 @@ class StuckDetector:
|
|
|
54
60
|
return self.thresholds.alternating_pattern
|
|
55
61
|
|
|
56
62
|
def is_stuck(self) -> bool:
|
|
57
|
-
"""Check if the agent is currently stuck.
|
|
58
|
-
|
|
63
|
+
"""Check if the agent is currently stuck.
|
|
64
|
+
|
|
65
|
+
Note: To avoid materializing potentially large file-backed event histories,
|
|
66
|
+
only the last MAX_EVENTS_TO_SCAN_FOR_STUCK_DETECTION events are analyzed.
|
|
67
|
+
If a user message exists within this window, only events after it are checked.
|
|
68
|
+
Otherwise, all events in the window are analyzed.
|
|
69
|
+
"""
|
|
70
|
+
events = list(self.state.events[-MAX_EVENTS_TO_SCAN_FOR_STUCK_DETECTION:])
|
|
59
71
|
|
|
60
72
|
# Only look at history after the last user message
|
|
61
73
|
last_user_msg_index = next(
|
|
@@ -66,11 +78,8 @@ class StuckDetector:
|
|
|
66
78
|
),
|
|
67
79
|
-1, # Default to -1 if no user message found
|
|
68
80
|
)
|
|
69
|
-
if last_user_msg_index
|
|
70
|
-
|
|
71
|
-
return False
|
|
72
|
-
|
|
73
|
-
events = events[last_user_msg_index + 1 :]
|
|
81
|
+
if last_user_msg_index != -1:
|
|
82
|
+
events = events[last_user_msg_index + 1 :]
|
|
74
83
|
|
|
75
84
|
# Determine minimum events needed
|
|
76
85
|
min_threshold = min(
|
|
@@ -253,10 +262,10 @@ class StuckDetector:
|
|
|
253
262
|
return False
|
|
254
263
|
|
|
255
264
|
def _is_stuck_context_window_error(self, _events: list[Event]) -> bool:
|
|
256
|
-
"""Detects if we
|
|
265
|
+
"""Detects if we are stuck in a loop of context window errors.
|
|
257
266
|
|
|
258
267
|
This happens when we repeatedly get context window errors and try to trim,
|
|
259
|
-
but the trimming
|
|
268
|
+
but the trimming does not work, causing us to get more context window errors.
|
|
260
269
|
The pattern is repeated AgentCondensationObservation events without any other
|
|
261
270
|
events between them.
|
|
262
271
|
"""
|
openhands/sdk/llm/__init__.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
from openhands.sdk.llm.auth import (
|
|
2
|
+
OPENAI_CODEX_MODELS,
|
|
3
|
+
CredentialStore,
|
|
4
|
+
OAuthCredentials,
|
|
5
|
+
OpenAISubscriptionAuth,
|
|
6
|
+
)
|
|
1
7
|
from openhands.sdk.llm.llm import LLM
|
|
2
8
|
from openhands.sdk.llm.llm_registry import LLMRegistry, RegistryEvent
|
|
3
9
|
from openhands.sdk.llm.llm_response import LLMResponse
|
|
@@ -22,11 +28,18 @@ from openhands.sdk.llm.utils.verified_models import VERIFIED_MODELS
|
|
|
22
28
|
|
|
23
29
|
|
|
24
30
|
__all__ = [
|
|
31
|
+
# Auth
|
|
32
|
+
"CredentialStore",
|
|
33
|
+
"OAuthCredentials",
|
|
34
|
+
"OpenAISubscriptionAuth",
|
|
35
|
+
"OPENAI_CODEX_MODELS",
|
|
36
|
+
# Core
|
|
25
37
|
"LLMResponse",
|
|
26
38
|
"LLM",
|
|
27
39
|
"LLMRegistry",
|
|
28
40
|
"RouterLLM",
|
|
29
41
|
"RegistryEvent",
|
|
42
|
+
# Messages
|
|
30
43
|
"Message",
|
|
31
44
|
"MessageToolCall",
|
|
32
45
|
"TextContent",
|
|
@@ -35,10 +48,13 @@ __all__ = [
|
|
|
35
48
|
"RedactedThinkingBlock",
|
|
36
49
|
"ReasoningItemModel",
|
|
37
50
|
"content_to_str",
|
|
51
|
+
# Streaming
|
|
38
52
|
"LLMStreamChunk",
|
|
39
53
|
"TokenCallbackType",
|
|
54
|
+
# Metrics
|
|
40
55
|
"Metrics",
|
|
41
56
|
"MetricsSnapshot",
|
|
57
|
+
# Models
|
|
42
58
|
"VERIFIED_MODELS",
|
|
43
59
|
"UNVERIFIED_MODELS_EXCLUDING_BEDROCK",
|
|
44
60
|
"get_unverified_models",
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Authentication module for LLM subscription-based access.
|
|
2
|
+
|
|
3
|
+
This module provides OAuth-based authentication for LLM providers that support
|
|
4
|
+
subscription-based access (e.g., ChatGPT Plus/Pro for OpenAI Codex models).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from openhands.sdk.llm.auth.credentials import (
|
|
8
|
+
CredentialStore,
|
|
9
|
+
OAuthCredentials,
|
|
10
|
+
)
|
|
11
|
+
from openhands.sdk.llm.auth.openai import (
|
|
12
|
+
OPENAI_CODEX_MODELS,
|
|
13
|
+
OpenAISubscriptionAuth,
|
|
14
|
+
SupportedVendor,
|
|
15
|
+
inject_system_prefix,
|
|
16
|
+
transform_for_subscription,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"CredentialStore",
|
|
22
|
+
"OAuthCredentials",
|
|
23
|
+
"OpenAISubscriptionAuth",
|
|
24
|
+
"OPENAI_CODEX_MODELS",
|
|
25
|
+
"SupportedVendor",
|
|
26
|
+
"inject_system_prefix",
|
|
27
|
+
"transform_for_subscription",
|
|
28
|
+
]
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""Credential storage and retrieval for OAuth-based LLM authentication."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import time
|
|
8
|
+
import warnings
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Literal
|
|
11
|
+
|
|
12
|
+
from pydantic import BaseModel, Field
|
|
13
|
+
|
|
14
|
+
from openhands.sdk.logger import get_logger
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_credentials_dir() -> Path:
|
|
21
|
+
"""Get the directory for storing credentials.
|
|
22
|
+
|
|
23
|
+
Uses XDG_DATA_HOME if set, otherwise defaults to ~/.local/share/openhands.
|
|
24
|
+
"""
|
|
25
|
+
return Path.home() / ".openhands" / "auth"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class OAuthCredentials(BaseModel):
|
|
29
|
+
"""OAuth credentials for subscription-based LLM access."""
|
|
30
|
+
|
|
31
|
+
type: Literal["oauth"] = "oauth"
|
|
32
|
+
vendor: str = Field(description="The vendor/provider (e.g., 'openai')")
|
|
33
|
+
access_token: str = Field(description="The OAuth access token")
|
|
34
|
+
refresh_token: str = Field(description="The OAuth refresh token")
|
|
35
|
+
expires_at: int = Field(
|
|
36
|
+
description="Unix timestamp (ms) when the access token expires"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def is_expired(self) -> bool:
|
|
40
|
+
"""Check if the access token is expired."""
|
|
41
|
+
# Add 60 second buffer to avoid edge cases
|
|
42
|
+
# Add 60 second buffer to avoid edge cases where token expires during request
|
|
43
|
+
return self.expires_at < (int(time.time() * 1000) + 60_000)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CredentialStore:
|
|
47
|
+
"""Store and retrieve OAuth credentials for LLM providers."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, credentials_dir: Path | None = None):
|
|
50
|
+
"""Initialize the credential store.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
credentials_dir: Optional custom directory for storing credentials.
|
|
54
|
+
Defaults to ~/.local/share/openhands/auth/
|
|
55
|
+
"""
|
|
56
|
+
self._credentials_dir = credentials_dir or get_credentials_dir()
|
|
57
|
+
logger.info(f"Using credentials directory: {self._credentials_dir}")
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def credentials_dir(self) -> Path:
|
|
61
|
+
"""Get the credentials directory, creating it if necessary."""
|
|
62
|
+
self._credentials_dir.mkdir(parents=True, exist_ok=True)
|
|
63
|
+
# Set directory permissions to owner-only (rwx------)
|
|
64
|
+
if os.name != "nt":
|
|
65
|
+
self._credentials_dir.chmod(0o700)
|
|
66
|
+
return self._credentials_dir
|
|
67
|
+
|
|
68
|
+
def _get_credentials_file(self, vendor: str) -> Path:
|
|
69
|
+
"""Get the path to the credentials file for a vendor."""
|
|
70
|
+
return self.credentials_dir / f"{vendor}_oauth.json"
|
|
71
|
+
|
|
72
|
+
def get(self, vendor: str) -> OAuthCredentials | None:
|
|
73
|
+
"""Get stored credentials for a vendor.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
vendor: The vendor/provider name (e.g., 'openai')
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
OAuthCredentials if found and valid, None otherwise
|
|
80
|
+
"""
|
|
81
|
+
creds_file = self._get_credentials_file(vendor)
|
|
82
|
+
if not creds_file.exists():
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
with open(creds_file) as f:
|
|
87
|
+
data = json.load(f)
|
|
88
|
+
return OAuthCredentials.model_validate(data)
|
|
89
|
+
except (json.JSONDecodeError, ValueError):
|
|
90
|
+
# Invalid credentials file, remove it
|
|
91
|
+
creds_file.unlink(missing_ok=True)
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
def save(self, credentials: OAuthCredentials) -> None:
|
|
95
|
+
"""Save credentials for a vendor.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
credentials: The OAuth credentials to save
|
|
99
|
+
"""
|
|
100
|
+
creds_file = self._get_credentials_file(credentials.vendor)
|
|
101
|
+
with open(creds_file, "w") as f:
|
|
102
|
+
json.dump(credentials.model_dump(), f, indent=2)
|
|
103
|
+
# Set restrictive permissions (owner read/write only)
|
|
104
|
+
# Note: On Windows, NTFS ACLs should be used instead
|
|
105
|
+
if os.name != "nt": # Not Windows
|
|
106
|
+
creds_file.chmod(0o600)
|
|
107
|
+
else:
|
|
108
|
+
warnings.warn(
|
|
109
|
+
"File permissions on Windows should be manually restricted",
|
|
110
|
+
stacklevel=2,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def delete(self, vendor: str) -> bool:
|
|
114
|
+
"""Delete stored credentials for a vendor.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
vendor: The vendor/provider name
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
True if credentials were deleted, False if they didn't exist
|
|
121
|
+
"""
|
|
122
|
+
creds_file = self._get_credentials_file(vendor)
|
|
123
|
+
if creds_file.exists():
|
|
124
|
+
creds_file.unlink()
|
|
125
|
+
return True
|
|
126
|
+
return False
|
|
127
|
+
|
|
128
|
+
def update_tokens(
|
|
129
|
+
self,
|
|
130
|
+
vendor: str,
|
|
131
|
+
access_token: str,
|
|
132
|
+
refresh_token: str | None,
|
|
133
|
+
expires_in: int,
|
|
134
|
+
) -> OAuthCredentials | None:
|
|
135
|
+
"""Update tokens for an existing credential.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
vendor: The vendor/provider name
|
|
139
|
+
access_token: New access token
|
|
140
|
+
refresh_token: New refresh token (if provided)
|
|
141
|
+
expires_in: Token expiry in seconds
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Updated credentials, or None if no existing credentials found
|
|
145
|
+
"""
|
|
146
|
+
existing = self.get(vendor)
|
|
147
|
+
if existing is None:
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
updated = OAuthCredentials(
|
|
151
|
+
vendor=vendor,
|
|
152
|
+
access_token=access_token,
|
|
153
|
+
refresh_token=refresh_token or existing.refresh_token,
|
|
154
|
+
expires_at=int(time.time() * 1000) + (expires_in * 1000),
|
|
155
|
+
)
|
|
156
|
+
self.save(updated)
|
|
157
|
+
return updated
|