antioch-py 2.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of antioch-py might be problematic. Click here for more details.
- antioch/__init__.py +0 -0
- antioch/message.py +87 -0
- antioch/module/__init__.py +53 -0
- antioch/module/clock.py +62 -0
- antioch/module/execution.py +278 -0
- antioch/module/input.py +127 -0
- antioch/module/module.py +218 -0
- antioch/module/node.py +357 -0
- antioch/module/token.py +42 -0
- antioch/session/__init__.py +150 -0
- antioch/session/ark.py +504 -0
- antioch/session/asset.py +65 -0
- antioch/session/error.py +80 -0
- antioch/session/record.py +158 -0
- antioch/session/scene.py +1521 -0
- antioch/session/session.py +220 -0
- antioch/session/task.py +323 -0
- antioch/session/views/__init__.py +40 -0
- antioch/session/views/animation.py +189 -0
- antioch/session/views/articulation.py +245 -0
- antioch/session/views/basis_curve.py +186 -0
- antioch/session/views/camera.py +92 -0
- antioch/session/views/collision.py +75 -0
- antioch/session/views/geometry.py +74 -0
- antioch/session/views/ground_plane.py +63 -0
- antioch/session/views/imu.py +73 -0
- antioch/session/views/joint.py +64 -0
- antioch/session/views/light.py +175 -0
- antioch/session/views/pir_sensor.py +140 -0
- antioch/session/views/radar.py +73 -0
- antioch/session/views/rigid_body.py +282 -0
- antioch/session/views/xform.py +119 -0
- antioch_py-2.0.6.dist-info/METADATA +115 -0
- antioch_py-2.0.6.dist-info/RECORD +99 -0
- antioch_py-2.0.6.dist-info/WHEEL +5 -0
- antioch_py-2.0.6.dist-info/entry_points.txt +2 -0
- antioch_py-2.0.6.dist-info/top_level.txt +2 -0
- common/__init__.py +0 -0
- common/ark/__init__.py +60 -0
- common/ark/ark.py +128 -0
- common/ark/hardware.py +121 -0
- common/ark/kinematics.py +31 -0
- common/ark/module.py +85 -0
- common/ark/node.py +94 -0
- common/ark/scheduler.py +439 -0
- common/ark/sim.py +33 -0
- common/assets/__init__.py +3 -0
- common/constants.py +47 -0
- common/core/__init__.py +52 -0
- common/core/agent.py +296 -0
- common/core/auth.py +305 -0
- common/core/registry.py +331 -0
- common/core/task.py +36 -0
- common/message/__init__.py +59 -0
- common/message/annotation.py +89 -0
- common/message/array.py +500 -0
- common/message/base.py +517 -0
- common/message/camera.py +91 -0
- common/message/color.py +139 -0
- common/message/frame.py +50 -0
- common/message/image.py +171 -0
- common/message/imu.py +14 -0
- common/message/joint.py +47 -0
- common/message/log.py +31 -0
- common/message/pir.py +16 -0
- common/message/point.py +109 -0
- common/message/point_cloud.py +63 -0
- common/message/pose.py +148 -0
- common/message/quaternion.py +273 -0
- common/message/radar.py +58 -0
- common/message/types.py +37 -0
- common/message/vector.py +786 -0
- common/rome/__init__.py +9 -0
- common/rome/client.py +430 -0
- common/rome/error.py +16 -0
- common/session/__init__.py +54 -0
- common/session/environment.py +31 -0
- common/session/sim.py +240 -0
- common/session/views/__init__.py +263 -0
- common/session/views/animation.py +73 -0
- common/session/views/articulation.py +184 -0
- common/session/views/basis_curve.py +102 -0
- common/session/views/camera.py +147 -0
- common/session/views/collision.py +59 -0
- common/session/views/geometry.py +102 -0
- common/session/views/ground_plane.py +41 -0
- common/session/views/imu.py +66 -0
- common/session/views/joint.py +81 -0
- common/session/views/light.py +96 -0
- common/session/views/pir_sensor.py +115 -0
- common/session/views/radar.py +82 -0
- common/session/views/rigid_body.py +236 -0
- common/session/views/viewport.py +21 -0
- common/session/views/xform.py +39 -0
- common/utils/__init__.py +4 -0
- common/utils/comms.py +571 -0
- common/utils/logger.py +123 -0
- common/utils/time.py +42 -0
- common/utils/usd.py +12 -0
common/core/agent.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Literal, TypeVar
|
|
3
|
+
|
|
4
|
+
from common.ark import Ark as ArkDefinition, Environment
|
|
5
|
+
from common.message import Message
|
|
6
|
+
from common.utils.comms import CommsSession
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T", bound=Message)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ContainerSource(str, Enum):
|
|
12
|
+
"""
|
|
13
|
+
Source for container images.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
LOCAL = "Local"
|
|
17
|
+
REMOTE = "Remote"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ContainerState(Message):
|
|
21
|
+
"""
|
|
22
|
+
State of a container with metadata.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
_type = "antioch/agent/container_state"
|
|
26
|
+
|
|
27
|
+
module_name: str
|
|
28
|
+
running: bool
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class StartArkRequest(Message):
|
|
32
|
+
"""
|
|
33
|
+
Request to start an Ark.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
_type = "antioch/agent/start_ark_request"
|
|
37
|
+
|
|
38
|
+
ark: ArkDefinition
|
|
39
|
+
source: ContainerSource
|
|
40
|
+
environment: Environment
|
|
41
|
+
debug: bool
|
|
42
|
+
timeout: float
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class StopArkRequest(Message):
|
|
46
|
+
"""
|
|
47
|
+
Request to stop an Ark.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
_type = "antioch/agent/stop_ark_request"
|
|
51
|
+
|
|
52
|
+
timeout: float
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class RecordTelemetryRequest(Message):
|
|
56
|
+
"""
|
|
57
|
+
Request to start recording telemetry.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
_type = "antioch/agent/record_telemetry_request"
|
|
61
|
+
|
|
62
|
+
mcap_path: str | None = None
|
|
63
|
+
websocket_port: int | None = None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class AgentResponse(Message):
|
|
67
|
+
"""
|
|
68
|
+
Generic response for agent operations.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
_type = "antioch/agent/response"
|
|
72
|
+
|
|
73
|
+
success: bool
|
|
74
|
+
error: str | None = None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class AgentStateResponse(Message):
|
|
78
|
+
"""
|
|
79
|
+
Agent state response.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
_type = "antioch/agent/state_response"
|
|
83
|
+
|
|
84
|
+
running: bool
|
|
85
|
+
ark_active: bool
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class ArkStateResponse(Message):
|
|
89
|
+
"""
|
|
90
|
+
Ark state response.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
_type = "antioch/agent/ark_state_response"
|
|
94
|
+
|
|
95
|
+
state: Literal["started", "stopped"]
|
|
96
|
+
ark_name: str | None = None
|
|
97
|
+
environment: Literal["sim", "real"] | None = None
|
|
98
|
+
debug: bool | None = None
|
|
99
|
+
global_start_time_us: int | None = None
|
|
100
|
+
containers: list[ContainerState] | None = None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class AgentError(Exception):
|
|
104
|
+
"""
|
|
105
|
+
Agent operation error.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class AgentValidationError(Exception):
|
|
110
|
+
"""
|
|
111
|
+
Agent validation error.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class Agent:
|
|
116
|
+
"""
|
|
117
|
+
Client for interacting with the agent that manages Ark containers.
|
|
118
|
+
|
|
119
|
+
The agent is a long-lived container that can start, stop, and manage Arks.
|
|
120
|
+
This class provides a simple interface for all agent operations and works
|
|
121
|
+
across all environments (sim/real, local/remote).
|
|
122
|
+
|
|
123
|
+
Example:
|
|
124
|
+
agent = Agent()
|
|
125
|
+
agent.start_ark(ark_def, source=ContainerSource.LOCAL)
|
|
126
|
+
state = agent.get_ark_state()
|
|
127
|
+
agent.stop_ark()
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(self):
|
|
131
|
+
"""
|
|
132
|
+
Initialize the agent client.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
self.comms = CommsSession()
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def connected(self) -> bool:
|
|
139
|
+
"""
|
|
140
|
+
Check if the agent is reachable.
|
|
141
|
+
|
|
142
|
+
:return: True if connected, False otherwise.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
self._query_agent(
|
|
147
|
+
path="_agent/get_state",
|
|
148
|
+
response_type=AgentStateResponse,
|
|
149
|
+
timeout=1.0,
|
|
150
|
+
)
|
|
151
|
+
return True
|
|
152
|
+
except Exception:
|
|
153
|
+
return False
|
|
154
|
+
|
|
155
|
+
def start_ark(
|
|
156
|
+
self,
|
|
157
|
+
ark: ArkDefinition,
|
|
158
|
+
source: ContainerSource = ContainerSource.LOCAL,
|
|
159
|
+
environment: Environment = Environment.SIM,
|
|
160
|
+
debug: bool = False,
|
|
161
|
+
timeout: float = 30.0,
|
|
162
|
+
) -> None:
|
|
163
|
+
"""
|
|
164
|
+
Start an Ark on the agent by launching all module containers.
|
|
165
|
+
|
|
166
|
+
This operation is idempotent. If an Ark is already running, it will be
|
|
167
|
+
gracefully stopped before starting the new one.
|
|
168
|
+
|
|
169
|
+
:param ark: Ark definition to start.
|
|
170
|
+
:param source: Container image source (local or remote).
|
|
171
|
+
:param environment: Environment to run in (sim or real).
|
|
172
|
+
:param debug: Enable debug mode.
|
|
173
|
+
:param timeout: Timeout in seconds for modules to become ready (default: 30.0).
|
|
174
|
+
:raises AgentError: If the agent fails to start the Ark.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
response = self._query_agent(
|
|
178
|
+
path="_agent/start_ark",
|
|
179
|
+
response_type=AgentResponse,
|
|
180
|
+
request=StartArkRequest(
|
|
181
|
+
ark=ark,
|
|
182
|
+
source=source,
|
|
183
|
+
environment=environment,
|
|
184
|
+
debug=debug,
|
|
185
|
+
timeout=timeout,
|
|
186
|
+
),
|
|
187
|
+
timeout=timeout + 10.0,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if not response.success:
|
|
191
|
+
raise AgentError(f"Failed to start Ark: {response.error}")
|
|
192
|
+
|
|
193
|
+
def stop_ark(
|
|
194
|
+
self,
|
|
195
|
+
timeout: float = 30.0,
|
|
196
|
+
) -> None:
|
|
197
|
+
"""
|
|
198
|
+
Stop the currently running Ark on the agent.
|
|
199
|
+
|
|
200
|
+
Removes all module containers. The agent continues running and can
|
|
201
|
+
accept requests to start a new Ark.
|
|
202
|
+
|
|
203
|
+
:param timeout: Timeout in seconds for stopping containers (default: 30.0).
|
|
204
|
+
:raises AgentError: If the agent fails to stop the Ark.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
response = self._query_agent(
|
|
208
|
+
path="_agent/stop_ark",
|
|
209
|
+
response_type=AgentResponse,
|
|
210
|
+
request=StopArkRequest(timeout=timeout),
|
|
211
|
+
timeout=timeout + 10.0,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
if not response.success:
|
|
215
|
+
raise AgentError(f"Failed to stop Ark: {response.error}")
|
|
216
|
+
|
|
217
|
+
def get_ark_state(self) -> ArkStateResponse:
|
|
218
|
+
"""
|
|
219
|
+
Get the current state of the Ark running on the agent.
|
|
220
|
+
|
|
221
|
+
Returns the current state including all container statuses.
|
|
222
|
+
|
|
223
|
+
:return: Current Ark state with container information.
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
return self._query_agent(
|
|
227
|
+
path="_agent/ark_state",
|
|
228
|
+
response_type=ArkStateResponse,
|
|
229
|
+
timeout=10.0,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def record_telemetry(self, mcap_path: str | None = None) -> None:
|
|
233
|
+
"""
|
|
234
|
+
Start recording telemetry to an MCAP file.
|
|
235
|
+
|
|
236
|
+
Creates an MCAP writer at the specified path. The WebSocket server (port 8765)
|
|
237
|
+
and subscriber task are always active, streaming telemetry continuously.
|
|
238
|
+
If already recording, finalizes the current recording before starting a new one.
|
|
239
|
+
|
|
240
|
+
:param mcap_path: Optional path where the MCAP file will be saved.
|
|
241
|
+
:raises AgentError: If the agent fails to start recording telemetry.
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
response = self.comms.query(
|
|
245
|
+
path="_agent/record_telemetry",
|
|
246
|
+
response_type=AgentResponse,
|
|
247
|
+
request=RecordTelemetryRequest(mcap_path=mcap_path),
|
|
248
|
+
timeout=5.0,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
if not response.success:
|
|
252
|
+
raise AgentError(f"Failed to start recording telemetry: {response.error}")
|
|
253
|
+
|
|
254
|
+
def reset_telemetry(self) -> None:
|
|
255
|
+
"""
|
|
256
|
+
Reset telemetry session completely.
|
|
257
|
+
|
|
258
|
+
Finalizes any active MCAP recording, resets time tracking, and clears the websocket
|
|
259
|
+
session causing all clients to reset their state. This is useful when clearing the
|
|
260
|
+
scene and starting a new Ark to ensure LET times start from 0 again.
|
|
261
|
+
|
|
262
|
+
:raises AgentError: If the agent fails to reset telemetry.
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
response = self.comms.query(
|
|
266
|
+
path="_agent/reset_telemetry",
|
|
267
|
+
response_type=AgentResponse,
|
|
268
|
+
timeout=5.0,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
if not response.success:
|
|
272
|
+
raise AgentError(f"Failed to reset telemetry: {response.error}")
|
|
273
|
+
|
|
274
|
+
def _query_agent(
|
|
275
|
+
self,
|
|
276
|
+
path: str,
|
|
277
|
+
response_type: type[T],
|
|
278
|
+
request: Message | None = None,
|
|
279
|
+
timeout: float = 10.0,
|
|
280
|
+
) -> T:
|
|
281
|
+
"""
|
|
282
|
+
Execute an agent query.
|
|
283
|
+
|
|
284
|
+
:param path: The agent query path.
|
|
285
|
+
:param response_type: Expected response type.
|
|
286
|
+
:param request: Optional request message.
|
|
287
|
+
:param timeout: Query timeout in seconds.
|
|
288
|
+
:return: The response message.
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
return self.comms.query(
|
|
292
|
+
path=path,
|
|
293
|
+
response_type=response_type,
|
|
294
|
+
request=request,
|
|
295
|
+
timeout=timeout,
|
|
296
|
+
)
|
common/core/auth.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import requests
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
10
|
+
from common.constants import get_auth_dir
|
|
11
|
+
|
|
12
|
+
# Authentication routes
|
|
13
|
+
AUTH_DOMAIN = os.environ.get("AUTH_DOMAIN", "https://staging.auth.antioch.com")
|
|
14
|
+
AUTH_TOKEN_URL = f"{AUTH_DOMAIN}/oauth/token"
|
|
15
|
+
DEVICE_CODE_URL = f"{AUTH_DOMAIN}/oauth/device/code"
|
|
16
|
+
|
|
17
|
+
# Authentication constants
|
|
18
|
+
AUTH_CLIENT_ID = "x0aOquV43Xe76ehqAm6Zir80O0MWpqTV"
|
|
19
|
+
ALGORITHMS = ["RS256"]
|
|
20
|
+
AUDIENCE = "https://sessions.antioch.com"
|
|
21
|
+
AUTH_SCOPE = "openid profile email"
|
|
22
|
+
AUTH_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code"
|
|
23
|
+
AUTH_TIMEOUT_SECONDS = 120
|
|
24
|
+
|
|
25
|
+
# Authentication claims
|
|
26
|
+
AUTH_ORG_ID_CLAIM = "https://antioch.com/org_id"
|
|
27
|
+
AUTH_ORG_NAME_CLAIM = "https://antioch.com/org_name"
|
|
28
|
+
AUTH_ORGANIZATIONS_CLAIM = "https://antioch.com/organizations"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AuthError(Exception):
|
|
32
|
+
"""
|
|
33
|
+
Authentication error.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Organization(BaseModel):
|
|
38
|
+
"""
|
|
39
|
+
Organization information.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
org_id: str
|
|
43
|
+
org_name: str
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class AuthHandler:
|
|
47
|
+
"""
|
|
48
|
+
Client for handling authentication.
|
|
49
|
+
|
|
50
|
+
Auth is used for:
|
|
51
|
+
- Pulling artifacts from the remote Ark registry
|
|
52
|
+
- Pulling assets from the remote asset registry
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self):
|
|
56
|
+
"""
|
|
57
|
+
Initialize the auth handler.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
self._token: str | None = None
|
|
61
|
+
self._user_id: str | None = None
|
|
62
|
+
self._current_org: Organization | None = None
|
|
63
|
+
self._available_orgs: list[Organization] = []
|
|
64
|
+
self._load_local_token()
|
|
65
|
+
|
|
66
|
+
def login(self) -> None:
|
|
67
|
+
"""
|
|
68
|
+
Authenticate the user via device code flow.
|
|
69
|
+
|
|
70
|
+
:raises AuthError: If authentication fails.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
if self.is_authenticated():
|
|
74
|
+
print("Already authenticated")
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
# Request device code
|
|
78
|
+
device_code_payload = {
|
|
79
|
+
"client_id": AUTH_CLIENT_ID,
|
|
80
|
+
"scope": AUTH_SCOPE,
|
|
81
|
+
"audience": AUDIENCE,
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
device_code_response = requests.post(DEVICE_CODE_URL, data=device_code_payload)
|
|
85
|
+
device_code_data = device_code_response.json()
|
|
86
|
+
if device_code_response.status_code != 200:
|
|
87
|
+
raise AuthError("Error generating the device code") from Exception(device_code_data)
|
|
88
|
+
|
|
89
|
+
print(f"You have {AUTH_TIMEOUT_SECONDS} seconds to complete the following:")
|
|
90
|
+
print(f" 1. Navigate to: {device_code_data['verification_uri_complete']}")
|
|
91
|
+
print(f" 2. Enter the code: {device_code_data['user_code']}")
|
|
92
|
+
|
|
93
|
+
# Poll for token
|
|
94
|
+
token_payload = {
|
|
95
|
+
"grant_type": AUTH_GRANT_TYPE,
|
|
96
|
+
"device_code": device_code_data["device_code"],
|
|
97
|
+
"client_id": AUTH_CLIENT_ID,
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
authenticated = False
|
|
101
|
+
start_time = time.time()
|
|
102
|
+
while not authenticated:
|
|
103
|
+
token_response = requests.post(AUTH_TOKEN_URL, data=token_payload)
|
|
104
|
+
token_data = token_response.json()
|
|
105
|
+
if token_response.status_code == 200:
|
|
106
|
+
print("Authenticated!")
|
|
107
|
+
authenticated = True
|
|
108
|
+
elif token_data["error"] not in ("authorization_pending", "slow_down"):
|
|
109
|
+
print(token_data["error_description"])
|
|
110
|
+
raise AuthError("Error authenticating the user") from Exception(token_data)
|
|
111
|
+
else:
|
|
112
|
+
if time.time() - start_time > AUTH_TIMEOUT_SECONDS:
|
|
113
|
+
raise AuthError("Timeout waiting for authentication")
|
|
114
|
+
time.sleep(device_code_data["interval"])
|
|
115
|
+
|
|
116
|
+
# Save token
|
|
117
|
+
self._token = token_data["access_token"]
|
|
118
|
+
if self._token is None:
|
|
119
|
+
raise AuthError("No token received")
|
|
120
|
+
self._validate_token_claims(self._token)
|
|
121
|
+
self.save_token()
|
|
122
|
+
|
|
123
|
+
def is_authenticated(self) -> bool:
|
|
124
|
+
"""
|
|
125
|
+
Check if the user is authenticated.
|
|
126
|
+
|
|
127
|
+
:return: True if authenticated, False otherwise.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
return self._current_org is not None
|
|
131
|
+
|
|
132
|
+
def select_organization(self, org_id: str):
|
|
133
|
+
"""
|
|
134
|
+
Select the organization to use for the session.
|
|
135
|
+
|
|
136
|
+
:param org_id: The ID of the organization to select.
|
|
137
|
+
:raises AuthError: If the user is not authenticated.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
if not self.is_authenticated():
|
|
141
|
+
raise AuthError("Not authenticated. Please login first")
|
|
142
|
+
|
|
143
|
+
for org in self._available_orgs:
|
|
144
|
+
if org.org_id == org_id:
|
|
145
|
+
self._current_org = org
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
raise AuthError(f"Organization '{org_id}' is not in your available organizations")
|
|
149
|
+
|
|
150
|
+
def get_current_org(self) -> Organization | None:
|
|
151
|
+
"""
|
|
152
|
+
Get the current organization.
|
|
153
|
+
|
|
154
|
+
:return: The current organization.
|
|
155
|
+
:raises AuthError: If the user is not authenticated.
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
if not self.is_authenticated():
|
|
159
|
+
raise AuthError("Not authenticated. Please login first")
|
|
160
|
+
|
|
161
|
+
return self._current_org
|
|
162
|
+
|
|
163
|
+
def get_user_id(self) -> str | None:
|
|
164
|
+
"""
|
|
165
|
+
Get the user ID.
|
|
166
|
+
|
|
167
|
+
:return: The user ID.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
return self._user_id
|
|
171
|
+
|
|
172
|
+
def get_available_orgs(self) -> list[Organization]:
|
|
173
|
+
"""
|
|
174
|
+
Get the available organizations.
|
|
175
|
+
|
|
176
|
+
:return: The available organizations.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
return self._available_orgs
|
|
180
|
+
|
|
181
|
+
def get_token(self) -> str | None:
|
|
182
|
+
"""
|
|
183
|
+
Get the token.
|
|
184
|
+
|
|
185
|
+
:return: The token.
|
|
186
|
+
:raises AuthError: If the user is not authenticated.
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
if not self.is_authenticated():
|
|
190
|
+
raise AuthError("Not authenticated. Please login first")
|
|
191
|
+
return self._token
|
|
192
|
+
|
|
193
|
+
def save_token(self) -> None:
|
|
194
|
+
"""
|
|
195
|
+
Save the authentication token and organization data to disk.
|
|
196
|
+
|
|
197
|
+
:raises AuthError: If not authenticated.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
if not self.is_authenticated():
|
|
201
|
+
raise AuthError("Not authenticated. Please login first")
|
|
202
|
+
|
|
203
|
+
stored_data = {
|
|
204
|
+
"token": self._token,
|
|
205
|
+
"current_org": self._current_org.model_dump() if self._current_org else None,
|
|
206
|
+
"available_orgs": [org.model_dump() for org in self._available_orgs],
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
token_path = self._get_token_path()
|
|
210
|
+
|
|
211
|
+
# Create file with restrictive permissions (owner read/write only)
|
|
212
|
+
# Use os.open to atomically create file with 0o600 permissions
|
|
213
|
+
fd = os.open(token_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
|
214
|
+
with os.fdopen(fd, "w") as f:
|
|
215
|
+
json.dump(stored_data, f, indent=2)
|
|
216
|
+
|
|
217
|
+
def clear_token(self) -> None:
|
|
218
|
+
"""
|
|
219
|
+
Clear the stored authentication token from disk.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
token_path = self._get_token_path()
|
|
223
|
+
if token_path.exists():
|
|
224
|
+
token_path.unlink()
|
|
225
|
+
|
|
226
|
+
def _load_local_token(self) -> None:
|
|
227
|
+
"""
|
|
228
|
+
Load the authentication token and organization data from disk.
|
|
229
|
+
|
|
230
|
+
Silently returns if no token exists or if loading fails. Clears invalid tokens.
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
token_path = self._get_token_path()
|
|
234
|
+
if not token_path.exists():
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
with open(token_path, "r") as f:
|
|
239
|
+
stored_data = json.load(f)
|
|
240
|
+
|
|
241
|
+
self._token = stored_data.get("token")
|
|
242
|
+
if not self._token:
|
|
243
|
+
return
|
|
244
|
+
|
|
245
|
+
# Validate and extract all claims from token
|
|
246
|
+
self._validate_token_claims(self._token)
|
|
247
|
+
except Exception as e:
|
|
248
|
+
print(f"Error loading local token: {e}")
|
|
249
|
+
|
|
250
|
+
# Clear invalid or expired tokens
|
|
251
|
+
self._token = None
|
|
252
|
+
self.clear_token()
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
def _validate_token_claims(self, token: str):
|
|
256
|
+
"""
|
|
257
|
+
Validate the token and extract all claims including user ID and organization information.
|
|
258
|
+
|
|
259
|
+
:param token: The JWT token to validate.
|
|
260
|
+
:raises AuthError: If the token is invalid, expired, or missing required claims.
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
parts = token.split(".")
|
|
264
|
+
if len(parts) != 3:
|
|
265
|
+
raise AuthError("Invalid token format")
|
|
266
|
+
|
|
267
|
+
# Decode the payload (middle part)
|
|
268
|
+
payload_encoded = parts[1]
|
|
269
|
+
padding = len(payload_encoded) % 4
|
|
270
|
+
if padding:
|
|
271
|
+
payload_encoded += "=" * (4 - padding)
|
|
272
|
+
payload_bytes = base64.urlsafe_b64decode(payload_encoded)
|
|
273
|
+
payload = json.loads(payload_bytes)
|
|
274
|
+
|
|
275
|
+
# Check expiration
|
|
276
|
+
exp = payload.get("exp")
|
|
277
|
+
if exp and time.time() > exp:
|
|
278
|
+
raise AuthError("Token has expired")
|
|
279
|
+
|
|
280
|
+
# Extract user ID
|
|
281
|
+
self._user_id = payload.get("sub")
|
|
282
|
+
if self._user_id is None:
|
|
283
|
+
raise AuthError("User ID not found in token claims")
|
|
284
|
+
|
|
285
|
+
# Extract current organization
|
|
286
|
+
self._current_org = Organization(
|
|
287
|
+
org_id=payload.get(AUTH_ORG_ID_CLAIM),
|
|
288
|
+
org_name=payload.get(AUTH_ORG_NAME_CLAIM),
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# Extract available organizations
|
|
292
|
+
# Note: Auth0 returns organizations with "id" and "name" keys, not "org_id" and "org_name"
|
|
293
|
+
self._available_orgs = [
|
|
294
|
+
Organization(org_id=org.get("id") or org.get("org_id"), org_name=org.get("name") or org.get("org_name"))
|
|
295
|
+
for org in payload.get(AUTH_ORGANIZATIONS_CLAIM, [])
|
|
296
|
+
]
|
|
297
|
+
|
|
298
|
+
def _get_token_path(self) -> Path:
|
|
299
|
+
"""
|
|
300
|
+
Get the token file path.
|
|
301
|
+
|
|
302
|
+
:return: Path to the token file.
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
return get_auth_dir() / "token.json"
|