reactor-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,647 @@
1
+ """
2
+ ModelClient for the Reactor SDK.
3
+
4
+ This module handles the WebRTC connection to the model,
5
+ including video streaming and data channel messaging.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import json
12
+ import logging
13
+ from typing import Any, Callable, Optional, Set
14
+
15
+ import numpy as np
16
+ from aiortc import (
17
+ MediaStreamTrack,
18
+ RTCDataChannel,
19
+ RTCIceServer,
20
+ RTCPeerConnection,
21
+ RTCRtpSender,
22
+ RTCRtpTransceiver,
23
+ )
24
+ from aiortc.codecs import h264
25
+ from av import VideoFrame
26
+ from numpy.typing import NDArray
27
+
28
+ from reactor_sdk.types import FrameCallback, GPUMachineEvent, GPUMachineStatus
29
+ from reactor_sdk.utils.webrtc import (
30
+ WebRTCConfig,
31
+ create_data_channel,
32
+ create_offer,
33
+ create_peer_connection,
34
+ parse_message,
35
+ send_message,
36
+ set_remote_description,
37
+ )
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ # Type for event handlers
43
+ GPUEventHandler = Callable[..., None]
44
+
45
+
46
+ class ModelClient:
47
+ """
48
+ Manages the WebRTC connection to the model.
49
+
50
+ Handles video streaming (both receiving and sending), data channel
51
+ messaging, and connection lifecycle.
52
+ """
53
+
54
+ def __init__(self, config: WebRTCConfig) -> None:
55
+ """
56
+ Initialize the GPUMachineClient.
57
+
58
+ Args:
59
+ config: WebRTC configuration including ICE servers.
60
+ """
61
+ self._config = config
62
+ self._peer_connection: Optional[RTCPeerConnection] = None
63
+ self._data_channel: Optional[RTCDataChannel] = None
64
+ self._status: GPUMachineStatus = GPUMachineStatus.DISCONNECTED
65
+ self._published_track: Optional[MediaStreamTrack] = None
66
+ self._video_transceiver: Optional[RTCRtpTransceiver] = None
67
+ self._remote_track: Optional[MediaStreamTrack] = None
68
+
69
+ # Event system
70
+ self._event_listeners: dict[GPUMachineEvent, Set[GPUEventHandler]] = {}
71
+
72
+ # Frame callback for single-frame access
73
+ self._frame_callback: Optional[FrameCallback] = None
74
+ self._frame_task: Optional[asyncio.Task[None]] = None
75
+
76
+ # Stop event for cooperative shutdown
77
+ self._stop_event = asyncio.Event()
78
+
79
+ # Connection state tracking - both must be true to be "connected"
80
+ self._peer_connection_ready = False
81
+ self._data_channel_open = False
82
+
83
+ # =========================================================================
84
+ # Event Emitter API
85
+ # =========================================================================
86
+
87
+ def on(self, event: GPUMachineEvent, handler: GPUEventHandler) -> None:
88
+ """
89
+ Register an event handler.
90
+
91
+ Args:
92
+ event: The event name.
93
+ handler: The callback function.
94
+ """
95
+ if event not in self._event_listeners:
96
+ self._event_listeners[event] = set()
97
+ self._event_listeners[event].add(handler)
98
+
99
+ def off(self, event: GPUMachineEvent, handler: GPUEventHandler) -> None:
100
+ """
101
+ Unregister an event handler.
102
+
103
+ Args:
104
+ event: The event name.
105
+ handler: The callback function to remove.
106
+ """
107
+ if event in self._event_listeners:
108
+ self._event_listeners[event].discard(handler)
109
+
110
+ def _emit(self, event: GPUMachineEvent, *args: Any) -> None:
111
+ """
112
+ Emit an event to all registered handlers.
113
+
114
+ Args:
115
+ event: The event name.
116
+ *args: Arguments to pass to handlers.
117
+ """
118
+ if event in self._event_listeners:
119
+ for handler in self._event_listeners[event]:
120
+ try:
121
+ handler(*args)
122
+ except Exception as e:
123
+ logger.exception(f"Error in event handler for '{event}': {e}")
124
+
125
+ # =========================================================================
126
+ # Frame Callback
127
+ # =========================================================================
128
+
129
+ def set_frame_callback(self, callback: Optional[FrameCallback]) -> None:
130
+ """
131
+ Set a callback to receive individual video frames.
132
+
133
+ The callback will be called with each received frame as a numpy array
134
+ in RGB format with shape (H, W, 3).
135
+
136
+ Args:
137
+ callback: The callback function, or None to clear.
138
+ """
139
+ self._frame_callback = callback
140
+
141
+ # =========================================================================
142
+ # SDP & Connection
143
+ # =========================================================================
144
+
145
+ async def create_offer(self) -> str:
146
+ """
147
+ Create an SDP offer for initiating a connection.
148
+
149
+ Must be called before connect().
150
+
151
+ Returns:
152
+ The SDP offer string.
153
+ """
154
+ # Create peer connection if not exists
155
+ if self._peer_connection is None:
156
+ self._peer_connection = create_peer_connection(self._config)
157
+ self._setup_peer_connection_handlers()
158
+
159
+ # Create data channel before offer (offerer creates the channel)
160
+ self._data_channel = create_data_channel(
161
+ self._peer_connection,
162
+ self._config.data_channel_label,
163
+ )
164
+ self._setup_data_channel_handlers()
165
+
166
+ # Add sendrecv video transceiver for bidirectional video
167
+ self._video_transceiver = self._peer_connection.addTransceiver(
168
+ "video",
169
+ direction="sendrecv",
170
+ )
171
+
172
+ # Set codec preferences to prefer H.264 over VP8
173
+ # This helps ensure codec compatibility with the server
174
+ self._set_codec_preferences()
175
+
176
+ offer = await create_offer(self._peer_connection)
177
+ logger.debug("Created SDP offer")
178
+ return offer
179
+
180
+ def _set_codec_preferences(self) -> None:
181
+ """
182
+ Set codec preferences to prefer H.264 over VP8.
183
+
184
+ H.264 is more widely supported and often provides better compatibility.
185
+ """
186
+ if self._video_transceiver is None:
187
+ return
188
+
189
+ try:
190
+ # Get available video codecs
191
+ capabilities = RTCRtpSender.getCapabilities("video")
192
+ if capabilities is None:
193
+ logger.debug("No video capabilities available")
194
+ return
195
+
196
+ # Sort codecs to prefer H.264, then VP8, then others
197
+ preferred_codecs = []
198
+ other_codecs = []
199
+
200
+ for codec in capabilities.codecs:
201
+ if codec.mimeType.lower() == "video/h264":
202
+ preferred_codecs.insert(0, codec) # H.264 first
203
+ elif codec.mimeType.lower() == "video/vp8":
204
+ preferred_codecs.append(codec) # VP8 second
205
+ else:
206
+ other_codecs.append(codec)
207
+
208
+ # Combine: H.264 first, then VP8, then others
209
+ all_codecs = preferred_codecs + other_codecs
210
+
211
+ if all_codecs:
212
+ self._video_transceiver.setCodecPreferences(all_codecs)
213
+ codec_names = [c.mimeType for c in all_codecs[:3]]
214
+ logger.debug(f"Set codec preferences: {codec_names}...")
215
+
216
+ except Exception as e:
217
+ # Don't fail if codec preferences can't be set
218
+ logger.debug(f"Could not set codec preferences: {e}")
219
+
220
+ async def connect(self, sdp_answer: str) -> None:
221
+ """
222
+ Connect to the GPU machine using the provided SDP answer.
223
+
224
+ create_offer() must be called first.
225
+
226
+ Args:
227
+ sdp_answer: The SDP answer from the GPU machine.
228
+
229
+ Raises:
230
+ RuntimeError: If create_offer() was not called first.
231
+ """
232
+ if self._peer_connection is None:
233
+ raise RuntimeError("Cannot connect - call create_offer() first")
234
+
235
+ if self._peer_connection.signalingState != "have-local-offer":
236
+ raise RuntimeError(
237
+ f"Invalid signaling state: {self._peer_connection.signalingState}"
238
+ )
239
+
240
+ self._set_status(GPUMachineStatus.CONNECTING)
241
+
242
+ try:
243
+ await set_remote_description(self._peer_connection, sdp_answer)
244
+ logger.debug("Remote description set")
245
+ except Exception as e:
246
+ logger.error(f"Failed to connect: {e}")
247
+ self._set_status(GPUMachineStatus.ERROR)
248
+ raise
249
+
250
+ async def disconnect(self) -> None:
251
+ """
252
+ Disconnect from the GPU machine and clean up resources.
253
+ """
254
+ # Signal stop to frame processing task
255
+ self._stop_event.set()
256
+
257
+ # Cancel frame processing task
258
+ if self._frame_task is not None:
259
+ self._frame_task.cancel()
260
+ try:
261
+ await self._frame_task
262
+ except asyncio.CancelledError:
263
+ pass
264
+ self._frame_task = None
265
+
266
+ # Unpublish any published track
267
+ if self._published_track is not None:
268
+ await self.unpublish_track()
269
+
270
+ # Close data channel
271
+ if self._data_channel is not None:
272
+ self._data_channel.close()
273
+ self._data_channel = None
274
+
275
+ # Close peer connection
276
+ if self._peer_connection is not None:
277
+ await self._peer_connection.close()
278
+ self._peer_connection = None
279
+
280
+ self._video_transceiver = None
281
+ self._remote_track = None
282
+ self._peer_connection_ready = False
283
+ self._data_channel_open = False
284
+ self._set_status(GPUMachineStatus.DISCONNECTED)
285
+ logger.debug("Disconnected from GPU machine")
286
+
287
+ def get_status(self) -> GPUMachineStatus:
288
+ """
289
+ Get the current connection status.
290
+
291
+ Returns:
292
+ The current GPUMachineStatus.
293
+ """
294
+ return self._status
295
+
296
+ def get_local_sdp(self) -> Optional[str]:
297
+ """
298
+ Get the current local SDP description.
299
+
300
+ Returns:
301
+ The local SDP string, or None if not set.
302
+ """
303
+ if self._peer_connection is None:
304
+ return None
305
+ desc = self._peer_connection.localDescription
306
+ return desc.sdp if desc else None
307
+
308
+ def is_offer_still_valid(self) -> bool:
309
+ """
310
+ Check if the current offer is still valid.
311
+
312
+ Returns:
313
+ True if the offer is valid.
314
+ """
315
+ if self._peer_connection is None:
316
+ return False
317
+ return self._peer_connection.signalingState == "have-local-offer"
318
+
319
+ # =========================================================================
320
+ # Messaging
321
+ # =========================================================================
322
+
323
+ def send_command(self, command: str, data: Any) -> None:
324
+ """
325
+ Send a command to the GPU machine via the data channel.
326
+
327
+ Args:
328
+ command: The command type.
329
+ data: The data to send with the command.
330
+
331
+ Raises:
332
+ RuntimeError: If the data channel is not available.
333
+ """
334
+ if self._data_channel is None:
335
+ raise RuntimeError("Data channel not available")
336
+
337
+ try:
338
+ send_message(self._data_channel, command, data)
339
+ except Exception as e:
340
+ logger.warning(f"Failed to send message: {e}")
341
+ raise
342
+
343
+ # =========================================================================
344
+ # Track Publishing
345
+ # =========================================================================
346
+
347
+ async def publish_track(self, track: MediaStreamTrack) -> None:
348
+ """
349
+ Publish a track to the GPU machine.
350
+
351
+ Only one track can be published at a time.
352
+ Uses the existing transceiver's sender to replace the track.
353
+
354
+ Args:
355
+ track: The MediaStreamTrack to publish.
356
+
357
+ Raises:
358
+ RuntimeError: If not connected or no video transceiver.
359
+ """
360
+ if self._peer_connection is None:
361
+ raise RuntimeError("Cannot publish track - not initialized")
362
+
363
+ if self._status != GPUMachineStatus.CONNECTED:
364
+ raise RuntimeError("Cannot publish track - not connected")
365
+
366
+ if self._video_transceiver is None:
367
+ raise RuntimeError("Cannot publish track - no video transceiver")
368
+
369
+ try:
370
+ # Use replaceTrack on the existing transceiver's sender
371
+ # This doesn't require renegotiation
372
+ await self._video_transceiver.sender.replaceTrack(track)
373
+ self._published_track = track
374
+ logger.debug(f"Track published successfully: {track.kind}")
375
+ except Exception as e:
376
+ logger.error(f"Failed to publish track: {e}")
377
+ raise
378
+
379
+ async def unpublish_track(self) -> None:
380
+ """
381
+ Unpublish the currently published track.
382
+ """
383
+ if self._video_transceiver is None or self._published_track is None:
384
+ return
385
+
386
+ try:
387
+ # Replace with None to stop sending without renegotiation
388
+ await self._video_transceiver.sender.replaceTrack(None)
389
+ logger.debug("Track unpublished successfully")
390
+ except Exception as e:
391
+ logger.error(f"Failed to unpublish track: {e}")
392
+ raise
393
+ finally:
394
+ self._published_track = None
395
+
396
+ def get_published_track(self) -> Optional[MediaStreamTrack]:
397
+ """
398
+ Get the currently published track.
399
+
400
+ Returns:
401
+ The published MediaStreamTrack, or None.
402
+ """
403
+ return self._published_track
404
+
405
+ # =========================================================================
406
+ # Remote Stream Access
407
+ # =========================================================================
408
+
409
+ def get_remote_track(self) -> Optional[MediaStreamTrack]:
410
+ """
411
+ Get the remote video track from the GPU machine.
412
+
413
+ Returns:
414
+ The remote MediaStreamTrack, or None if not available.
415
+ """
416
+ return self._remote_track
417
+
418
+ # =========================================================================
419
+ # Private Helpers
420
+ # =========================================================================
421
+
422
+ def _set_status(self, new_status: GPUMachineStatus) -> None:
423
+ """Set the connection status and emit event if changed."""
424
+ if self._status != new_status:
425
+ self._status = new_status
426
+ self._emit("status_changed", new_status)
427
+
428
+ def _check_fully_connected(self) -> None:
429
+ """
430
+ Check if both peer connection and data channel are ready.
431
+
432
+ Only transitions to CONNECTED status when both conditions are met.
433
+ This prevents sending messages before the data channel is open.
434
+ """
435
+ if self._peer_connection_ready and self._data_channel_open:
436
+ logger.debug("Both peer connection and data channel ready - fully connected")
437
+ self._set_status(GPUMachineStatus.CONNECTED)
438
+
439
+ def _setup_peer_connection_handlers(self) -> None:
440
+ """Set up event handlers for the peer connection."""
441
+ if self._peer_connection is None:
442
+ return
443
+
444
+ @self._peer_connection.on("connectionstatechange")
445
+ async def on_connection_state_change() -> None:
446
+ if self._peer_connection is None:
447
+ return
448
+
449
+ state = self._peer_connection.connectionState
450
+ logger.debug(f"Peer connection state: {state}")
451
+
452
+ if state == "connected":
453
+ self._peer_connection_ready = True
454
+ self._check_fully_connected()
455
+ elif state in ("disconnected", "closed"):
456
+ self._peer_connection_ready = False
457
+ self._data_channel_open = False
458
+ self._set_status(GPUMachineStatus.DISCONNECTED)
459
+ elif state == "failed":
460
+ self._peer_connection_ready = False
461
+ self._set_status(GPUMachineStatus.ERROR)
462
+
463
+ @self._peer_connection.on("track")
464
+ def on_track(track: MediaStreamTrack) -> None:
465
+ logger.debug(f"Track received: {track.kind}")
466
+ if track.kind == "video":
467
+ self._remote_track = track
468
+ self._emit("track_received", track)
469
+ # Start frame processing if callback is set
470
+ if self._frame_callback is not None:
471
+ self._start_frame_processing(track)
472
+
473
+ @self._peer_connection.on("icecandidate")
474
+ def on_ice_candidate(candidate: Any) -> None:
475
+ if candidate:
476
+ logger.debug(f"ICE candidate: {candidate}")
477
+
478
+ @self._peer_connection.on("datachannel")
479
+ def on_data_channel(channel: RTCDataChannel) -> None:
480
+ logger.debug(f"Data channel received from remote: {channel.label}")
481
+ self._data_channel = channel
482
+ self._setup_data_channel_handlers()
483
+
484
+ def _setup_data_channel_handlers(self) -> None:
485
+ """Set up event handlers for the data channel."""
486
+ if self._data_channel is None:
487
+ return
488
+
489
+ @self._data_channel.on("open")
490
+ def on_open() -> None:
491
+ logger.debug("Data channel open")
492
+ self._data_channel_open = True
493
+ self._check_fully_connected()
494
+
495
+ @self._data_channel.on("close")
496
+ def on_close() -> None:
497
+ logger.debug("Data channel closed")
498
+ self._data_channel_open = False
499
+
500
+ @self._data_channel.on("message")
501
+ def on_message(message: str) -> None:
502
+ data = parse_message(message)
503
+ logger.debug(f"Received message: {data}")
504
+ try:
505
+ self._emit("application", data)
506
+ except Exception as e:
507
+ logger.error(f"Failed to handle message: {e}")
508
+
509
+ def _start_frame_processing(self, track: MediaStreamTrack) -> None:
510
+ """Start the frame processing task for the given track."""
511
+ if self._frame_task is not None:
512
+ self._frame_task.cancel()
513
+
514
+ self._stop_event.clear()
515
+ self._frame_task = asyncio.create_task(self._process_frames(track))
516
+
517
+ async def _process_frames(self, track: MediaStreamTrack) -> None:
518
+ """
519
+ Process incoming video frames from a track.
520
+
521
+ Args:
522
+ track: The MediaStreamTrack to process frames from.
523
+ """
524
+ try:
525
+ while not self._stop_event.is_set():
526
+ try:
527
+ # Receive frame with timeout to allow stop checks
528
+ frame: VideoFrame = await asyncio.wait_for(
529
+ track.recv(),
530
+ timeout=0.1,
531
+ )
532
+
533
+ # Convert to numpy RGB array
534
+ numpy_frame = self._video_frame_to_numpy(frame)
535
+
536
+ # Call the frame callback
537
+ if self._frame_callback is not None:
538
+ try:
539
+ self._frame_callback(numpy_frame)
540
+ except Exception as e:
541
+ logger.error(f"Error in frame callback: {e}")
542
+
543
+ except asyncio.TimeoutError:
544
+ continue
545
+ except Exception as e:
546
+ if "MediaStreamError" in str(type(e).__name__):
547
+ logger.debug("Video track ended")
548
+ break
549
+ logger.warning(f"Error processing video frame: {e}")
550
+ break
551
+
552
+ except asyncio.CancelledError:
553
+ logger.debug("Frame processing cancelled")
554
+ except Exception as e:
555
+ logger.warning(f"Frame processing stopped: {e}")
556
+ finally:
557
+ self._emit("track_removed")
558
+
559
+ @staticmethod
560
+ def _video_frame_to_numpy(frame: VideoFrame) -> NDArray[np.uint8]:
561
+ """
562
+ Convert an av.VideoFrame to a numpy array (H, W, 3) RGB.
563
+
564
+ Args:
565
+ frame: The VideoFrame to convert.
566
+
567
+ Returns:
568
+ Numpy array in RGB format with shape (H, W, 3).
569
+ """
570
+ if frame.format.name != "rgb24":
571
+ frame = frame.reformat(format="rgb24")
572
+ return frame.to_ndarray()
573
+
574
+
575
+ # =============================================================================
576
+ # Custom Video Track for Sending Frames
577
+ # =============================================================================
578
+
579
+
580
+ class FrameVideoTrack(MediaStreamTrack):
581
+ """
582
+ A video track that sends frames from a queue.
583
+
584
+ Use this to send custom video frames to the GPU machine.
585
+
586
+ Example:
587
+ track = FrameVideoTrack()
588
+ await reactor.publish_track(track)
589
+
590
+ # Push frames in a loop
591
+ while True:
592
+ frame = get_next_frame() # Your frame source
593
+ await track.push_frame(frame)
594
+ """
595
+
596
+ kind = "video"
597
+
598
+ def __init__(self, fps: float = 30.0) -> None:
599
+ """
600
+ Initialize the FrameVideoTrack.
601
+
602
+ Args:
603
+ fps: Target frames per second.
604
+ """
605
+ super().__init__()
606
+ self._queue: asyncio.Queue[VideoFrame] = asyncio.Queue(maxsize=2)
607
+ self._pts = 0
608
+ self._fps = fps
609
+ self._time_base = 1 / fps
610
+
611
+ async def push_frame(self, frame: NDArray[np.uint8]) -> None:
612
+ """
613
+ Push a frame to be sent.
614
+
615
+ Args:
616
+ frame: Numpy array in RGB format with shape (H, W, 3).
617
+ """
618
+ video_frame = VideoFrame.from_ndarray(frame, format="rgb24")
619
+ video_frame = video_frame.reformat(format="yuv420p")
620
+ video_frame.pts = self._pts
621
+ video_frame.time_base = self._time_base
622
+ self._pts += 1
623
+
624
+ # Non-blocking put, drop old frames if queue is full
625
+ try:
626
+ self._queue.put_nowait(video_frame)
627
+ except asyncio.QueueFull:
628
+ # Drop oldest frame and add new one
629
+ try:
630
+ self._queue.get_nowait()
631
+ except asyncio.QueueEmpty:
632
+ pass
633
+ await self._queue.put(video_frame)
634
+
635
+ async def recv(self) -> VideoFrame:
636
+ """
637
+ Receive the next frame to send.
638
+
639
+ Returns:
640
+ The next VideoFrame.
641
+ """
642
+ frame = await self._queue.get()
643
+ return frame
644
+
645
+ def stop(self) -> None:
646
+ """Stop the track."""
647
+ super().stop()
reactor_sdk/py.typed ADDED
@@ -0,0 +1,2 @@
1
+ # Marker file for PEP 561
2
+ # This package provides type information