robosystems-client 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of robosystems-client might be problematic. Click here for more details.
- robosystems_client/__init__.py +15 -4
- robosystems_client/api/agent/auto_select_agent.py +25 -0
- robosystems_client/api/agent/batch_process_queries.py +25 -0
- robosystems_client/api/agent/execute_specific_agent.py +25 -0
- robosystems_client/api/agent/get_agent_metadata.py +25 -0
- robosystems_client/api/agent/list_agents.py +20 -0
- robosystems_client/api/agent/recommend_agent.py +25 -0
- robosystems_client/api/backup/create_backup.py +25 -0
- robosystems_client/api/backup/export_backup.py +25 -0
- robosystems_client/api/backup/get_backup_download_url.py +20 -0
- robosystems_client/api/backup/get_backup_stats.py +25 -0
- robosystems_client/api/backup/list_backups.py +20 -0
- robosystems_client/api/backup/restore_backup.py +25 -0
- robosystems_client/api/connections/create_connection.py +25 -0
- robosystems_client/api/connections/create_link_token.py +25 -0
- robosystems_client/api/connections/delete_connection.py +25 -0
- robosystems_client/api/connections/exchange_link_token.py +25 -0
- robosystems_client/api/connections/get_connection.py +25 -0
- robosystems_client/api/connections/get_connection_options.py +25 -0
- robosystems_client/api/connections/init_o_auth.py +25 -0
- robosystems_client/api/connections/list_connections.py +20 -0
- robosystems_client/api/connections/oauth_callback.py +25 -0
- robosystems_client/api/connections/sync_connection.py +25 -0
- robosystems_client/api/copy/copy_data_to_graph.py +25 -0
- robosystems_client/api/create/create_graph.py +25 -0
- robosystems_client/api/graph_analytics/get_graph_metrics.py +25 -0
- robosystems_client/api/graph_analytics/get_graph_usage_stats.py +20 -0
- robosystems_client/api/graph_billing/get_current_graph_bill.py +25 -0
- robosystems_client/api/graph_billing/get_graph_billing_history.py +20 -0
- robosystems_client/api/graph_billing/get_graph_monthly_bill.py +25 -0
- robosystems_client/api/graph_billing/get_graph_usage_details.py +20 -0
- robosystems_client/api/graph_credits/check_credit_balance.py +20 -0
- robosystems_client/api/graph_credits/check_storage_limits.py +25 -0
- robosystems_client/api/graph_credits/get_credit_summary.py +25 -0
- robosystems_client/api/graph_credits/get_storage_usage.py +20 -0
- robosystems_client/api/graph_credits/list_credit_transactions.py +20 -0
- robosystems_client/api/graph_health/get_database_health.py +25 -0
- robosystems_client/api/graph_info/get_database_info.py +25 -0
- robosystems_client/api/graph_limits/get_graph_limits.py +25 -0
- robosystems_client/api/mcp/call_mcp_tool.py +20 -0
- robosystems_client/api/mcp/list_mcp_tools.py +25 -0
- robosystems_client/api/operations/cancel_operation.py +25 -0
- robosystems_client/api/operations/get_operation_status.py +25 -0
- robosystems_client/api/operations/stream_operation_events.py +20 -0
- robosystems_client/api/query/execute_cypher_query.py +20 -0
- robosystems_client/api/schema/export_graph_schema.py +20 -0
- robosystems_client/api/schema/get_graph_schema_info.py +25 -0
- robosystems_client/api/schema/list_schema_extensions.py +25 -0
- robosystems_client/api/schema/validate_schema.py +25 -0
- robosystems_client/api/subgraphs/create_subgraph.py +25 -0
- robosystems_client/api/subgraphs/delete_subgraph.py +25 -0
- robosystems_client/api/subgraphs/get_subgraph_info.py +25 -0
- robosystems_client/api/subgraphs/get_subgraph_quota.py +25 -0
- robosystems_client/api/subgraphs/list_subgraphs.py +25 -0
- robosystems_client/api/user/create_user_api_key.py +25 -0
- robosystems_client/api/user/get_all_credit_summaries.py +25 -0
- robosystems_client/api/user/get_current_user.py +25 -0
- robosystems_client/api/user/get_user_graphs.py +25 -0
- robosystems_client/api/user/list_user_api_keys.py +25 -0
- robosystems_client/api/user/revoke_user_api_key.py +25 -0
- robosystems_client/api/user/select_user_graph.py +25 -0
- robosystems_client/api/user/update_user.py +25 -0
- robosystems_client/api/user/update_user_api_key.py +25 -0
- robosystems_client/api/user/update_user_password.py +25 -0
- robosystems_client/api/user_analytics/get_detailed_user_analytics.py +20 -0
- robosystems_client/api/user_analytics/get_user_usage_overview.py +25 -0
- robosystems_client/api/user_limits/get_all_shared_repository_limits.py +25 -0
- robosystems_client/api/user_limits/get_shared_repository_limits.py +25 -0
- robosystems_client/api/user_limits/get_user_limits.py +25 -0
- robosystems_client/api/user_limits/get_user_usage.py +25 -0
- robosystems_client/api/user_subscriptions/cancel_shared_repository_subscription.py +25 -0
- robosystems_client/api/user_subscriptions/get_repository_credits.py +25 -0
- robosystems_client/api/user_subscriptions/get_shared_repository_credits.py +25 -0
- robosystems_client/api/user_subscriptions/get_user_shared_subscriptions.py +20 -0
- robosystems_client/api/user_subscriptions/subscribe_to_shared_repository.py +25 -0
- robosystems_client/api/user_subscriptions/upgrade_shared_repository_subscription.py +25 -0
- robosystems_client/extensions/__init__.py +70 -0
- robosystems_client/extensions/auth_integration.py +14 -1
- robosystems_client/extensions/copy_client.py +32 -22
- robosystems_client/extensions/dataframe_utils.py +455 -0
- robosystems_client/extensions/extensions.py +16 -0
- robosystems_client/extensions/operation_client.py +43 -21
- robosystems_client/extensions/query_client.py +109 -12
- robosystems_client/extensions/tests/test_dataframe_utils.py +334 -0
- robosystems_client/extensions/tests/test_integration.py +1 -1
- robosystems_client/extensions/tests/test_token_utils.py +274 -0
- robosystems_client/extensions/token_utils.py +417 -0
- robosystems_client/extensions/utils.py +32 -2
- {robosystems_client-0.1.17.dist-info → robosystems_client-0.1.18.dist-info}/METADATA +1 -1
- {robosystems_client-0.1.17.dist-info → robosystems_client-0.1.18.dist-info}/RECORD +92 -88
- {robosystems_client-0.1.17.dist-info → robosystems_client-0.1.18.dist-info}/WHEEL +0 -0
- {robosystems_client-0.1.17.dist-info → robosystems_client-0.1.18.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""Tests for JWT token utilities"""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from datetime import datetime, timedelta
|
|
5
|
+
import json
|
|
6
|
+
import base64
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
from robosystems_client.extensions.token_utils import (
|
|
10
|
+
validate_jwt_format,
|
|
11
|
+
extract_jwt_from_header,
|
|
12
|
+
decode_jwt_payload,
|
|
13
|
+
is_jwt_expired,
|
|
14
|
+
get_jwt_claims,
|
|
15
|
+
get_jwt_expiration,
|
|
16
|
+
extract_token_from_environment,
|
|
17
|
+
extract_token_from_cookie,
|
|
18
|
+
find_valid_token,
|
|
19
|
+
TokenManager,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def create_test_jwt(payload: dict = None, exp_delta_seconds: int = 3600) -> str:
|
|
24
|
+
"""Create a test JWT token"""
|
|
25
|
+
header = {"alg": "HS256", "typ": "JWT"}
|
|
26
|
+
|
|
27
|
+
if payload is None:
|
|
28
|
+
payload = {
|
|
29
|
+
"sub": "test_user",
|
|
30
|
+
"user_id": "123",
|
|
31
|
+
"exp": int((datetime.now() + timedelta(seconds=exp_delta_seconds)).timestamp()),
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
# Encode header and payload
|
|
35
|
+
header_b64 = (
|
|
36
|
+
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
payload_b64 = (
|
|
40
|
+
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Create fake signature
|
|
44
|
+
signature = "test_signature_123"
|
|
45
|
+
|
|
46
|
+
return f"{header_b64}.{payload_b64}.{signature}"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class TestJWTValidation:
|
|
50
|
+
"""Test JWT format validation"""
|
|
51
|
+
|
|
52
|
+
def test_validate_jwt_format_valid(self):
|
|
53
|
+
"""Test validation of valid JWT format"""
|
|
54
|
+
token = create_test_jwt()
|
|
55
|
+
assert validate_jwt_format(token) is True
|
|
56
|
+
|
|
57
|
+
def test_validate_jwt_format_invalid(self):
|
|
58
|
+
"""Test validation of invalid JWT formats"""
|
|
59
|
+
# Missing parts
|
|
60
|
+
assert validate_jwt_format("header.payload") is False
|
|
61
|
+
# Wrong format
|
|
62
|
+
assert validate_jwt_format("not-a-jwt") is False
|
|
63
|
+
# Empty
|
|
64
|
+
assert validate_jwt_format("") is False
|
|
65
|
+
# None
|
|
66
|
+
assert validate_jwt_format(None) is False
|
|
67
|
+
# Not a string
|
|
68
|
+
assert validate_jwt_format(123) is False
|
|
69
|
+
|
|
70
|
+
def test_validate_jwt_with_padding(self):
|
|
71
|
+
"""Test JWT validation handles padding correctly"""
|
|
72
|
+
# JWT with different padding requirements
|
|
73
|
+
token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signature"
|
|
74
|
+
assert validate_jwt_format(token) is True
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class TestJWTExtraction:
|
|
78
|
+
"""Test JWT extraction from various sources"""
|
|
79
|
+
|
|
80
|
+
def test_extract_jwt_from_header_bearer(self):
|
|
81
|
+
"""Test extraction from Bearer authorization header"""
|
|
82
|
+
token = create_test_jwt()
|
|
83
|
+
|
|
84
|
+
# Standard Bearer format
|
|
85
|
+
assert extract_jwt_from_header(f"Bearer {token}") == token
|
|
86
|
+
# Case insensitive
|
|
87
|
+
assert extract_jwt_from_header(f"bearer {token}") == token
|
|
88
|
+
# Extra spaces
|
|
89
|
+
assert extract_jwt_from_header(f"Bearer {token} ") == token
|
|
90
|
+
|
|
91
|
+
def test_extract_jwt_from_header_dict(self):
|
|
92
|
+
"""Test extraction from headers dictionary"""
|
|
93
|
+
token = create_test_jwt()
|
|
94
|
+
|
|
95
|
+
headers = {"Authorization": f"Bearer {token}"}
|
|
96
|
+
assert extract_jwt_from_header(headers) == token
|
|
97
|
+
|
|
98
|
+
# Case variation
|
|
99
|
+
headers = {"authorization": f"Bearer {token}"}
|
|
100
|
+
assert extract_jwt_from_header(headers) == token
|
|
101
|
+
|
|
102
|
+
def test_extract_jwt_from_header_invalid(self):
|
|
103
|
+
"""Test extraction returns None for invalid inputs"""
|
|
104
|
+
assert extract_jwt_from_header(None) is None
|
|
105
|
+
assert extract_jwt_from_header("") is None
|
|
106
|
+
assert extract_jwt_from_header("NotBearer token") is None
|
|
107
|
+
assert extract_jwt_from_header({"Other": "header"}) is None
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class TestJWTDecoding:
|
|
111
|
+
"""Test JWT payload decoding"""
|
|
112
|
+
|
|
113
|
+
def test_decode_jwt_payload(self):
|
|
114
|
+
"""Test decoding JWT payload"""
|
|
115
|
+
payload = {"sub": "test_user", "user_id": "123", "roles": ["admin", "user"]}
|
|
116
|
+
token = create_test_jwt(payload)
|
|
117
|
+
|
|
118
|
+
decoded = decode_jwt_payload(token)
|
|
119
|
+
assert decoded["sub"] == "test_user"
|
|
120
|
+
assert decoded["user_id"] == "123"
|
|
121
|
+
assert decoded["roles"] == ["admin", "user"]
|
|
122
|
+
|
|
123
|
+
def test_decode_jwt_payload_invalid(self):
|
|
124
|
+
"""Test decoding invalid JWT returns None"""
|
|
125
|
+
assert decode_jwt_payload("invalid.token") is None
|
|
126
|
+
assert decode_jwt_payload("") is None
|
|
127
|
+
|
|
128
|
+
def test_get_jwt_claims(self):
|
|
129
|
+
"""Test getting all claims from JWT"""
|
|
130
|
+
payload = {"claim1": "value1", "claim2": "value2"}
|
|
131
|
+
token = create_test_jwt(payload)
|
|
132
|
+
|
|
133
|
+
claims = get_jwt_claims(token)
|
|
134
|
+
assert claims["claim1"] == "value1"
|
|
135
|
+
assert claims["claim2"] == "value2"
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class TestJWTExpiration:
|
|
139
|
+
"""Test JWT expiration checking"""
|
|
140
|
+
|
|
141
|
+
def test_is_jwt_expired_not_expired(self):
|
|
142
|
+
"""Test checking non-expired token"""
|
|
143
|
+
# Token expires in 1 hour
|
|
144
|
+
token = create_test_jwt(exp_delta_seconds=3600)
|
|
145
|
+
assert is_jwt_expired(token) is False
|
|
146
|
+
|
|
147
|
+
def test_is_jwt_expired_expired(self):
|
|
148
|
+
"""Test checking expired token"""
|
|
149
|
+
# Token expired 1 hour ago
|
|
150
|
+
token = create_test_jwt(exp_delta_seconds=-3600)
|
|
151
|
+
assert is_jwt_expired(token) is True
|
|
152
|
+
|
|
153
|
+
def test_is_jwt_expired_with_buffer(self):
|
|
154
|
+
"""Test expiration with buffer time"""
|
|
155
|
+
# Token expires in 30 seconds
|
|
156
|
+
token = create_test_jwt(exp_delta_seconds=30)
|
|
157
|
+
# With 60 second buffer, should be considered expired
|
|
158
|
+
assert is_jwt_expired(token, buffer_seconds=60) is True
|
|
159
|
+
# With no buffer, should not be expired
|
|
160
|
+
assert is_jwt_expired(token, buffer_seconds=0) is False
|
|
161
|
+
|
|
162
|
+
def test_get_jwt_expiration(self):
|
|
163
|
+
"""Test getting expiration datetime"""
|
|
164
|
+
exp_time = datetime.now() + timedelta(hours=1)
|
|
165
|
+
payload = {"exp": int(exp_time.timestamp())}
|
|
166
|
+
token = create_test_jwt(payload)
|
|
167
|
+
|
|
168
|
+
exp = get_jwt_expiration(token)
|
|
169
|
+
assert exp is not None
|
|
170
|
+
# Allow 1 second difference for test execution
|
|
171
|
+
assert abs((exp - exp_time).total_seconds()) < 1
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class TestTokenExtraction:
|
|
175
|
+
"""Test token extraction from various sources"""
|
|
176
|
+
|
|
177
|
+
def test_extract_token_from_environment(self):
|
|
178
|
+
"""Test extracting token from environment variable"""
|
|
179
|
+
token = create_test_jwt()
|
|
180
|
+
os.environ["ROBOSYSTEMS_TOKEN"] = token
|
|
181
|
+
|
|
182
|
+
try:
|
|
183
|
+
assert extract_token_from_environment() == token
|
|
184
|
+
# Custom env var
|
|
185
|
+
os.environ["CUSTOM_TOKEN"] = token
|
|
186
|
+
assert extract_token_from_environment("CUSTOM_TOKEN") == token
|
|
187
|
+
finally:
|
|
188
|
+
# Clean up
|
|
189
|
+
os.environ.pop("ROBOSYSTEMS_TOKEN", None)
|
|
190
|
+
os.environ.pop("CUSTOM_TOKEN", None)
|
|
191
|
+
|
|
192
|
+
def test_extract_token_from_cookie(self):
|
|
193
|
+
"""Test extracting token from cookies"""
|
|
194
|
+
token = create_test_jwt()
|
|
195
|
+
cookies = {"auth-token": token}
|
|
196
|
+
|
|
197
|
+
assert extract_token_from_cookie(cookies) == token
|
|
198
|
+
# Custom cookie name
|
|
199
|
+
cookies = {"session_token": token}
|
|
200
|
+
assert extract_token_from_cookie(cookies, "session_token") == token
|
|
201
|
+
# Missing cookie
|
|
202
|
+
assert extract_token_from_cookie({}) is None
|
|
203
|
+
|
|
204
|
+
def test_find_valid_token(self):
|
|
205
|
+
"""Test finding first valid token from multiple sources"""
|
|
206
|
+
token = create_test_jwt()
|
|
207
|
+
|
|
208
|
+
# Found in second source
|
|
209
|
+
result = find_valid_token(None, "invalid-token", token, "another-invalid")
|
|
210
|
+
assert result == token
|
|
211
|
+
|
|
212
|
+
# Found in headers dict
|
|
213
|
+
headers = {"Authorization": f"Bearer {token}"}
|
|
214
|
+
result = find_valid_token(None, headers)
|
|
215
|
+
assert result == token
|
|
216
|
+
|
|
217
|
+
# Not found
|
|
218
|
+
result = find_valid_token(None, "", "invalid")
|
|
219
|
+
assert result is None
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class TestTokenManager:
|
|
223
|
+
"""Test TokenManager class"""
|
|
224
|
+
|
|
225
|
+
def test_token_manager_basic(self):
|
|
226
|
+
"""Test basic TokenManager functionality"""
|
|
227
|
+
token = create_test_jwt()
|
|
228
|
+
manager = TokenManager(token)
|
|
229
|
+
|
|
230
|
+
assert manager.token == token
|
|
231
|
+
assert manager.is_valid() is True
|
|
232
|
+
|
|
233
|
+
claims = manager.get_claims()
|
|
234
|
+
assert claims["sub"] == "test_user"
|
|
235
|
+
|
|
236
|
+
def test_token_manager_refresh(self):
|
|
237
|
+
"""Test token refresh functionality"""
|
|
238
|
+
old_token = create_test_jwt(exp_delta_seconds=30)
|
|
239
|
+
new_token = create_test_jwt(exp_delta_seconds=3600)
|
|
240
|
+
|
|
241
|
+
def refresh_callback():
|
|
242
|
+
return new_token
|
|
243
|
+
|
|
244
|
+
manager = TokenManager(
|
|
245
|
+
old_token, refresh_callback=refresh_callback, auto_refresh=True, refresh_buffer=60
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Token should be refreshed automatically
|
|
249
|
+
assert manager.token == new_token
|
|
250
|
+
|
|
251
|
+
def test_token_manager_invalid_token(self):
|
|
252
|
+
"""Test TokenManager with invalid token"""
|
|
253
|
+
manager = TokenManager()
|
|
254
|
+
|
|
255
|
+
with pytest.raises(ValueError):
|
|
256
|
+
manager.token = "invalid-token"
|
|
257
|
+
|
|
258
|
+
assert manager.is_valid() is False
|
|
259
|
+
assert manager.get_claims() is None
|
|
260
|
+
assert manager.get_expiration() is None
|
|
261
|
+
|
|
262
|
+
def test_token_manager_manual_refresh(self):
|
|
263
|
+
"""Test manual token refresh"""
|
|
264
|
+
token = create_test_jwt()
|
|
265
|
+
new_token = create_test_jwt(exp_delta_seconds=7200)
|
|
266
|
+
|
|
267
|
+
manager = TokenManager(
|
|
268
|
+
token, refresh_callback=lambda: new_token, auto_refresh=False
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
assert manager.token == token
|
|
272
|
+
refreshed = manager.refresh()
|
|
273
|
+
assert refreshed == new_token
|
|
274
|
+
assert manager.token == new_token
|
|
@@ -0,0 +1,417 @@
|
|
|
1
|
+
"""JWT Token validation and management utilities for RoboSystems SDK
|
|
2
|
+
|
|
3
|
+
Provides comprehensive JWT handling, validation, and extraction utilities.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import base64
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
from datetime import datetime, timedelta
|
|
10
|
+
from typing import Dict, Any, Optional, Union
|
|
11
|
+
from enum import Enum
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TokenSource(Enum):
|
|
18
|
+
"""Sources where tokens can be extracted from"""
|
|
19
|
+
|
|
20
|
+
HEADER = "header"
|
|
21
|
+
COOKIE = "cookie"
|
|
22
|
+
ENVIRONMENT = "environment"
|
|
23
|
+
CONFIG = "config"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def validate_jwt_format(token: Optional[str]) -> bool:
|
|
27
|
+
"""Validate JWT token format (basic validation without cryptographic verification)
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
token: JWT token string to validate
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
True if token appears to be valid JWT format
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
>>> validate_jwt_format("eyJhbGc.eyJzdWI.SflKxwRJSM")
|
|
37
|
+
True
|
|
38
|
+
>>> validate_jwt_format("invalid-token")
|
|
39
|
+
False
|
|
40
|
+
"""
|
|
41
|
+
if not token or not isinstance(token, str):
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
# JWT should have exactly 3 parts: header.payload.signature
|
|
46
|
+
parts = token.split(".")
|
|
47
|
+
if len(parts) != 3:
|
|
48
|
+
return False
|
|
49
|
+
|
|
50
|
+
# Each part should be base64url encoded
|
|
51
|
+
for part in parts[:2]: # Check header and payload only
|
|
52
|
+
# Add padding if needed
|
|
53
|
+
padding = 4 - (len(part) % 4)
|
|
54
|
+
if padding != 4:
|
|
55
|
+
part += "=" * padding
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
# Try to decode base64
|
|
59
|
+
base64.urlsafe_b64decode(part)
|
|
60
|
+
except Exception:
|
|
61
|
+
return False
|
|
62
|
+
|
|
63
|
+
return True
|
|
64
|
+
except Exception:
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def extract_jwt_from_header(
|
|
69
|
+
auth_header: Optional[Union[str, Dict[str, str]]],
|
|
70
|
+
) -> Optional[str]:
|
|
71
|
+
"""Extract JWT token from Authorization header
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
auth_header: Authorization header value (e.g., "Bearer token123") or headers dict
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
JWT token if found, None otherwise
|
|
78
|
+
|
|
79
|
+
Example:
|
|
80
|
+
>>> extract_jwt_from_header("Bearer eyJhbGc.eyJzdWI.SflKxwRJSM")
|
|
81
|
+
"eyJhbGc.eyJzdWI.SflKxwRJSM"
|
|
82
|
+
>>> extract_jwt_from_header({"Authorization": "Bearer token123"})
|
|
83
|
+
"token123"
|
|
84
|
+
"""
|
|
85
|
+
if not auth_header:
|
|
86
|
+
return None
|
|
87
|
+
|
|
88
|
+
# Handle dict of headers
|
|
89
|
+
if isinstance(auth_header, dict):
|
|
90
|
+
auth_value = auth_header.get("Authorization") or auth_header.get("authorization")
|
|
91
|
+
if not auth_value:
|
|
92
|
+
return None
|
|
93
|
+
auth_header = auth_value
|
|
94
|
+
|
|
95
|
+
# Extract token from Bearer scheme
|
|
96
|
+
if isinstance(auth_header, str):
|
|
97
|
+
auth_header = auth_header.strip()
|
|
98
|
+
if auth_header.startswith("Bearer "):
|
|
99
|
+
token = auth_header[7:].strip()
|
|
100
|
+
return token if token else None
|
|
101
|
+
elif auth_header.startswith("bearer "): # Case insensitive
|
|
102
|
+
token = auth_header[7:].strip()
|
|
103
|
+
return token if token else None
|
|
104
|
+
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def decode_jwt_payload(token: str, verify: bool = False) -> Optional[Dict[str, Any]]:
|
|
109
|
+
"""Decode JWT payload without verification (for reading claims only)
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
token: JWT token to decode
|
|
113
|
+
verify: If True, will validate format first (default: False)
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Decoded payload as dictionary, None if invalid
|
|
117
|
+
|
|
118
|
+
Note:
|
|
119
|
+
This does NOT verify the signature. Use only for reading non-sensitive claims.
|
|
120
|
+
|
|
121
|
+
Example:
|
|
122
|
+
>>> payload = decode_jwt_payload("eyJhbGc.eyJzdWI.SflKxwRJSM")
|
|
123
|
+
>>> payload.get("sub") # Get subject claim
|
|
124
|
+
"""
|
|
125
|
+
if verify and not validate_jwt_format(token):
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
# Split token and get payload (second part)
|
|
130
|
+
parts = token.split(".")
|
|
131
|
+
if len(parts) != 3:
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
payload_part = parts[1]
|
|
135
|
+
|
|
136
|
+
# Add padding if needed
|
|
137
|
+
padding = 4 - (len(payload_part) % 4)
|
|
138
|
+
if padding != 4:
|
|
139
|
+
payload_part += "=" * padding
|
|
140
|
+
|
|
141
|
+
# Decode base64url
|
|
142
|
+
payload_bytes = base64.urlsafe_b64decode(payload_part)
|
|
143
|
+
payload = json.loads(payload_bytes.decode("utf-8"))
|
|
144
|
+
|
|
145
|
+
return payload
|
|
146
|
+
except Exception as e:
|
|
147
|
+
logger.debug(f"Failed to decode JWT payload: {e}")
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def is_jwt_expired(token: str, buffer_seconds: int = 60) -> bool:
|
|
152
|
+
"""Check if JWT token is expired based on exp claim
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
token: JWT token to check
|
|
156
|
+
buffer_seconds: Consider expired if expiring within this many seconds (default: 60)
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
True if token is expired or expiring soon
|
|
160
|
+
|
|
161
|
+
Example:
|
|
162
|
+
>>> is_jwt_expired("eyJhbGc.eyJleHAiOjE2MzA0MjU2MDB9.SflKxwRJSM")
|
|
163
|
+
True # If current time is past exp claim
|
|
164
|
+
"""
|
|
165
|
+
payload = decode_jwt_payload(token)
|
|
166
|
+
if not payload:
|
|
167
|
+
return True
|
|
168
|
+
|
|
169
|
+
exp = payload.get("exp")
|
|
170
|
+
if not exp:
|
|
171
|
+
# No expiration claim, consider as non-expiring
|
|
172
|
+
return False
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
exp_datetime = datetime.fromtimestamp(exp)
|
|
176
|
+
buffer = timedelta(seconds=buffer_seconds)
|
|
177
|
+
return datetime.now() >= (exp_datetime - buffer)
|
|
178
|
+
except Exception:
|
|
179
|
+
return True
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def get_jwt_claims(token: str) -> Optional[Dict[str, Any]]:
|
|
183
|
+
"""Get all claims from JWT token
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
token: JWT token
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Dictionary of all claims, None if invalid
|
|
190
|
+
|
|
191
|
+
Example:
|
|
192
|
+
>>> claims = get_jwt_claims(token)
|
|
193
|
+
>>> user_id = claims.get("user_id")
|
|
194
|
+
>>> roles = claims.get("roles", [])
|
|
195
|
+
"""
|
|
196
|
+
return decode_jwt_payload(token)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def get_jwt_expiration(token: str) -> Optional[datetime]:
|
|
200
|
+
"""Get expiration datetime from JWT token
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
token: JWT token
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Expiration datetime, None if no exp claim or invalid
|
|
207
|
+
|
|
208
|
+
Example:
|
|
209
|
+
>>> exp = get_jwt_expiration(token)
|
|
210
|
+
>>> if exp and exp > datetime.now():
|
|
211
|
+
... print(f"Token valid until {exp}")
|
|
212
|
+
"""
|
|
213
|
+
payload = decode_jwt_payload(token)
|
|
214
|
+
if not payload:
|
|
215
|
+
return None
|
|
216
|
+
|
|
217
|
+
exp = payload.get("exp")
|
|
218
|
+
if not exp:
|
|
219
|
+
return None
|
|
220
|
+
|
|
221
|
+
try:
|
|
222
|
+
return datetime.fromtimestamp(exp)
|
|
223
|
+
except Exception:
|
|
224
|
+
return None
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def extract_token_from_environment(env_var: str = "ROBOSYSTEMS_TOKEN") -> Optional[str]:
|
|
228
|
+
"""Extract JWT token from environment variable
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
env_var: Environment variable name (default: ROBOSYSTEMS_TOKEN)
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
JWT token if found and valid format, None otherwise
|
|
235
|
+
|
|
236
|
+
Example:
|
|
237
|
+
>>> os.environ["ROBOSYSTEMS_TOKEN"] = "eyJhbGc..."
|
|
238
|
+
>>> token = extract_token_from_environment()
|
|
239
|
+
"""
|
|
240
|
+
token = os.environ.get(env_var)
|
|
241
|
+
if token and validate_jwt_format(token):
|
|
242
|
+
return token
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def extract_token_from_cookie(
|
|
247
|
+
cookies: Dict[str, str], cookie_name: str = "auth-token"
|
|
248
|
+
) -> Optional[str]:
|
|
249
|
+
"""Extract JWT token from cookies
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
cookies: Dictionary of cookies
|
|
253
|
+
cookie_name: Name of cookie containing token (default: auth-token)
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
JWT token if found, None otherwise
|
|
257
|
+
|
|
258
|
+
Example:
|
|
259
|
+
>>> cookies = {"auth-token": "eyJhbGc..."}
|
|
260
|
+
>>> token = extract_token_from_cookie(cookies)
|
|
261
|
+
"""
|
|
262
|
+
token = cookies.get(cookie_name)
|
|
263
|
+
if token and validate_jwt_format(token):
|
|
264
|
+
return token
|
|
265
|
+
return None
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def find_valid_token(*sources: Union[str, Dict[str, str], None]) -> Optional[str]:
|
|
269
|
+
"""Find first valid JWT token from multiple sources
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
*sources: Variable number of potential token sources
|
|
273
|
+
(strings, dicts with Authorization header, etc.)
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
First valid JWT token found, None if none found
|
|
277
|
+
|
|
278
|
+
Example:
|
|
279
|
+
>>> token = find_valid_token(
|
|
280
|
+
... os.environ.get("TOKEN"),
|
|
281
|
+
... headers,
|
|
282
|
+
... cookies.get("auth-token"),
|
|
283
|
+
... config.get("token")
|
|
284
|
+
... )
|
|
285
|
+
"""
|
|
286
|
+
for source in sources:
|
|
287
|
+
if not source:
|
|
288
|
+
continue
|
|
289
|
+
|
|
290
|
+
# Direct token string
|
|
291
|
+
if isinstance(source, str):
|
|
292
|
+
if validate_jwt_format(source):
|
|
293
|
+
return source
|
|
294
|
+
|
|
295
|
+
# Headers dict
|
|
296
|
+
elif isinstance(source, dict):
|
|
297
|
+
# Try as headers
|
|
298
|
+
token = extract_jwt_from_header(source)
|
|
299
|
+
if token and validate_jwt_format(token):
|
|
300
|
+
return token
|
|
301
|
+
|
|
302
|
+
# Try as cookies
|
|
303
|
+
for key in ["auth-token", "auth_token", "token", "jwt"]:
|
|
304
|
+
token = source.get(key)
|
|
305
|
+
if token and validate_jwt_format(token):
|
|
306
|
+
return token
|
|
307
|
+
|
|
308
|
+
return None
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class TokenManager:
|
|
312
|
+
"""Manages JWT tokens with automatic refresh and validation"""
|
|
313
|
+
|
|
314
|
+
def __init__(
|
|
315
|
+
self,
|
|
316
|
+
token: Optional[str] = None,
|
|
317
|
+
refresh_callback: Optional[callable] = None,
|
|
318
|
+
auto_refresh: bool = True,
|
|
319
|
+
refresh_buffer: int = 300,
|
|
320
|
+
):
|
|
321
|
+
"""Initialize token manager
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
token: Initial JWT token
|
|
325
|
+
refresh_callback: Callback to refresh token when expired
|
|
326
|
+
auto_refresh: Automatically refresh before expiration
|
|
327
|
+
refresh_buffer: Seconds before expiration to trigger refresh (default: 300)
|
|
328
|
+
"""
|
|
329
|
+
self._token = token
|
|
330
|
+
self._refresh_callback = refresh_callback
|
|
331
|
+
self._auto_refresh = auto_refresh
|
|
332
|
+
self._refresh_buffer = refresh_buffer
|
|
333
|
+
|
|
334
|
+
@property
|
|
335
|
+
def token(self) -> Optional[str]:
|
|
336
|
+
"""Get current token, refreshing if needed"""
|
|
337
|
+
if self._auto_refresh and self._token and self._refresh_callback:
|
|
338
|
+
if is_jwt_expired(self._token, self._refresh_buffer):
|
|
339
|
+
self.refresh()
|
|
340
|
+
return self._token
|
|
341
|
+
|
|
342
|
+
@token.setter
|
|
343
|
+
def token(self, value: Optional[str]):
|
|
344
|
+
"""Set new token"""
|
|
345
|
+
if value and not validate_jwt_format(value):
|
|
346
|
+
raise ValueError("Invalid JWT token format")
|
|
347
|
+
self._token = value
|
|
348
|
+
|
|
349
|
+
def refresh(self) -> Optional[str]:
|
|
350
|
+
"""Refresh token using callback"""
|
|
351
|
+
if not self._refresh_callback:
|
|
352
|
+
raise RuntimeError("No refresh callback configured")
|
|
353
|
+
|
|
354
|
+
try:
|
|
355
|
+
new_token = self._refresh_callback()
|
|
356
|
+
if new_token and validate_jwt_format(new_token):
|
|
357
|
+
self._token = new_token
|
|
358
|
+
logger.info("Token refreshed successfully")
|
|
359
|
+
return new_token
|
|
360
|
+
except Exception as e:
|
|
361
|
+
logger.error(f"Token refresh failed: {e}")
|
|
362
|
+
|
|
363
|
+
return None
|
|
364
|
+
|
|
365
|
+
def is_valid(self) -> bool:
|
|
366
|
+
"""Check if current token is valid"""
|
|
367
|
+
return bool(
|
|
368
|
+
self._token
|
|
369
|
+
and validate_jwt_format(self._token)
|
|
370
|
+
and not is_jwt_expired(self._token, 0)
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
def get_claims(self) -> Optional[Dict[str, Any]]:
|
|
374
|
+
"""Get claims from current token"""
|
|
375
|
+
if self._token:
|
|
376
|
+
return get_jwt_claims(self._token)
|
|
377
|
+
return None
|
|
378
|
+
|
|
379
|
+
def get_expiration(self) -> Optional[datetime]:
|
|
380
|
+
"""Get expiration time of current token"""
|
|
381
|
+
if self._token:
|
|
382
|
+
return get_jwt_expiration(self._token)
|
|
383
|
+
return None
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
# Convenience function for quick token extraction from client config
|
|
387
|
+
def extract_token_from_client(client) -> Optional[str]:
|
|
388
|
+
"""Extract JWT token from RoboSystems client configuration
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
client: RoboSystems client instance
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
JWT token if found, None otherwise
|
|
395
|
+
"""
|
|
396
|
+
# Try to get from authenticated client
|
|
397
|
+
if hasattr(client, "token"):
|
|
398
|
+
return client.token
|
|
399
|
+
|
|
400
|
+
# Try from headers
|
|
401
|
+
if hasattr(client, "_headers"):
|
|
402
|
+
token = extract_jwt_from_header(client._headers)
|
|
403
|
+
if token:
|
|
404
|
+
return token
|
|
405
|
+
|
|
406
|
+
# Try from config
|
|
407
|
+
if hasattr(client, "config"):
|
|
408
|
+
config = client.config
|
|
409
|
+
if isinstance(config, dict):
|
|
410
|
+
# Direct token
|
|
411
|
+
if config.get("token"):
|
|
412
|
+
return config["token"]
|
|
413
|
+
# From headers in config
|
|
414
|
+
if config.get("headers"):
|
|
415
|
+
return extract_jwt_from_header(config["headers"])
|
|
416
|
+
|
|
417
|
+
return None
|