antioch-py 2.2.3__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of antioch-py might be problematic. Click here for more details.
- antioch/__init__.py +101 -0
- antioch/{module/execution.py → execution.py} +1 -1
- antioch/{module/input.py → input.py} +2 -4
- antioch/{module/module.py → module.py} +17 -34
- antioch/{module/node.py → node.py} +17 -16
- {antioch_py-2.2.3.dist-info → antioch_py-3.0.0.dist-info}/METADATA +8 -11
- antioch_py-3.0.0.dist-info/RECORD +61 -0
- {antioch_py-2.2.3.dist-info → antioch_py-3.0.0.dist-info}/WHEEL +1 -1
- antioch_py-3.0.0.dist-info/licenses/LICENSE +21 -0
- common/ark/__init__.py +6 -16
- common/ark/ark.py +23 -62
- common/ark/hardware.py +1 -1
- common/ark/kinematics.py +1 -1
- common/ark/module.py +22 -0
- common/ark/node.py +46 -3
- common/ark/scheduler.py +2 -29
- common/ark/sim.py +1 -1
- {antioch/module → common/ark}/token.py +17 -0
- common/assets/rigging.usd +0 -0
- common/constants.py +63 -5
- common/core/__init__.py +37 -24
- common/core/auth.py +87 -112
- common/core/container.py +261 -0
- common/core/registry.py +131 -152
- common/core/rome.py +251 -0
- common/core/telemetry.py +176 -0
- common/core/types.py +219 -0
- common/message/__init__.py +19 -5
- common/message/annotation.py +174 -23
- common/message/array.py +25 -1
- common/message/camera.py +23 -1
- common/message/color.py +32 -6
- common/message/detection.py +40 -0
- common/message/foxglove.py +20 -0
- common/message/frame.py +71 -7
- common/message/image.py +58 -9
- common/message/imu.py +24 -4
- common/message/joint.py +69 -10
- common/message/log.py +52 -7
- common/message/pir.py +23 -8
- common/message/plot.py +57 -0
- common/message/point.py +55 -6
- common/message/point_cloud.py +55 -19
- common/message/pose.py +59 -19
- common/message/quaternion.py +105 -92
- common/message/radar.py +195 -29
- common/message/twist.py +34 -0
- common/message/types.py +40 -5
- common/message/vector.py +180 -245
- common/sim/__init__.py +49 -0
- common/{session/config.py → sim/objects.py} +97 -27
- common/sim/state.py +11 -0
- common/utils/comms.py +30 -12
- common/utils/logger.py +26 -7
- antioch/message.py +0 -87
- antioch/module/__init__.py +0 -53
- antioch/session/__init__.py +0 -152
- antioch/session/ark.py +0 -500
- antioch/session/asset.py +0 -65
- antioch/session/error.py +0 -80
- antioch/session/objects/__init__.py +0 -40
- antioch/session/objects/animation.py +0 -162
- antioch/session/objects/articulation.py +0 -180
- antioch/session/objects/basis_curve.py +0 -180
- antioch/session/objects/camera.py +0 -65
- antioch/session/objects/collision.py +0 -46
- antioch/session/objects/geometry.py +0 -58
- antioch/session/objects/ground_plane.py +0 -48
- antioch/session/objects/imu.py +0 -53
- antioch/session/objects/joint.py +0 -49
- antioch/session/objects/light.py +0 -123
- antioch/session/objects/pir_sensor.py +0 -98
- antioch/session/objects/radar.py +0 -62
- antioch/session/objects/rigid_body.py +0 -197
- antioch/session/objects/xform.py +0 -119
- antioch/session/record.py +0 -158
- antioch/session/scene.py +0 -1544
- antioch/session/session.py +0 -211
- antioch/session/task.py +0 -309
- antioch_py-2.2.3.dist-info/RECORD +0 -85
- antioch_py-2.2.3.dist-info/entry_points.txt +0 -2
- common/core/agent.py +0 -324
- common/core/task.py +0 -36
- common/message/velocity.py +0 -11
- common/rome/__init__.py +0 -9
- common/rome/client.py +0 -430
- common/rome/error.py +0 -16
- common/session/__init__.py +0 -31
- common/session/environment.py +0 -31
- common/session/sim.py +0 -129
- common/utils/usd.py +0 -12
- /antioch/{module/clock.py → clock.py} +0 -0
- {antioch_py-2.2.3.dist-info → antioch_py-3.0.0.dist-info}/top_level.txt +0 -0
- /common/message/{base.py → message.py} +0 -0
common/ark/node.py
CHANGED
|
@@ -3,6 +3,7 @@ from enum import Enum
|
|
|
3
3
|
|
|
4
4
|
from pydantic import Field
|
|
5
5
|
|
|
6
|
+
from common.ark.token import InputToken
|
|
6
7
|
from common.message import Message
|
|
7
8
|
|
|
8
9
|
|
|
@@ -42,10 +43,7 @@ class NodeTimer(Message):
|
|
|
42
43
|
else:
|
|
43
44
|
raise ValueError("Timer must specify frequency or period")
|
|
44
45
|
|
|
45
|
-
# Convert to microseconds
|
|
46
46
|
period_us = int(period_ms * 1000)
|
|
47
|
-
|
|
48
|
-
# Round to nearest millisecond
|
|
49
47
|
rounded_us = ((period_us + 500) // 1000) * 1000
|
|
50
48
|
if rounded_us == 0:
|
|
51
49
|
raise ValueError("Timer frequency is too high (sub-millisecond period)")
|
|
@@ -92,3 +90,48 @@ class Node(Message):
|
|
|
92
90
|
inputs: dict[str, NodeInput] = Field(default_factory=dict)
|
|
93
91
|
outputs: dict[str, NodeOutput] = Field(default_factory=dict)
|
|
94
92
|
hardware_access: dict[str, HardwareAccessMode] = Field(default_factory=dict)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class NodeEdge(Message):
|
|
96
|
+
"""
|
|
97
|
+
Directed edge representing data flow between nodes.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
_type = "antioch/ark/node_edge"
|
|
101
|
+
source_module: str
|
|
102
|
+
source_node: str
|
|
103
|
+
source_output_name: str
|
|
104
|
+
target_module: str
|
|
105
|
+
target_node: str
|
|
106
|
+
target_input_name: str
|
|
107
|
+
type: str
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class SimNodeStart(Message):
|
|
111
|
+
"""
|
|
112
|
+
Ark signals node to start execution (sim mode).
|
|
113
|
+
|
|
114
|
+
Sent from Ark to node via publisher to trigger node start with hardware reads.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
_type = "antioch/ark/sim_node_start"
|
|
118
|
+
module_name: str
|
|
119
|
+
node_name: str
|
|
120
|
+
start_let_us: int
|
|
121
|
+
start_timestamp_us: int
|
|
122
|
+
input_tokens: list[InputToken]
|
|
123
|
+
hardware_reads: dict[str, bytes]
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class SimNodeComplete(Message):
|
|
127
|
+
"""
|
|
128
|
+
Node signals completion to Ark (sim mode).
|
|
129
|
+
|
|
130
|
+
Sent from node to Ark to indicate completion with optional hardware writes.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
_type = "antioch/ark/sim_node_complete"
|
|
134
|
+
module_name: str
|
|
135
|
+
node_name: str
|
|
136
|
+
completion_let_us: int
|
|
137
|
+
hardware_writes: dict[str, bytes] | None = None
|
common/ark/scheduler.py
CHANGED
|
@@ -3,38 +3,11 @@ from abc import ABC, abstractmethod
|
|
|
3
3
|
from sortedcontainers import SortedDict
|
|
4
4
|
|
|
5
5
|
from common.ark.module import Module
|
|
6
|
+
from common.ark.node import NodeEdge
|
|
7
|
+
from common.ark.token import InputToken
|
|
6
8
|
from common.message import Message
|
|
7
9
|
|
|
8
10
|
|
|
9
|
-
class NodeEdge(Message):
|
|
10
|
-
"""
|
|
11
|
-
Directed edge representing data flow between nodes.
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
_type = "antioch/ark/node_edge"
|
|
15
|
-
source_module: str
|
|
16
|
-
source_node: str
|
|
17
|
-
source_output_name: str
|
|
18
|
-
target_module: str
|
|
19
|
-
target_node: str
|
|
20
|
-
target_input_name: str
|
|
21
|
-
type: str
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class InputToken(Message):
|
|
25
|
-
"""
|
|
26
|
-
Input token representing data flow to a node.
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
_type = "antioch/ark/input_token"
|
|
30
|
-
source_module: str
|
|
31
|
-
source_node: str
|
|
32
|
-
source_output_name: str
|
|
33
|
-
target_input_name: str
|
|
34
|
-
let_us: int
|
|
35
|
-
budget_us: int
|
|
36
|
-
|
|
37
|
-
|
|
38
11
|
class ScheduleEvent(Message, ABC):
|
|
39
12
|
"""
|
|
40
13
|
Base class for schedule events.
|
common/ark/sim.py
CHANGED
|
@@ -3,6 +3,23 @@ from enum import Enum
|
|
|
3
3
|
from common.message import Message
|
|
4
4
|
from common.utils.time import now_us
|
|
5
5
|
|
|
6
|
+
# Synchronization path for token communication
|
|
7
|
+
ARK_TOKEN_PATH = "_ark/token/{path}"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class InputToken(Message):
|
|
11
|
+
"""
|
|
12
|
+
Input token representing data flow to a node.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
_type = "antioch/ark/input_token"
|
|
16
|
+
source_module: str
|
|
17
|
+
source_node: str
|
|
18
|
+
source_output_name: str
|
|
19
|
+
target_input_name: str
|
|
20
|
+
let_us: int
|
|
21
|
+
budget_us: int
|
|
22
|
+
|
|
6
23
|
|
|
7
24
|
class TokenType(str, Enum):
|
|
8
25
|
"""
|
|
Binary file
|
common/constants.py
CHANGED
|
@@ -1,11 +1,20 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
|
|
4
|
+
# =============================================================================
|
|
5
|
+
# Environment
|
|
6
|
+
# =============================================================================
|
|
7
|
+
|
|
4
8
|
ANTIOCH_ENV = os.environ.get("ANTIOCH_ENV", "prod").lower()
|
|
5
|
-
if ANTIOCH_ENV not in ("prod", "staging"):
|
|
9
|
+
if ANTIOCH_ENV not in ("prod", "staging", "local"):
|
|
6
10
|
raise ValueError(f"Invalid ANTIOCH_ENV: {ANTIOCH_ENV}")
|
|
7
11
|
|
|
8
|
-
|
|
12
|
+
# =============================================================================
|
|
13
|
+
# API URLs
|
|
14
|
+
# =============================================================================
|
|
15
|
+
|
|
16
|
+
# Local dev uses staging APIs
|
|
17
|
+
if ANTIOCH_ENV in ("staging", "local"):
|
|
9
18
|
ANTIOCH_API_URL = "https://staging.api.antioch.com"
|
|
10
19
|
AUTH_DOMAIN = "https://staging.auth.antioch.com"
|
|
11
20
|
AUTH_CLIENT_ID = "x0aOquV43Xe76ehqAm6Zir80O0MWpqTV"
|
|
@@ -14,9 +23,58 @@ else:
|
|
|
14
23
|
AUTH_DOMAIN = "https://auth.antioch.com"
|
|
15
24
|
AUTH_CLIENT_ID = "8RLoPEgMP3ih10sfJsGPkwbUWGilsoyX"
|
|
16
25
|
|
|
26
|
+
# Allow environment variable overrides
|
|
17
27
|
ANTIOCH_API_URL = os.environ.get("ANTIOCH_API_URL", ANTIOCH_API_URL)
|
|
18
28
|
AUTH_DOMAIN = os.environ.get("AUTH_DOMAIN", AUTH_DOMAIN)
|
|
19
|
-
|
|
29
|
+
|
|
30
|
+
# =============================================================================
|
|
31
|
+
# Local Storage Directories
|
|
32
|
+
# =============================================================================
|
|
33
|
+
|
|
34
|
+
ANTIOCH_DIR = os.environ.get("ANTIOCH_DIR", f"{os.environ.get('HOME', '.')}/.antioch/{ANTIOCH_ENV}")
|
|
35
|
+
ANTIOCH_ARKS_DIR = os.environ.get("ANTIOCH_ARKS_DIR", f"{ANTIOCH_DIR}/arks")
|
|
36
|
+
ANTIOCH_ASSETS_DIR = os.environ.get("ANTIOCH_ASSETS_DIR", f"{ANTIOCH_DIR}/assets")
|
|
37
|
+
|
|
38
|
+
# =============================================================================
|
|
39
|
+
# Auth0 Configuration
|
|
40
|
+
# =============================================================================
|
|
41
|
+
|
|
42
|
+
AUTH_TOKEN_URL = f"{AUTH_DOMAIN}/oauth/token"
|
|
43
|
+
DEVICE_CODE_URL = f"{AUTH_DOMAIN}/oauth/device/code"
|
|
44
|
+
|
|
45
|
+
AUTH_SCOPE = "openid profile email"
|
|
46
|
+
AUDIENCE = "https://sessions.antioch.com"
|
|
47
|
+
AUTH_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code"
|
|
48
|
+
AUTH_TIMEOUT_SECONDS = 120
|
|
49
|
+
|
|
50
|
+
# JWT claim names (namespaced for Auth0)
|
|
51
|
+
AUTH_ORG_ID_CLAIM = "https://antioch.com/org_id"
|
|
52
|
+
AUTH_ORG_NAME_CLAIM = "https://antioch.com/org_name"
|
|
53
|
+
|
|
54
|
+
# =============================================================================
|
|
55
|
+
# Telemetry Configuration
|
|
56
|
+
# =============================================================================
|
|
57
|
+
|
|
58
|
+
FOXGLOVE_WEBSOCKET_PORT = 8765
|
|
59
|
+
|
|
60
|
+
# =============================================================================
|
|
61
|
+
# Zenoh Shared Memory Configuration
|
|
62
|
+
# =============================================================================
|
|
63
|
+
|
|
64
|
+
# Enable shared memory transport for high-performance IPC between processes
|
|
65
|
+
# When enabled, large messages use shared memory instead of TCP, providing
|
|
66
|
+
# 4-8x latency improvement for messages > 64 KB. Falls back to TCP when SHM
|
|
67
|
+
# is unavailable (cross-machine, pool exhausted, etc.)
|
|
68
|
+
SHM_ENABLED = True
|
|
69
|
+
|
|
70
|
+
# Shared memory pool size in bytes (256 MB)
|
|
71
|
+
# This is the total amount of shared memory allocated for message transport
|
|
72
|
+
SHM_POOL_SIZE_BYTES = 256 * 1024 * 1024
|
|
73
|
+
|
|
74
|
+
# Message size threshold in bytes for SHM transport (64 KB)
|
|
75
|
+
# Messages larger than this use SHM; smaller messages use TCP
|
|
76
|
+
# Based on benchmarks showing SHM wins for messages >= 64 KB
|
|
77
|
+
SHM_MESSAGE_SIZE_THRESHOLD_BYTES = 64 * 1024
|
|
20
78
|
|
|
21
79
|
|
|
22
80
|
def get_auth_dir() -> Path:
|
|
@@ -42,7 +100,7 @@ def get_ark_dir() -> Path:
|
|
|
42
100
|
:return: Path to the arks directory.
|
|
43
101
|
"""
|
|
44
102
|
|
|
45
|
-
ark_dir = Path(
|
|
103
|
+
ark_dir = Path(ANTIOCH_ARKS_DIR)
|
|
46
104
|
ark_dir.mkdir(parents=True, exist_ok=True)
|
|
47
105
|
return ark_dir
|
|
48
106
|
|
|
@@ -56,6 +114,6 @@ def get_asset_dir() -> Path:
|
|
|
56
114
|
:return: Path to the assets directory.
|
|
57
115
|
"""
|
|
58
116
|
|
|
59
|
-
asset_dir = Path(
|
|
117
|
+
asset_dir = Path(ANTIOCH_ASSETS_DIR)
|
|
60
118
|
asset_dir.mkdir(parents=True, exist_ok=True)
|
|
61
119
|
return asset_dir
|
common/core/__init__.py
CHANGED
|
@@ -1,16 +1,5 @@
|
|
|
1
|
-
from common.core.agent import (
|
|
2
|
-
Agent,
|
|
3
|
-
AgentError,
|
|
4
|
-
AgentResponse,
|
|
5
|
-
AgentStateResponse,
|
|
6
|
-
AgentValidationError,
|
|
7
|
-
ArkStateResponse,
|
|
8
|
-
ContainerSource,
|
|
9
|
-
ContainerState,
|
|
10
|
-
RecordTelemetryRequest,
|
|
11
|
-
StartArkRequest,
|
|
12
|
-
)
|
|
13
1
|
from common.core.auth import AuthError, AuthHandler, Organization
|
|
2
|
+
from common.core.container import ContainerManager, ContainerManagerError, ContainerSource
|
|
14
3
|
from common.core.registry import (
|
|
15
4
|
get_ark_version_reference,
|
|
16
5
|
get_asset_path,
|
|
@@ -22,24 +11,41 @@ from common.core.registry import (
|
|
|
22
11
|
pull_remote_ark,
|
|
23
12
|
pull_remote_asset,
|
|
24
13
|
)
|
|
14
|
+
from common.core.rome import RomeAuthError, RomeClient, RomeError, RomeNetworkError
|
|
15
|
+
from common.core.telemetry import TelemetryManager
|
|
16
|
+
from common.core.types import (
|
|
17
|
+
ArkReference,
|
|
18
|
+
ArkRegistryMetadata,
|
|
19
|
+
ArkVersionReference,
|
|
20
|
+
AssetReference,
|
|
21
|
+
AssetVersionReference,
|
|
22
|
+
TaskOutcome,
|
|
23
|
+
TaskRun,
|
|
24
|
+
TaskRunner,
|
|
25
|
+
TaskTriggerSource,
|
|
26
|
+
)
|
|
25
27
|
|
|
26
28
|
__all__ = [
|
|
27
|
-
# Agent
|
|
28
|
-
"Agent",
|
|
29
|
-
"AgentError",
|
|
30
|
-
"AgentResponse",
|
|
31
|
-
"AgentStateResponse",
|
|
32
|
-
"AgentValidationError",
|
|
33
|
-
"ArkStateResponse",
|
|
34
|
-
"ContainerSource",
|
|
35
|
-
"ContainerState",
|
|
36
|
-
"RecordTelemetryRequest",
|
|
37
|
-
"StartArkRequest",
|
|
38
29
|
# Auth
|
|
39
30
|
"AuthError",
|
|
40
31
|
"AuthHandler",
|
|
41
32
|
"Organization",
|
|
42
|
-
#
|
|
33
|
+
# Containers
|
|
34
|
+
"ContainerManager",
|
|
35
|
+
"ContainerManagerError",
|
|
36
|
+
"ContainerSource",
|
|
37
|
+
# Registry types
|
|
38
|
+
"ArkReference",
|
|
39
|
+
"ArkRegistryMetadata",
|
|
40
|
+
"ArkVersionReference",
|
|
41
|
+
"AssetReference",
|
|
42
|
+
"AssetVersionReference",
|
|
43
|
+
# Task types
|
|
44
|
+
"TaskOutcome",
|
|
45
|
+
"TaskRun",
|
|
46
|
+
"TaskRunner",
|
|
47
|
+
"TaskTriggerSource",
|
|
48
|
+
# Registry functions
|
|
43
49
|
"get_ark_version_reference",
|
|
44
50
|
"get_asset_path",
|
|
45
51
|
"list_local_arks",
|
|
@@ -49,4 +55,11 @@ __all__ = [
|
|
|
49
55
|
"load_local_ark",
|
|
50
56
|
"pull_remote_ark",
|
|
51
57
|
"pull_remote_asset",
|
|
58
|
+
# Rome
|
|
59
|
+
"RomeAuthError",
|
|
60
|
+
"RomeClient",
|
|
61
|
+
"RomeError",
|
|
62
|
+
"RomeNetworkError",
|
|
63
|
+
# Telemetry
|
|
64
|
+
"TelemetryManager",
|
|
52
65
|
]
|
common/core/auth.py
CHANGED
|
@@ -7,23 +7,18 @@ from pathlib import Path
|
|
|
7
7
|
import requests
|
|
8
8
|
from pydantic import BaseModel
|
|
9
9
|
|
|
10
|
-
from common.constants import
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
# Authentication claims
|
|
24
|
-
AUTH_ORG_ID_CLAIM = "https://antioch.com/org_id"
|
|
25
|
-
AUTH_ORG_NAME_CLAIM = "https://antioch.com/org_name"
|
|
26
|
-
AUTH_ORGANIZATIONS_CLAIM = "https://antioch.com/organizations"
|
|
10
|
+
from common.constants import (
|
|
11
|
+
AUDIENCE,
|
|
12
|
+
AUTH_CLIENT_ID,
|
|
13
|
+
AUTH_GRANT_TYPE,
|
|
14
|
+
AUTH_ORG_ID_CLAIM,
|
|
15
|
+
AUTH_ORG_NAME_CLAIM,
|
|
16
|
+
AUTH_SCOPE,
|
|
17
|
+
AUTH_TIMEOUT_SECONDS,
|
|
18
|
+
AUTH_TOKEN_URL,
|
|
19
|
+
DEVICE_CODE_URL,
|
|
20
|
+
get_auth_dir,
|
|
21
|
+
)
|
|
27
22
|
|
|
28
23
|
|
|
29
24
|
class AuthError(Exception):
|
|
@@ -34,38 +29,51 @@ class AuthError(Exception):
|
|
|
34
29
|
|
|
35
30
|
class Organization(BaseModel):
|
|
36
31
|
"""
|
|
37
|
-
Organization information.
|
|
32
|
+
Organization information extracted from JWT token.
|
|
38
33
|
"""
|
|
39
34
|
|
|
40
35
|
org_id: str
|
|
41
36
|
org_name: str
|
|
42
37
|
|
|
43
38
|
|
|
39
|
+
class UserInfo(BaseModel):
|
|
40
|
+
"""
|
|
41
|
+
User information extracted from JWT token.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
user_id: str
|
|
45
|
+
name: str | None = None
|
|
46
|
+
email: str | None = None
|
|
47
|
+
|
|
48
|
+
|
|
44
49
|
class AuthHandler:
|
|
45
50
|
"""
|
|
46
|
-
Client for handling authentication.
|
|
51
|
+
Client for handling authentication via OAuth2 device code flow.
|
|
47
52
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
- Pulling assets from the remote asset registry
|
|
53
|
+
Manages authentication tokens and organization context for interacting
|
|
54
|
+
with Antioch services.
|
|
51
55
|
"""
|
|
52
56
|
|
|
53
57
|
def __init__(self):
|
|
54
58
|
"""
|
|
55
59
|
Initialize the auth handler.
|
|
60
|
+
|
|
61
|
+
Automatically loads any existing token from disk.
|
|
56
62
|
"""
|
|
57
63
|
|
|
58
64
|
self._token: str | None = None
|
|
59
|
-
self.
|
|
60
|
-
self.
|
|
61
|
-
self._available_orgs: list[Organization] = []
|
|
65
|
+
self._org: Organization | None = None
|
|
66
|
+
self._user: UserInfo | None = None
|
|
62
67
|
self._load_local_token()
|
|
63
68
|
|
|
64
69
|
def login(self) -> None:
|
|
65
70
|
"""
|
|
66
|
-
Authenticate the user via device code flow.
|
|
71
|
+
Authenticate the user via OAuth2 device code flow.
|
|
67
72
|
|
|
68
|
-
|
|
73
|
+
Initiates the device code flow, prompts the user to authenticate
|
|
74
|
+
in their browser, and saves the token to disk on success.
|
|
75
|
+
|
|
76
|
+
:raises AuthError: If authentication fails or times out.
|
|
69
77
|
"""
|
|
70
78
|
|
|
71
79
|
if self.is_authenticated():
|
|
@@ -95,26 +103,27 @@ class AuthHandler:
|
|
|
95
103
|
"client_id": AUTH_CLIENT_ID,
|
|
96
104
|
}
|
|
97
105
|
|
|
98
|
-
authenticated = False
|
|
99
106
|
start_time = time.time()
|
|
100
|
-
while
|
|
107
|
+
while True:
|
|
101
108
|
token_response = requests.post(AUTH_TOKEN_URL, data=token_payload)
|
|
102
109
|
token_data = token_response.json()
|
|
103
110
|
if token_response.status_code == 200:
|
|
104
111
|
print("Authenticated!")
|
|
105
|
-
|
|
106
|
-
|
|
112
|
+
self._token = token_data["access_token"]
|
|
113
|
+
break
|
|
114
|
+
|
|
115
|
+
if token_data["error"] not in ("authorization_pending", "slow_down"):
|
|
107
116
|
print(token_data["error_description"])
|
|
108
117
|
raise AuthError("Error authenticating the user") from Exception(token_data)
|
|
109
|
-
else:
|
|
110
|
-
if time.time() - start_time > AUTH_TIMEOUT_SECONDS:
|
|
111
|
-
raise AuthError("Timeout waiting for authentication")
|
|
112
|
-
time.sleep(device_code_data["interval"])
|
|
113
118
|
|
|
114
|
-
|
|
115
|
-
|
|
119
|
+
if time.time() - start_time > AUTH_TIMEOUT_SECONDS:
|
|
120
|
+
raise AuthError("Timeout waiting for authentication")
|
|
121
|
+
|
|
122
|
+
time.sleep(device_code_data["interval"])
|
|
123
|
+
|
|
116
124
|
if self._token is None:
|
|
117
125
|
raise AuthError("No token received")
|
|
126
|
+
|
|
118
127
|
self._validate_token_claims(self._token)
|
|
119
128
|
self.save_token()
|
|
120
129
|
|
|
@@ -122,69 +131,44 @@ class AuthHandler:
|
|
|
122
131
|
"""
|
|
123
132
|
Check if the user is authenticated.
|
|
124
133
|
|
|
125
|
-
:return: True if authenticated, False otherwise.
|
|
134
|
+
:return: True if authenticated with a valid token, False otherwise.
|
|
126
135
|
"""
|
|
127
136
|
|
|
128
|
-
return self.
|
|
137
|
+
return self._org is not None
|
|
129
138
|
|
|
130
|
-
def
|
|
139
|
+
def get_org(self) -> Organization:
|
|
131
140
|
"""
|
|
132
|
-
|
|
141
|
+
Get the current organization.
|
|
133
142
|
|
|
134
|
-
:
|
|
143
|
+
:return: The current organization.
|
|
135
144
|
:raises AuthError: If the user is not authenticated.
|
|
136
145
|
"""
|
|
137
146
|
|
|
138
|
-
if not self.is_authenticated():
|
|
147
|
+
if not self.is_authenticated() or self._org is None:
|
|
139
148
|
raise AuthError("Not authenticated. Please login first")
|
|
149
|
+
return self._org
|
|
140
150
|
|
|
141
|
-
|
|
142
|
-
if org.org_id == org_id:
|
|
143
|
-
self._current_org = org
|
|
144
|
-
return
|
|
145
|
-
|
|
146
|
-
raise AuthError(f"Organization '{org_id}' is not in your available organizations")
|
|
147
|
-
|
|
148
|
-
def get_current_org(self) -> Organization | None:
|
|
151
|
+
def get_user_info(self) -> UserInfo | None:
|
|
149
152
|
"""
|
|
150
|
-
Get the current
|
|
153
|
+
Get the current user information.
|
|
151
154
|
|
|
152
|
-
:return: The current
|
|
155
|
+
:return: The current user info, or None if not available.
|
|
153
156
|
:raises AuthError: If the user is not authenticated.
|
|
154
157
|
"""
|
|
155
158
|
|
|
156
159
|
if not self.is_authenticated():
|
|
157
160
|
raise AuthError("Not authenticated. Please login first")
|
|
161
|
+
return self._user
|
|
158
162
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def get_user_id(self) -> str | None:
|
|
162
|
-
"""
|
|
163
|
-
Get the user ID.
|
|
164
|
-
|
|
165
|
-
:return: The user ID.
|
|
163
|
+
def get_token(self) -> str:
|
|
166
164
|
"""
|
|
165
|
+
Get the current authentication token.
|
|
167
166
|
|
|
168
|
-
return
|
|
169
|
-
|
|
170
|
-
def get_available_orgs(self) -> list[Organization]:
|
|
171
|
-
"""
|
|
172
|
-
Get the available organizations.
|
|
173
|
-
|
|
174
|
-
:return: The available organizations.
|
|
175
|
-
"""
|
|
176
|
-
|
|
177
|
-
return self._available_orgs
|
|
178
|
-
|
|
179
|
-
def get_token(self) -> str | None:
|
|
180
|
-
"""
|
|
181
|
-
Get the token.
|
|
182
|
-
|
|
183
|
-
:return: The token.
|
|
167
|
+
:return: The JWT access token.
|
|
184
168
|
:raises AuthError: If the user is not authenticated.
|
|
185
169
|
"""
|
|
186
170
|
|
|
187
|
-
if not self.is_authenticated():
|
|
171
|
+
if not self.is_authenticated() or self._token is None:
|
|
188
172
|
raise AuthError("Not authenticated. Please login first")
|
|
189
173
|
return self._token
|
|
190
174
|
|
|
@@ -192,22 +176,16 @@ class AuthHandler:
|
|
|
192
176
|
"""
|
|
193
177
|
Save the authentication token and organization data to disk.
|
|
194
178
|
|
|
179
|
+
Creates the token file with restrictive permissions (0600).
|
|
180
|
+
|
|
195
181
|
:raises AuthError: If not authenticated.
|
|
196
182
|
"""
|
|
197
183
|
|
|
198
184
|
if not self.is_authenticated():
|
|
199
185
|
raise AuthError("Not authenticated. Please login first")
|
|
200
186
|
|
|
201
|
-
stored_data = {
|
|
202
|
-
"token": self._token,
|
|
203
|
-
"current_org": self._current_org.model_dump() if self._current_org else None,
|
|
204
|
-
"available_orgs": [org.model_dump() for org in self._available_orgs],
|
|
205
|
-
}
|
|
206
|
-
|
|
187
|
+
stored_data = {"token": self._token, "org": self._org.model_dump() if self._org else None}
|
|
207
188
|
token_path = self._get_token_path()
|
|
208
|
-
|
|
209
|
-
# Create file with restrictive permissions (owner read/write only)
|
|
210
|
-
# Use os.open to atomically create file with 0o600 permissions
|
|
211
189
|
fd = os.open(token_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
|
212
190
|
with os.fdopen(fd, "w") as f:
|
|
213
191
|
json.dump(stored_data, f, indent=2)
|
|
@@ -223,9 +201,9 @@ class AuthHandler:
|
|
|
223
201
|
|
|
224
202
|
def _load_local_token(self) -> None:
|
|
225
203
|
"""
|
|
226
|
-
Load the authentication token
|
|
204
|
+
Load the authentication token from disk.
|
|
227
205
|
|
|
228
|
-
Silently returns if no token exists
|
|
206
|
+
Silently returns if no token exists. Clears invalid or expired tokens.
|
|
229
207
|
"""
|
|
230
208
|
|
|
231
209
|
token_path = self._get_token_path()
|
|
@@ -242,17 +220,16 @@ class AuthHandler:
|
|
|
242
220
|
|
|
243
221
|
# Validate and extract all claims from token
|
|
244
222
|
self._validate_token_claims(self._token)
|
|
223
|
+
if stored_data.get("org"):
|
|
224
|
+
self._org = Organization(**stored_data["org"])
|
|
245
225
|
except Exception as e:
|
|
246
226
|
print(f"Error loading local token: {e}")
|
|
247
|
-
|
|
248
|
-
# Clear invalid or expired tokens
|
|
249
227
|
self._token = None
|
|
250
228
|
self.clear_token()
|
|
251
|
-
return
|
|
252
229
|
|
|
253
|
-
def _validate_token_claims(self, token: str):
|
|
230
|
+
def _validate_token_claims(self, token: str) -> None:
|
|
254
231
|
"""
|
|
255
|
-
Validate the token and extract
|
|
232
|
+
Validate the token and extract organization and user information.
|
|
256
233
|
|
|
257
234
|
:param token: The JWT token to validate.
|
|
258
235
|
:raises AuthError: If the token is invalid, expired, or missing required claims.
|
|
@@ -262,7 +239,7 @@ class AuthHandler:
|
|
|
262
239
|
if len(parts) != 3:
|
|
263
240
|
raise AuthError("Invalid token format")
|
|
264
241
|
|
|
265
|
-
# Decode the payload
|
|
242
|
+
# Decode the payload
|
|
266
243
|
payload_encoded = parts[1]
|
|
267
244
|
padding = len(payload_encoded) % 4
|
|
268
245
|
if padding:
|
|
@@ -275,29 +252,27 @@ class AuthHandler:
|
|
|
275
252
|
if exp and time.time() > exp:
|
|
276
253
|
raise AuthError("Token has expired")
|
|
277
254
|
|
|
278
|
-
# Extract
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
for org in payload.get(AUTH_ORGANIZATIONS_CLAIM, [])
|
|
294
|
-
]
|
|
255
|
+
# Extract organization
|
|
256
|
+
org_id = payload.get(AUTH_ORG_ID_CLAIM)
|
|
257
|
+
org_name = payload.get(AUTH_ORG_NAME_CLAIM)
|
|
258
|
+
if not org_id or not org_name:
|
|
259
|
+
raise AuthError("Organization information not found in token claims")
|
|
260
|
+
self._org = Organization(org_id=org_id, org_name=org_name)
|
|
261
|
+
|
|
262
|
+
# Extract user info (optional claims)
|
|
263
|
+
user_id = payload.get("sub")
|
|
264
|
+
if user_id:
|
|
265
|
+
self._user = UserInfo(
|
|
266
|
+
user_id=user_id,
|
|
267
|
+
name=payload.get("name") or payload.get("nickname"),
|
|
268
|
+
email=payload.get("email"),
|
|
269
|
+
)
|
|
295
270
|
|
|
296
271
|
def _get_token_path(self) -> Path:
|
|
297
272
|
"""
|
|
298
273
|
Get the token file path.
|
|
299
274
|
|
|
300
|
-
:return: Path to the token file.
|
|
275
|
+
:return: Path to the token.json file.
|
|
301
276
|
"""
|
|
302
277
|
|
|
303
278
|
return get_auth_dir() / "token.json"
|