reactor-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reactor_sdk/__init__.py +57 -0
- reactor_sdk/coordinator/__init__.py +13 -0
- reactor_sdk/coordinator/client.py +362 -0
- reactor_sdk/coordinator/local_client.py +163 -0
- reactor_sdk/interface.py +203 -0
- reactor_sdk/model/__init__.py +11 -0
- reactor_sdk/model/client.py +647 -0
- reactor_sdk/py.typed +2 -0
- reactor_sdk/reactor.py +739 -0
- reactor_sdk/types.py +255 -0
- reactor_sdk/utils/__init__.py +25 -0
- reactor_sdk/utils/tokens.py +64 -0
- reactor_sdk/utils/webrtc.py +315 -0
- reactor_sdk-0.1.0.dist-info/METADATA +204 -0
- reactor_sdk-0.1.0.dist-info/RECORD +17 -0
- reactor_sdk-0.1.0.dist-info/WHEEL +4 -0
- reactor_sdk-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,647 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ModelClient for the Reactor SDK.
|
|
3
|
+
|
|
4
|
+
This module handles the WebRTC connection to the model,
|
|
5
|
+
including video streaming and data channel messaging.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
from typing import Any, Callable, Optional, Set
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from aiortc import (
|
|
17
|
+
MediaStreamTrack,
|
|
18
|
+
RTCDataChannel,
|
|
19
|
+
RTCIceServer,
|
|
20
|
+
RTCPeerConnection,
|
|
21
|
+
RTCRtpSender,
|
|
22
|
+
RTCRtpTransceiver,
|
|
23
|
+
)
|
|
24
|
+
from aiortc.codecs import h264
|
|
25
|
+
from av import VideoFrame
|
|
26
|
+
from numpy.typing import NDArray
|
|
27
|
+
|
|
28
|
+
from reactor_sdk.types import FrameCallback, GPUMachineEvent, GPUMachineStatus
|
|
29
|
+
from reactor_sdk.utils.webrtc import (
|
|
30
|
+
WebRTCConfig,
|
|
31
|
+
create_data_channel,
|
|
32
|
+
create_offer,
|
|
33
|
+
create_peer_connection,
|
|
34
|
+
parse_message,
|
|
35
|
+
send_message,
|
|
36
|
+
set_remote_description,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# Type for event handlers
|
|
43
|
+
GPUEventHandler = Callable[..., None]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ModelClient:
|
|
47
|
+
"""
|
|
48
|
+
Manages the WebRTC connection to the model.
|
|
49
|
+
|
|
50
|
+
Handles video streaming (both receiving and sending), data channel
|
|
51
|
+
messaging, and connection lifecycle.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, config: WebRTCConfig) -> None:
|
|
55
|
+
"""
|
|
56
|
+
Initialize the GPUMachineClient.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
config: WebRTC configuration including ICE servers.
|
|
60
|
+
"""
|
|
61
|
+
self._config = config
|
|
62
|
+
self._peer_connection: Optional[RTCPeerConnection] = None
|
|
63
|
+
self._data_channel: Optional[RTCDataChannel] = None
|
|
64
|
+
self._status: GPUMachineStatus = GPUMachineStatus.DISCONNECTED
|
|
65
|
+
self._published_track: Optional[MediaStreamTrack] = None
|
|
66
|
+
self._video_transceiver: Optional[RTCRtpTransceiver] = None
|
|
67
|
+
self._remote_track: Optional[MediaStreamTrack] = None
|
|
68
|
+
|
|
69
|
+
# Event system
|
|
70
|
+
self._event_listeners: dict[GPUMachineEvent, Set[GPUEventHandler]] = {}
|
|
71
|
+
|
|
72
|
+
# Frame callback for single-frame access
|
|
73
|
+
self._frame_callback: Optional[FrameCallback] = None
|
|
74
|
+
self._frame_task: Optional[asyncio.Task[None]] = None
|
|
75
|
+
|
|
76
|
+
# Stop event for cooperative shutdown
|
|
77
|
+
self._stop_event = asyncio.Event()
|
|
78
|
+
|
|
79
|
+
# Connection state tracking - both must be true to be "connected"
|
|
80
|
+
self._peer_connection_ready = False
|
|
81
|
+
self._data_channel_open = False
|
|
82
|
+
|
|
83
|
+
# =========================================================================
|
|
84
|
+
# Event Emitter API
|
|
85
|
+
# =========================================================================
|
|
86
|
+
|
|
87
|
+
def on(self, event: GPUMachineEvent, handler: GPUEventHandler) -> None:
|
|
88
|
+
"""
|
|
89
|
+
Register an event handler.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
event: The event name.
|
|
93
|
+
handler: The callback function.
|
|
94
|
+
"""
|
|
95
|
+
if event not in self._event_listeners:
|
|
96
|
+
self._event_listeners[event] = set()
|
|
97
|
+
self._event_listeners[event].add(handler)
|
|
98
|
+
|
|
99
|
+
def off(self, event: GPUMachineEvent, handler: GPUEventHandler) -> None:
|
|
100
|
+
"""
|
|
101
|
+
Unregister an event handler.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
event: The event name.
|
|
105
|
+
handler: The callback function to remove.
|
|
106
|
+
"""
|
|
107
|
+
if event in self._event_listeners:
|
|
108
|
+
self._event_listeners[event].discard(handler)
|
|
109
|
+
|
|
110
|
+
def _emit(self, event: GPUMachineEvent, *args: Any) -> None:
|
|
111
|
+
"""
|
|
112
|
+
Emit an event to all registered handlers.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
event: The event name.
|
|
116
|
+
*args: Arguments to pass to handlers.
|
|
117
|
+
"""
|
|
118
|
+
if event in self._event_listeners:
|
|
119
|
+
for handler in self._event_listeners[event]:
|
|
120
|
+
try:
|
|
121
|
+
handler(*args)
|
|
122
|
+
except Exception as e:
|
|
123
|
+
logger.exception(f"Error in event handler for '{event}': {e}")
|
|
124
|
+
|
|
125
|
+
# =========================================================================
|
|
126
|
+
# Frame Callback
|
|
127
|
+
# =========================================================================
|
|
128
|
+
|
|
129
|
+
def set_frame_callback(self, callback: Optional[FrameCallback]) -> None:
|
|
130
|
+
"""
|
|
131
|
+
Set a callback to receive individual video frames.
|
|
132
|
+
|
|
133
|
+
The callback will be called with each received frame as a numpy array
|
|
134
|
+
in RGB format with shape (H, W, 3).
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
callback: The callback function, or None to clear.
|
|
138
|
+
"""
|
|
139
|
+
self._frame_callback = callback
|
|
140
|
+
|
|
141
|
+
# =========================================================================
|
|
142
|
+
# SDP & Connection
|
|
143
|
+
# =========================================================================
|
|
144
|
+
|
|
145
|
+
async def create_offer(self) -> str:
|
|
146
|
+
"""
|
|
147
|
+
Create an SDP offer for initiating a connection.
|
|
148
|
+
|
|
149
|
+
Must be called before connect().
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
The SDP offer string.
|
|
153
|
+
"""
|
|
154
|
+
# Create peer connection if not exists
|
|
155
|
+
if self._peer_connection is None:
|
|
156
|
+
self._peer_connection = create_peer_connection(self._config)
|
|
157
|
+
self._setup_peer_connection_handlers()
|
|
158
|
+
|
|
159
|
+
# Create data channel before offer (offerer creates the channel)
|
|
160
|
+
self._data_channel = create_data_channel(
|
|
161
|
+
self._peer_connection,
|
|
162
|
+
self._config.data_channel_label,
|
|
163
|
+
)
|
|
164
|
+
self._setup_data_channel_handlers()
|
|
165
|
+
|
|
166
|
+
# Add sendrecv video transceiver for bidirectional video
|
|
167
|
+
self._video_transceiver = self._peer_connection.addTransceiver(
|
|
168
|
+
"video",
|
|
169
|
+
direction="sendrecv",
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Set codec preferences to prefer H.264 over VP8
|
|
173
|
+
# This helps ensure codec compatibility with the server
|
|
174
|
+
self._set_codec_preferences()
|
|
175
|
+
|
|
176
|
+
offer = await create_offer(self._peer_connection)
|
|
177
|
+
logger.debug("Created SDP offer")
|
|
178
|
+
return offer
|
|
179
|
+
|
|
180
|
+
def _set_codec_preferences(self) -> None:
|
|
181
|
+
"""
|
|
182
|
+
Set codec preferences to prefer H.264 over VP8.
|
|
183
|
+
|
|
184
|
+
H.264 is more widely supported and often provides better compatibility.
|
|
185
|
+
"""
|
|
186
|
+
if self._video_transceiver is None:
|
|
187
|
+
return
|
|
188
|
+
|
|
189
|
+
try:
|
|
190
|
+
# Get available video codecs
|
|
191
|
+
capabilities = RTCRtpSender.getCapabilities("video")
|
|
192
|
+
if capabilities is None:
|
|
193
|
+
logger.debug("No video capabilities available")
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
# Sort codecs to prefer H.264, then VP8, then others
|
|
197
|
+
preferred_codecs = []
|
|
198
|
+
other_codecs = []
|
|
199
|
+
|
|
200
|
+
for codec in capabilities.codecs:
|
|
201
|
+
if codec.mimeType.lower() == "video/h264":
|
|
202
|
+
preferred_codecs.insert(0, codec) # H.264 first
|
|
203
|
+
elif codec.mimeType.lower() == "video/vp8":
|
|
204
|
+
preferred_codecs.append(codec) # VP8 second
|
|
205
|
+
else:
|
|
206
|
+
other_codecs.append(codec)
|
|
207
|
+
|
|
208
|
+
# Combine: H.264 first, then VP8, then others
|
|
209
|
+
all_codecs = preferred_codecs + other_codecs
|
|
210
|
+
|
|
211
|
+
if all_codecs:
|
|
212
|
+
self._video_transceiver.setCodecPreferences(all_codecs)
|
|
213
|
+
codec_names = [c.mimeType for c in all_codecs[:3]]
|
|
214
|
+
logger.debug(f"Set codec preferences: {codec_names}...")
|
|
215
|
+
|
|
216
|
+
except Exception as e:
|
|
217
|
+
# Don't fail if codec preferences can't be set
|
|
218
|
+
logger.debug(f"Could not set codec preferences: {e}")
|
|
219
|
+
|
|
220
|
+
async def connect(self, sdp_answer: str) -> None:
|
|
221
|
+
"""
|
|
222
|
+
Connect to the GPU machine using the provided SDP answer.
|
|
223
|
+
|
|
224
|
+
create_offer() must be called first.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
sdp_answer: The SDP answer from the GPU machine.
|
|
228
|
+
|
|
229
|
+
Raises:
|
|
230
|
+
RuntimeError: If create_offer() was not called first.
|
|
231
|
+
"""
|
|
232
|
+
if self._peer_connection is None:
|
|
233
|
+
raise RuntimeError("Cannot connect - call create_offer() first")
|
|
234
|
+
|
|
235
|
+
if self._peer_connection.signalingState != "have-local-offer":
|
|
236
|
+
raise RuntimeError(
|
|
237
|
+
f"Invalid signaling state: {self._peer_connection.signalingState}"
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
self._set_status(GPUMachineStatus.CONNECTING)
|
|
241
|
+
|
|
242
|
+
try:
|
|
243
|
+
await set_remote_description(self._peer_connection, sdp_answer)
|
|
244
|
+
logger.debug("Remote description set")
|
|
245
|
+
except Exception as e:
|
|
246
|
+
logger.error(f"Failed to connect: {e}")
|
|
247
|
+
self._set_status(GPUMachineStatus.ERROR)
|
|
248
|
+
raise
|
|
249
|
+
|
|
250
|
+
async def disconnect(self) -> None:
|
|
251
|
+
"""
|
|
252
|
+
Disconnect from the GPU machine and clean up resources.
|
|
253
|
+
"""
|
|
254
|
+
# Signal stop to frame processing task
|
|
255
|
+
self._stop_event.set()
|
|
256
|
+
|
|
257
|
+
# Cancel frame processing task
|
|
258
|
+
if self._frame_task is not None:
|
|
259
|
+
self._frame_task.cancel()
|
|
260
|
+
try:
|
|
261
|
+
await self._frame_task
|
|
262
|
+
except asyncio.CancelledError:
|
|
263
|
+
pass
|
|
264
|
+
self._frame_task = None
|
|
265
|
+
|
|
266
|
+
# Unpublish any published track
|
|
267
|
+
if self._published_track is not None:
|
|
268
|
+
await self.unpublish_track()
|
|
269
|
+
|
|
270
|
+
# Close data channel
|
|
271
|
+
if self._data_channel is not None:
|
|
272
|
+
self._data_channel.close()
|
|
273
|
+
self._data_channel = None
|
|
274
|
+
|
|
275
|
+
# Close peer connection
|
|
276
|
+
if self._peer_connection is not None:
|
|
277
|
+
await self._peer_connection.close()
|
|
278
|
+
self._peer_connection = None
|
|
279
|
+
|
|
280
|
+
self._video_transceiver = None
|
|
281
|
+
self._remote_track = None
|
|
282
|
+
self._peer_connection_ready = False
|
|
283
|
+
self._data_channel_open = False
|
|
284
|
+
self._set_status(GPUMachineStatus.DISCONNECTED)
|
|
285
|
+
logger.debug("Disconnected from GPU machine")
|
|
286
|
+
|
|
287
|
+
def get_status(self) -> GPUMachineStatus:
|
|
288
|
+
"""
|
|
289
|
+
Get the current connection status.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
The current GPUMachineStatus.
|
|
293
|
+
"""
|
|
294
|
+
return self._status
|
|
295
|
+
|
|
296
|
+
def get_local_sdp(self) -> Optional[str]:
|
|
297
|
+
"""
|
|
298
|
+
Get the current local SDP description.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
The local SDP string, or None if not set.
|
|
302
|
+
"""
|
|
303
|
+
if self._peer_connection is None:
|
|
304
|
+
return None
|
|
305
|
+
desc = self._peer_connection.localDescription
|
|
306
|
+
return desc.sdp if desc else None
|
|
307
|
+
|
|
308
|
+
def is_offer_still_valid(self) -> bool:
|
|
309
|
+
"""
|
|
310
|
+
Check if the current offer is still valid.
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
True if the offer is valid.
|
|
314
|
+
"""
|
|
315
|
+
if self._peer_connection is None:
|
|
316
|
+
return False
|
|
317
|
+
return self._peer_connection.signalingState == "have-local-offer"
|
|
318
|
+
|
|
319
|
+
# =========================================================================
|
|
320
|
+
# Messaging
|
|
321
|
+
# =========================================================================
|
|
322
|
+
|
|
323
|
+
def send_command(self, command: str, data: Any) -> None:
|
|
324
|
+
"""
|
|
325
|
+
Send a command to the GPU machine via the data channel.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
command: The command type.
|
|
329
|
+
data: The data to send with the command.
|
|
330
|
+
|
|
331
|
+
Raises:
|
|
332
|
+
RuntimeError: If the data channel is not available.
|
|
333
|
+
"""
|
|
334
|
+
if self._data_channel is None:
|
|
335
|
+
raise RuntimeError("Data channel not available")
|
|
336
|
+
|
|
337
|
+
try:
|
|
338
|
+
send_message(self._data_channel, command, data)
|
|
339
|
+
except Exception as e:
|
|
340
|
+
logger.warning(f"Failed to send message: {e}")
|
|
341
|
+
raise
|
|
342
|
+
|
|
343
|
+
# =========================================================================
|
|
344
|
+
# Track Publishing
|
|
345
|
+
# =========================================================================
|
|
346
|
+
|
|
347
|
+
async def publish_track(self, track: MediaStreamTrack) -> None:
|
|
348
|
+
"""
|
|
349
|
+
Publish a track to the GPU machine.
|
|
350
|
+
|
|
351
|
+
Only one track can be published at a time.
|
|
352
|
+
Uses the existing transceiver's sender to replace the track.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
track: The MediaStreamTrack to publish.
|
|
356
|
+
|
|
357
|
+
Raises:
|
|
358
|
+
RuntimeError: If not connected or no video transceiver.
|
|
359
|
+
"""
|
|
360
|
+
if self._peer_connection is None:
|
|
361
|
+
raise RuntimeError("Cannot publish track - not initialized")
|
|
362
|
+
|
|
363
|
+
if self._status != GPUMachineStatus.CONNECTED:
|
|
364
|
+
raise RuntimeError("Cannot publish track - not connected")
|
|
365
|
+
|
|
366
|
+
if self._video_transceiver is None:
|
|
367
|
+
raise RuntimeError("Cannot publish track - no video transceiver")
|
|
368
|
+
|
|
369
|
+
try:
|
|
370
|
+
# Use replaceTrack on the existing transceiver's sender
|
|
371
|
+
# This doesn't require renegotiation
|
|
372
|
+
await self._video_transceiver.sender.replaceTrack(track)
|
|
373
|
+
self._published_track = track
|
|
374
|
+
logger.debug(f"Track published successfully: {track.kind}")
|
|
375
|
+
except Exception as e:
|
|
376
|
+
logger.error(f"Failed to publish track: {e}")
|
|
377
|
+
raise
|
|
378
|
+
|
|
379
|
+
async def unpublish_track(self) -> None:
|
|
380
|
+
"""
|
|
381
|
+
Unpublish the currently published track.
|
|
382
|
+
"""
|
|
383
|
+
if self._video_transceiver is None or self._published_track is None:
|
|
384
|
+
return
|
|
385
|
+
|
|
386
|
+
try:
|
|
387
|
+
# Replace with None to stop sending without renegotiation
|
|
388
|
+
await self._video_transceiver.sender.replaceTrack(None)
|
|
389
|
+
logger.debug("Track unpublished successfully")
|
|
390
|
+
except Exception as e:
|
|
391
|
+
logger.error(f"Failed to unpublish track: {e}")
|
|
392
|
+
raise
|
|
393
|
+
finally:
|
|
394
|
+
self._published_track = None
|
|
395
|
+
|
|
396
|
+
def get_published_track(self) -> Optional[MediaStreamTrack]:
|
|
397
|
+
"""
|
|
398
|
+
Get the currently published track.
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
The published MediaStreamTrack, or None.
|
|
402
|
+
"""
|
|
403
|
+
return self._published_track
|
|
404
|
+
|
|
405
|
+
# =========================================================================
|
|
406
|
+
# Remote Stream Access
|
|
407
|
+
# =========================================================================
|
|
408
|
+
|
|
409
|
+
def get_remote_track(self) -> Optional[MediaStreamTrack]:
|
|
410
|
+
"""
|
|
411
|
+
Get the remote video track from the GPU machine.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
The remote MediaStreamTrack, or None if not available.
|
|
415
|
+
"""
|
|
416
|
+
return self._remote_track
|
|
417
|
+
|
|
418
|
+
# =========================================================================
|
|
419
|
+
# Private Helpers
|
|
420
|
+
# =========================================================================
|
|
421
|
+
|
|
422
|
+
def _set_status(self, new_status: GPUMachineStatus) -> None:
|
|
423
|
+
"""Set the connection status and emit event if changed."""
|
|
424
|
+
if self._status != new_status:
|
|
425
|
+
self._status = new_status
|
|
426
|
+
self._emit("status_changed", new_status)
|
|
427
|
+
|
|
428
|
+
def _check_fully_connected(self) -> None:
|
|
429
|
+
"""
|
|
430
|
+
Check if both peer connection and data channel are ready.
|
|
431
|
+
|
|
432
|
+
Only transitions to CONNECTED status when both conditions are met.
|
|
433
|
+
This prevents sending messages before the data channel is open.
|
|
434
|
+
"""
|
|
435
|
+
if self._peer_connection_ready and self._data_channel_open:
|
|
436
|
+
logger.debug("Both peer connection and data channel ready - fully connected")
|
|
437
|
+
self._set_status(GPUMachineStatus.CONNECTED)
|
|
438
|
+
|
|
439
|
+
def _setup_peer_connection_handlers(self) -> None:
|
|
440
|
+
"""Set up event handlers for the peer connection."""
|
|
441
|
+
if self._peer_connection is None:
|
|
442
|
+
return
|
|
443
|
+
|
|
444
|
+
@self._peer_connection.on("connectionstatechange")
|
|
445
|
+
async def on_connection_state_change() -> None:
|
|
446
|
+
if self._peer_connection is None:
|
|
447
|
+
return
|
|
448
|
+
|
|
449
|
+
state = self._peer_connection.connectionState
|
|
450
|
+
logger.debug(f"Peer connection state: {state}")
|
|
451
|
+
|
|
452
|
+
if state == "connected":
|
|
453
|
+
self._peer_connection_ready = True
|
|
454
|
+
self._check_fully_connected()
|
|
455
|
+
elif state in ("disconnected", "closed"):
|
|
456
|
+
self._peer_connection_ready = False
|
|
457
|
+
self._data_channel_open = False
|
|
458
|
+
self._set_status(GPUMachineStatus.DISCONNECTED)
|
|
459
|
+
elif state == "failed":
|
|
460
|
+
self._peer_connection_ready = False
|
|
461
|
+
self._set_status(GPUMachineStatus.ERROR)
|
|
462
|
+
|
|
463
|
+
@self._peer_connection.on("track")
|
|
464
|
+
def on_track(track: MediaStreamTrack) -> None:
|
|
465
|
+
logger.debug(f"Track received: {track.kind}")
|
|
466
|
+
if track.kind == "video":
|
|
467
|
+
self._remote_track = track
|
|
468
|
+
self._emit("track_received", track)
|
|
469
|
+
# Start frame processing if callback is set
|
|
470
|
+
if self._frame_callback is not None:
|
|
471
|
+
self._start_frame_processing(track)
|
|
472
|
+
|
|
473
|
+
@self._peer_connection.on("icecandidate")
|
|
474
|
+
def on_ice_candidate(candidate: Any) -> None:
|
|
475
|
+
if candidate:
|
|
476
|
+
logger.debug(f"ICE candidate: {candidate}")
|
|
477
|
+
|
|
478
|
+
@self._peer_connection.on("datachannel")
|
|
479
|
+
def on_data_channel(channel: RTCDataChannel) -> None:
|
|
480
|
+
logger.debug(f"Data channel received from remote: {channel.label}")
|
|
481
|
+
self._data_channel = channel
|
|
482
|
+
self._setup_data_channel_handlers()
|
|
483
|
+
|
|
484
|
+
def _setup_data_channel_handlers(self) -> None:
|
|
485
|
+
"""Set up event handlers for the data channel."""
|
|
486
|
+
if self._data_channel is None:
|
|
487
|
+
return
|
|
488
|
+
|
|
489
|
+
@self._data_channel.on("open")
|
|
490
|
+
def on_open() -> None:
|
|
491
|
+
logger.debug("Data channel open")
|
|
492
|
+
self._data_channel_open = True
|
|
493
|
+
self._check_fully_connected()
|
|
494
|
+
|
|
495
|
+
@self._data_channel.on("close")
|
|
496
|
+
def on_close() -> None:
|
|
497
|
+
logger.debug("Data channel closed")
|
|
498
|
+
self._data_channel_open = False
|
|
499
|
+
|
|
500
|
+
@self._data_channel.on("message")
|
|
501
|
+
def on_message(message: str) -> None:
|
|
502
|
+
data = parse_message(message)
|
|
503
|
+
logger.debug(f"Received message: {data}")
|
|
504
|
+
try:
|
|
505
|
+
self._emit("application", data)
|
|
506
|
+
except Exception as e:
|
|
507
|
+
logger.error(f"Failed to handle message: {e}")
|
|
508
|
+
|
|
509
|
+
def _start_frame_processing(self, track: MediaStreamTrack) -> None:
|
|
510
|
+
"""Start the frame processing task for the given track."""
|
|
511
|
+
if self._frame_task is not None:
|
|
512
|
+
self._frame_task.cancel()
|
|
513
|
+
|
|
514
|
+
self._stop_event.clear()
|
|
515
|
+
self._frame_task = asyncio.create_task(self._process_frames(track))
|
|
516
|
+
|
|
517
|
+
async def _process_frames(self, track: MediaStreamTrack) -> None:
|
|
518
|
+
"""
|
|
519
|
+
Process incoming video frames from a track.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
track: The MediaStreamTrack to process frames from.
|
|
523
|
+
"""
|
|
524
|
+
try:
|
|
525
|
+
while not self._stop_event.is_set():
|
|
526
|
+
try:
|
|
527
|
+
# Receive frame with timeout to allow stop checks
|
|
528
|
+
frame: VideoFrame = await asyncio.wait_for(
|
|
529
|
+
track.recv(),
|
|
530
|
+
timeout=0.1,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
# Convert to numpy RGB array
|
|
534
|
+
numpy_frame = self._video_frame_to_numpy(frame)
|
|
535
|
+
|
|
536
|
+
# Call the frame callback
|
|
537
|
+
if self._frame_callback is not None:
|
|
538
|
+
try:
|
|
539
|
+
self._frame_callback(numpy_frame)
|
|
540
|
+
except Exception as e:
|
|
541
|
+
logger.error(f"Error in frame callback: {e}")
|
|
542
|
+
|
|
543
|
+
except asyncio.TimeoutError:
|
|
544
|
+
continue
|
|
545
|
+
except Exception as e:
|
|
546
|
+
if "MediaStreamError" in str(type(e).__name__):
|
|
547
|
+
logger.debug("Video track ended")
|
|
548
|
+
break
|
|
549
|
+
logger.warning(f"Error processing video frame: {e}")
|
|
550
|
+
break
|
|
551
|
+
|
|
552
|
+
except asyncio.CancelledError:
|
|
553
|
+
logger.debug("Frame processing cancelled")
|
|
554
|
+
except Exception as e:
|
|
555
|
+
logger.warning(f"Frame processing stopped: {e}")
|
|
556
|
+
finally:
|
|
557
|
+
self._emit("track_removed")
|
|
558
|
+
|
|
559
|
+
@staticmethod
|
|
560
|
+
def _video_frame_to_numpy(frame: VideoFrame) -> NDArray[np.uint8]:
|
|
561
|
+
"""
|
|
562
|
+
Convert an av.VideoFrame to a numpy array (H, W, 3) RGB.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
frame: The VideoFrame to convert.
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
Numpy array in RGB format with shape (H, W, 3).
|
|
569
|
+
"""
|
|
570
|
+
if frame.format.name != "rgb24":
|
|
571
|
+
frame = frame.reformat(format="rgb24")
|
|
572
|
+
return frame.to_ndarray()
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
# =============================================================================
|
|
576
|
+
# Custom Video Track for Sending Frames
|
|
577
|
+
# =============================================================================
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
class FrameVideoTrack(MediaStreamTrack):
|
|
581
|
+
"""
|
|
582
|
+
A video track that sends frames from a queue.
|
|
583
|
+
|
|
584
|
+
Use this to send custom video frames to the GPU machine.
|
|
585
|
+
|
|
586
|
+
Example:
|
|
587
|
+
track = FrameVideoTrack()
|
|
588
|
+
await reactor.publish_track(track)
|
|
589
|
+
|
|
590
|
+
# Push frames in a loop
|
|
591
|
+
while True:
|
|
592
|
+
frame = get_next_frame() # Your frame source
|
|
593
|
+
await track.push_frame(frame)
|
|
594
|
+
"""
|
|
595
|
+
|
|
596
|
+
kind = "video"
|
|
597
|
+
|
|
598
|
+
def __init__(self, fps: float = 30.0) -> None:
|
|
599
|
+
"""
|
|
600
|
+
Initialize the FrameVideoTrack.
|
|
601
|
+
|
|
602
|
+
Args:
|
|
603
|
+
fps: Target frames per second.
|
|
604
|
+
"""
|
|
605
|
+
super().__init__()
|
|
606
|
+
self._queue: asyncio.Queue[VideoFrame] = asyncio.Queue(maxsize=2)
|
|
607
|
+
self._pts = 0
|
|
608
|
+
self._fps = fps
|
|
609
|
+
self._time_base = 1 / fps
|
|
610
|
+
|
|
611
|
+
async def push_frame(self, frame: NDArray[np.uint8]) -> None:
|
|
612
|
+
"""
|
|
613
|
+
Push a frame to be sent.
|
|
614
|
+
|
|
615
|
+
Args:
|
|
616
|
+
frame: Numpy array in RGB format with shape (H, W, 3).
|
|
617
|
+
"""
|
|
618
|
+
video_frame = VideoFrame.from_ndarray(frame, format="rgb24")
|
|
619
|
+
video_frame = video_frame.reformat(format="yuv420p")
|
|
620
|
+
video_frame.pts = self._pts
|
|
621
|
+
video_frame.time_base = self._time_base
|
|
622
|
+
self._pts += 1
|
|
623
|
+
|
|
624
|
+
# Non-blocking put, drop old frames if queue is full
|
|
625
|
+
try:
|
|
626
|
+
self._queue.put_nowait(video_frame)
|
|
627
|
+
except asyncio.QueueFull:
|
|
628
|
+
# Drop oldest frame and add new one
|
|
629
|
+
try:
|
|
630
|
+
self._queue.get_nowait()
|
|
631
|
+
except asyncio.QueueEmpty:
|
|
632
|
+
pass
|
|
633
|
+
await self._queue.put(video_frame)
|
|
634
|
+
|
|
635
|
+
async def recv(self) -> VideoFrame:
|
|
636
|
+
"""
|
|
637
|
+
Receive the next frame to send.
|
|
638
|
+
|
|
639
|
+
Returns:
|
|
640
|
+
The next VideoFrame.
|
|
641
|
+
"""
|
|
642
|
+
frame = await self._queue.get()
|
|
643
|
+
return frame
|
|
644
|
+
|
|
645
|
+
def stop(self) -> None:
|
|
646
|
+
"""Stop the track."""
|
|
647
|
+
super().stop()
|
reactor_sdk/py.typed
ADDED