reactor-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
reactor_sdk/types.py ADDED
@@ -0,0 +1,255 @@
1
+ """
2
+ Type definitions for the Reactor SDK.
3
+
4
+ This module contains all the type definitions, enums, and data classes
5
+ used throughout the SDK.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass, field
11
+ from enum import Enum
12
+ from typing import Any, Callable, Literal, Optional, TypedDict
13
+
14
+ import numpy as np
15
+ from numpy.typing import NDArray
16
+
17
+
18
+ # =============================================================================
19
+ # Status Enums
20
+ # =============================================================================
21
+
22
+
23
+ class ReactorStatus(Enum):
24
+ """Status of the Reactor connection."""
25
+
26
+ DISCONNECTED = "disconnected" # Not connected to anything
27
+ CONNECTING = "connecting" # Establishing connection to coordinator
28
+ WAITING = "waiting" # Connected to coordinator, waiting for GPU assignment
29
+ READY = "ready" # Connected to GPU machine, can send/receive messages
30
+
31
+
32
+ class GPUMachineStatus(Enum):
33
+ """Status of the GPU machine WebRTC connection."""
34
+
35
+ DISCONNECTED = "disconnected"
36
+ CONNECTING = "connecting"
37
+ CONNECTED = "connected"
38
+ ERROR = "error"
39
+
40
+
41
+ # =============================================================================
42
+ # Event Types
43
+ # =============================================================================
44
+
45
+ # Reactor event types
46
+ ReactorEvent = Literal[
47
+ "status_changed", # Updates on the reactor status
48
+ "session_id_changed", # Updates on the session ID
49
+ "new_message", # New messages from the machine
50
+ "stream_changed", # Video stream has changed
51
+ "error", # Error events with ReactorError details
52
+ "session_expiration_changed", # Session expiration has changed
53
+ ]
54
+
55
+ # GPU Machine event types
56
+ GPUMachineEvent = Literal[
57
+ "status_changed", # Connection state changes
58
+ "track_received", # Remote track received
59
+ "track_removed", # Remote track removed
60
+ "application", # Data channel messages
61
+ ]
62
+
63
+
64
+ # =============================================================================
65
+ # Error Types
66
+ # =============================================================================
67
+
68
+
69
+ @dataclass
70
+ class ReactorError:
71
+ """Information about an error that occurred in the Reactor."""
72
+
73
+ code: str
74
+ message: str
75
+ timestamp: float
76
+ recoverable: bool
77
+ component: Literal["coordinator", "gpu"]
78
+ retry_after: Optional[float] = None
79
+
80
+ def __str__(self) -> str:
81
+ return f"[{self.component}:{self.code}] {self.message}"
82
+
83
+
84
+ class ConflictError(Exception):
85
+ """Raised when a connection conflict occurs (e.g., superseded by newer request)."""
86
+
87
+ pass
88
+
89
+
90
+ # =============================================================================
91
+ # State Types
92
+ # =============================================================================
93
+
94
+
95
+ @dataclass
96
+ class ReactorState:
97
+ """Current state of the Reactor including status and error info."""
98
+
99
+ status: ReactorStatus
100
+ last_error: Optional[ReactorError] = None
101
+
102
+
103
+ # =============================================================================
104
+ # Callback Types
105
+ # =============================================================================
106
+
107
+ # Type for frame callback function - receives numpy RGB frame (H, W, 3)
108
+ FrameCallback = Callable[[NDArray[np.uint8]], None]
109
+
110
+ # Type for event handler function
111
+ EventHandler = Callable[..., None]
112
+
113
+
114
+ # =============================================================================
115
+ # API Response Types
116
+ # =============================================================================
117
+
118
+
119
+ class IceServerCredentials(TypedDict, total=False):
120
+ """Credentials for an ICE server."""
121
+
122
+ username: str
123
+ password: str
124
+
125
+
126
+ class IceServerConfig(TypedDict):
127
+ """Configuration for a single ICE server."""
128
+
129
+ uris: list[str]
130
+ credentials: Optional[IceServerCredentials]
131
+
132
+
133
+ class IceServersResponse(TypedDict):
134
+ """Response from the ICE servers endpoint."""
135
+
136
+ ice_servers: list[IceServerConfig]
137
+
138
+
139
+ class ModelConfig(TypedDict):
140
+ """Model configuration in session requests."""
141
+
142
+ name: str
143
+
144
+
145
+ class CreateSessionRequest(TypedDict):
146
+ """Request body for creating a session."""
147
+
148
+ model: ModelConfig
149
+ sdp_offer: str
150
+ extra_args: dict[str, Any]
151
+
152
+
153
+ class CreateSessionResponse(TypedDict):
154
+ """Response from creating a session."""
155
+
156
+ session_id: str
157
+
158
+
159
+ class SDPParamsRequest(TypedDict):
160
+ """Request body for SDP params endpoint."""
161
+
162
+ sdp_offer: str
163
+ extra_args: dict[str, Any]
164
+
165
+
166
+ class SDPParamsResponse(TypedDict):
167
+ """Response from SDP params endpoint."""
168
+
169
+ sdp_answer: str
170
+ extra_args: dict[str, Any]
171
+
172
+
173
+ class SessionState(Enum):
174
+ """State of a session on the coordinator."""
175
+
176
+ CREATED = "CREATED"
177
+ PENDING = "PENDING"
178
+ SUSPENDED = "SUSPENDED"
179
+ WAITING = "WAITING"
180
+ ACTIVE = "ACTIVE"
181
+ INACTIVE = "INACTIVE"
182
+ CLOSED = "CLOSED"
183
+
184
+
185
+ class SessionInfoResponse(TypedDict):
186
+ """Response from session info endpoint."""
187
+
188
+ session_id: str
189
+ state: str
190
+
191
+
192
+ # =============================================================================
193
+ # WebRTC Types
194
+ # =============================================================================
195
+
196
+
197
+ @dataclass
198
+ class RTCIceServer:
199
+ """ICE server configuration for WebRTC."""
200
+
201
+ urls: list[str]
202
+ username: Optional[str] = None
203
+ credential: Optional[str] = None
204
+
205
+
206
+ @dataclass
207
+ class WebRTCConfig:
208
+ """Configuration for WebRTC peer connection."""
209
+
210
+ ice_servers: list[RTCIceServer]
211
+ data_channel_label: str = "data"
212
+
213
+
214
+ # =============================================================================
215
+ # Command Schema Types (for capabilities)
216
+ # =============================================================================
217
+
218
+
219
+ class ParameterSchema(TypedDict, total=False):
220
+ """Schema for a command parameter."""
221
+
222
+ description: str
223
+ type: str # "number", "integer", "string", "boolean"
224
+ minimum: float
225
+ maximum: float
226
+ required: bool
227
+ enum: list[str]
228
+
229
+
230
+ class CommandSchema(TypedDict):
231
+ """Schema for a command."""
232
+
233
+ description: str
234
+ schema: dict[str, ParameterSchema]
235
+
236
+
237
+ class CapabilitiesMessage(TypedDict):
238
+ """Message containing model capabilities/commands."""
239
+
240
+ commands: dict[str, CommandSchema]
241
+
242
+
243
+ # =============================================================================
244
+ # Video Frame Types
245
+ # =============================================================================
246
+
247
+
248
+ @dataclass
249
+ class VideoFrameInfo:
250
+ """Information about a video frame."""
251
+
252
+ width: int
253
+ height: int
254
+ format: str = "rgb24"
255
+ timestamp: float = field(default=0.0)
@@ -0,0 +1,25 @@
1
+ """Utility modules for the Reactor SDK."""
2
+
3
+ from reactor_sdk.utils.tokens import fetch_jwt_token
4
+ from reactor_sdk.utils.webrtc import (
5
+ WebRTCConfig,
6
+ create_peer_connection,
7
+ create_data_channel,
8
+ create_offer,
9
+ set_remote_description,
10
+ transform_ice_servers,
11
+ wait_for_ice_gathering,
12
+ )
13
+
14
+ __all__ = [
15
+ # Token utilities
16
+ "fetch_jwt_token",
17
+ # WebRTC utilities
18
+ "WebRTCConfig",
19
+ "create_peer_connection",
20
+ "create_data_channel",
21
+ "create_offer",
22
+ "set_remote_description",
23
+ "transform_ice_servers",
24
+ "wait_for_ice_gathering",
25
+ ]
@@ -0,0 +1,64 @@
1
+ """
2
+ Token utilities for the Reactor SDK.
3
+
4
+ This module provides functions for fetching JWT tokens from the coordinator.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from typing import Optional
11
+
12
+ import aiohttp
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ # Default coordinator URL
18
+ PROD_COORDINATOR_URL = "https://api.reactor.inc"
19
+
20
+
21
+ async def fetch_jwt_token(
22
+ api_key: str,
23
+ coordinator_url: str = PROD_COORDINATOR_URL,
24
+ ) -> str:
25
+ """
26
+ Fetch a JWT token from the coordinator using an API key.
27
+
28
+ This is safe to use in Python applications (CLI tools, scripts, servers)
29
+ since the API key is not exposed to end users like it would be in
30
+ browser-based JavaScript applications.
31
+
32
+ Args:
33
+ api_key: Your Reactor API key.
34
+ coordinator_url: Optional coordinator URL, defaults to production.
35
+
36
+ Returns:
37
+ The JWT token string.
38
+
39
+ Raises:
40
+ RuntimeError: If the token fetch fails.
41
+
42
+ Example:
43
+ >>> token = await fetch_jwt_token("your-api-key")
44
+ >>> reactor = Reactor(model_name="my-model")
45
+ >>> await reactor.connect(jwt_token=token)
46
+ """
47
+ url = f"{coordinator_url.rstrip('/')}/tokens"
48
+
49
+ async with aiohttp.ClientSession() as session:
50
+ async with session.get(
51
+ url,
52
+ headers={"X-API-Key": api_key},
53
+ ) as response:
54
+ if not response.ok:
55
+ error_text = await response.text()
56
+ raise RuntimeError(
57
+ f"Failed to fetch JWT token: {response.status} {error_text}"
58
+ )
59
+
60
+ data = await response.json()
61
+ jwt_token: str = data["jwt"]
62
+
63
+ logger.debug("Successfully fetched JWT token")
64
+ return jwt_token
@@ -0,0 +1,315 @@
1
+ """
2
+ WebRTC utility functions for the Reactor SDK.
3
+
4
+ This module provides stateless utility functions for WebRTC operations
5
+ using aiortc.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import json
12
+ import logging
13
+ from dataclasses import dataclass
14
+ from typing import Any, Optional
15
+
16
+ from aiortc import RTCConfiguration, RTCDataChannel, RTCIceServer, RTCPeerConnection
17
+ from aiortc import RTCSessionDescription
18
+
19
+ from reactor_sdk.types import IceServersResponse
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ # =============================================================================
25
+ # Configuration
26
+ # =============================================================================
27
+
28
+
29
+ @dataclass
30
+ class WebRTCConfig:
31
+ """Configuration for WebRTC peer connection."""
32
+
33
+ ice_servers: list[RTCIceServer]
34
+ data_channel_label: str = "data"
35
+
36
+
37
+ DEFAULT_DATA_CHANNEL_LABEL = "data"
38
+ DEFAULT_ICE_GATHERING_TIMEOUT = 5.0 # seconds
39
+
40
+
41
+ # =============================================================================
42
+ # Peer Connection Creation
43
+ # =============================================================================
44
+
45
+
46
+ def create_peer_connection(config: WebRTCConfig) -> RTCPeerConnection:
47
+ """
48
+ Create a new RTCPeerConnection with the specified configuration.
49
+
50
+ Args:
51
+ config: WebRTC configuration with ICE servers.
52
+
53
+ Returns:
54
+ A new RTCPeerConnection instance.
55
+ """
56
+ rtc_config = RTCConfiguration(iceServers=config.ice_servers)
57
+ return RTCPeerConnection(configuration=rtc_config)
58
+
59
+
60
+ def create_data_channel(
61
+ pc: RTCPeerConnection,
62
+ label: Optional[str] = None,
63
+ ) -> RTCDataChannel:
64
+ """
65
+ Create a data channel on the peer connection.
66
+
67
+ Args:
68
+ pc: The RTCPeerConnection.
69
+ label: Label for the data channel (defaults to "data").
70
+
71
+ Returns:
72
+ The created RTCDataChannel.
73
+ """
74
+ return pc.createDataChannel(label or DEFAULT_DATA_CHANNEL_LABEL)
75
+
76
+
77
+ # =============================================================================
78
+ # SDP Offer/Answer
79
+ # =============================================================================
80
+
81
+
82
+ async def create_offer(pc: RTCPeerConnection) -> str:
83
+ """
84
+ Create an SDP offer on the peer connection.
85
+
86
+ Waits for ICE gathering to complete before returning.
87
+
88
+ Args:
89
+ pc: The RTCPeerConnection.
90
+
91
+ Returns:
92
+ The SDP offer string.
93
+
94
+ Raises:
95
+ RuntimeError: If local description creation fails.
96
+ """
97
+ offer = await pc.createOffer()
98
+ await pc.setLocalDescription(offer)
99
+
100
+ await wait_for_ice_gathering(pc)
101
+
102
+ local_description = pc.localDescription
103
+ if local_description is None:
104
+ raise RuntimeError("Failed to create local description")
105
+
106
+ return local_description.sdp
107
+
108
+
109
+ async def create_answer(pc: RTCPeerConnection, offer_sdp: str) -> str:
110
+ """
111
+ Create an SDP answer in response to a received offer.
112
+
113
+ Waits for ICE gathering to complete before returning.
114
+
115
+ Args:
116
+ pc: The RTCPeerConnection.
117
+ offer_sdp: The SDP offer string from the remote peer.
118
+
119
+ Returns:
120
+ The SDP answer string.
121
+
122
+ Raises:
123
+ RuntimeError: If local description creation fails.
124
+ """
125
+ await set_remote_description(pc, offer_sdp, "offer")
126
+
127
+ answer = await pc.createAnswer()
128
+ await pc.setLocalDescription(answer)
129
+
130
+ await wait_for_ice_gathering(pc)
131
+
132
+ local_description = pc.localDescription
133
+ if local_description is None:
134
+ raise RuntimeError("Failed to create local description")
135
+
136
+ return local_description.sdp
137
+
138
+
139
+ async def set_remote_description(
140
+ pc: RTCPeerConnection,
141
+ sdp: str,
142
+ sdp_type: str = "answer",
143
+ ) -> None:
144
+ """
145
+ Set the remote description on the peer connection.
146
+
147
+ Args:
148
+ pc: The RTCPeerConnection.
149
+ sdp: The SDP string.
150
+ sdp_type: The type of SDP ("offer" or "answer").
151
+ """
152
+ session_description = RTCSessionDescription(sdp=sdp, type=sdp_type)
153
+ await pc.setRemoteDescription(session_description)
154
+
155
+
156
+ def get_local_description(pc: RTCPeerConnection) -> Optional[str]:
157
+ """
158
+ Get the local SDP description from the peer connection.
159
+
160
+ Args:
161
+ pc: The RTCPeerConnection.
162
+
163
+ Returns:
164
+ The local SDP string, or None if not set.
165
+ """
166
+ desc = pc.localDescription
167
+ if desc is None:
168
+ return None
169
+ return desc.sdp
170
+
171
+
172
+ # =============================================================================
173
+ # ICE Handling
174
+ # =============================================================================
175
+
176
+
177
+ def transform_ice_servers(response: IceServersResponse) -> list[RTCIceServer]:
178
+ """
179
+ Transform ICE servers from the coordinator API format to RTCIceServer format.
180
+
181
+ Args:
182
+ response: The parsed IceServersResponse from the coordinator.
183
+
184
+ Returns:
185
+ List of RTCIceServer objects for WebRTC peer connection configuration.
186
+ """
187
+ ice_servers: list[RTCIceServer] = []
188
+
189
+ for server in response["ice_servers"]:
190
+ if server.get("credentials"):
191
+ creds = server["credentials"]
192
+ ice_server = RTCIceServer(
193
+ urls=server["uris"],
194
+ username=creds.get("username"),
195
+ credential=creds.get("password"),
196
+ )
197
+ else:
198
+ ice_server = RTCIceServer(urls=server["uris"])
199
+
200
+ ice_servers.append(ice_server)
201
+
202
+ return ice_servers
203
+
204
+
205
+ async def wait_for_ice_gathering(
206
+ pc: RTCPeerConnection,
207
+ timeout: float = DEFAULT_ICE_GATHERING_TIMEOUT,
208
+ ) -> None:
209
+ """
210
+ Wait for ICE gathering to complete with a timeout.
211
+
212
+ Args:
213
+ pc: The RTCPeerConnection.
214
+ timeout: Maximum time to wait in seconds.
215
+ """
216
+ if pc.iceGatheringState == "complete":
217
+ return
218
+
219
+ gathering_complete = asyncio.Event()
220
+
221
+ @pc.on("icegatheringstatechange")
222
+ def on_ice_gathering_state_change() -> None:
223
+ if pc.iceGatheringState == "complete":
224
+ gathering_complete.set()
225
+
226
+ try:
227
+ await asyncio.wait_for(gathering_complete.wait(), timeout=timeout)
228
+ except asyncio.TimeoutError:
229
+ logger.warning(
230
+ f"ICE gathering timed out after {timeout}s, proceeding with current candidates"
231
+ )
232
+
233
+
234
+ # =============================================================================
235
+ # Data Channel Messaging
236
+ # =============================================================================
237
+
238
+
239
+ def send_message(channel: RTCDataChannel, command: str, data: Any) -> None:
240
+ """
241
+ Send a message through a data channel.
242
+
243
+ Args:
244
+ channel: The RTCDataChannel.
245
+ command: The command type.
246
+ data: The data to send with the command.
247
+
248
+ Raises:
249
+ RuntimeError: If the data channel is not open.
250
+ """
251
+ if channel.readyState != "open":
252
+ raise RuntimeError(f"Data channel not open: {channel.readyState}")
253
+
254
+ json_data = data if isinstance(data, dict) else json.loads(data) if isinstance(data, str) else data
255
+ payload = {"type": command, "data": json_data}
256
+ channel.send(json.dumps(payload))
257
+
258
+
259
+ def parse_message(data: Any) -> Any:
260
+ """
261
+ Parse a received data channel message, attempting JSON parse.
262
+
263
+ Args:
264
+ data: The raw message data.
265
+
266
+ Returns:
267
+ The parsed message (dict if JSON, otherwise original data).
268
+ """
269
+ if isinstance(data, str):
270
+ try:
271
+ return json.loads(data)
272
+ except json.JSONDecodeError:
273
+ return data
274
+ return data
275
+
276
+
277
+ # =============================================================================
278
+ # Connection State
279
+ # =============================================================================
280
+
281
+
282
+ def is_connected(pc: RTCPeerConnection) -> bool:
283
+ """
284
+ Check if the peer connection is in a connected state.
285
+
286
+ Args:
287
+ pc: The RTCPeerConnection.
288
+
289
+ Returns:
290
+ True if connected.
291
+ """
292
+ return pc.connectionState == "connected"
293
+
294
+
295
+ def is_closed(pc: RTCPeerConnection) -> bool:
296
+ """
297
+ Check if the peer connection is closed or failed.
298
+
299
+ Args:
300
+ pc: The RTCPeerConnection.
301
+
302
+ Returns:
303
+ True if closed or failed.
304
+ """
305
+ return pc.connectionState in ("closed", "failed")
306
+
307
+
308
+ async def close_peer_connection(pc: RTCPeerConnection) -> None:
309
+ """
310
+ Close the peer connection and clean up.
311
+
312
+ Args:
313
+ pc: The RTCPeerConnection.
314
+ """
315
+ await pc.close()