websocket-proxy 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,13 @@
1
+ """Fox Bridge - Foxglove WebSocket Proxy with Message Transformations"""
2
+
3
+ from websocket_proxy.proxy import ProxyBridge
4
+ from websocket_proxy.transformers import TransformerRegistry
5
+ from websocket_proxy.transformers.image_to_video import ImageToVideoTransformer
6
+ from websocket_proxy.transformers.pointcloud_voxel import PointCloudVoxelTransformer
7
+
8
+ __all__ = [
9
+ "ImageToVideoTransformer",
10
+ "PointCloudVoxelTransformer",
11
+ "ProxyBridge",
12
+ "TransformerRegistry",
13
+ ]
@@ -0,0 +1,228 @@
1
+ import argparse
2
+ import asyncio
3
+ import contextlib
4
+ import logging
5
+ import signal
6
+ import sys
7
+
8
+ from rich.console import Console
9
+ from rich.logging import RichHandler
10
+
11
+ from websocket_proxy.dashboard import DashboardRenderer
12
+ from websocket_proxy.proxy import ProxyBridge
13
+ from websocket_proxy.transformers import TransformerRegistry
14
+ from websocket_proxy.transformers.image_to_video import ImageToVideoTransformer
15
+ from websocket_proxy.transformers.pointcloud_voxel import PointCloudVoxelTransformer
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def parse_args() -> argparse.Namespace:
21
+ """Parse command line arguments."""
22
+ parser = argparse.ArgumentParser(
23
+ description="Foxglove WebSocket proxy - forwards topics with optional transformations"
24
+ )
25
+ parser.add_argument(
26
+ "source_ws",
27
+ help="WebSocket URL of the upstream Foxglove bridge (e.g., ws://localhost:8765)",
28
+ )
29
+ parser.add_argument(
30
+ "--port",
31
+ type=int,
32
+ default=8766,
33
+ help="Port to listen on for downstream clients (default: 8766)",
34
+ )
35
+ parser.add_argument(
36
+ "--host",
37
+ default="0.0.0.0", # noqa: S104
38
+ help="Host to listen on for downstream clients (default: 0.0.0.0)",
39
+ )
40
+ parser.add_argument(
41
+ "--verbose",
42
+ "-v",
43
+ action="store_true",
44
+ help="Enable verbose debug logging",
45
+ )
46
+ parser.add_argument(
47
+ "--throttle-hz",
48
+ type=float,
49
+ default=1.0,
50
+ help="Topic throttle rate in Hz (default: 1.0; set to 0 to disable)",
51
+ )
52
+ parser.add_argument(
53
+ "--max-message-size",
54
+ type=int,
55
+ default=0,
56
+ help="Maximum websocket message size in bytes (<=0 disables limit, default: unlimited)",
57
+ )
58
+
59
+ parser.add_argument(
60
+ "--image-codec",
61
+ default="h264",
62
+ help="Video codec to use for image compression (default: h264)",
63
+ )
64
+ parser.add_argument(
65
+ "--image-quality",
66
+ type=int,
67
+ default=23,
68
+ help="CRF/quality value for image compression (lower is higher quality, default: 23)",
69
+ )
70
+ parser.add_argument(
71
+ "--image-preset",
72
+ default="fast",
73
+ help="Encoder preset for image compression (default: fast)",
74
+ )
75
+ parser.add_argument(
76
+ "--image-max-dimension",
77
+ type=int,
78
+ default=480,
79
+ help="Maximum width/height used when downscaling images before encoding (default: 480)",
80
+ )
81
+ parser.add_argument(
82
+ "--image-disable-hw",
83
+ dest="image_use_hardware",
84
+ action="store_false",
85
+ help="Disable hardware acceleration for image compression",
86
+ )
87
+
88
+ parser.add_argument(
89
+ "--pointcloud-voxel-size",
90
+ type=float,
91
+ default=0.1,
92
+ help="Voxel size (in meters) for point cloud downsampling (default: 0.1)",
93
+ )
94
+ parser.add_argument(
95
+ "--pointcloud-keep-nans",
96
+ dest="pointcloud_skip_nans",
97
+ action="store_false",
98
+ help="Keep NaN points when voxelizing point clouds (default: drop NaNs)",
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--no-dashboard",
103
+ action="store_true",
104
+ help="Disable the live dashboard display",
105
+ )
106
+ parser.add_argument(
107
+ "--dashboard-refresh-rate",
108
+ type=float,
109
+ default=1.0,
110
+ help="Dashboard refresh rate in seconds (default: 1.0)",
111
+ )
112
+
113
+ parser.set_defaults(image_use_hardware=True, pointcloud_skip_nans=True)
114
+ return parser.parse_args()
115
+
116
+
117
+ async def main_async(args: argparse.Namespace) -> None:
118
+ """Async main function."""
119
+ # Create shared console for dashboard and logging
120
+ console = Console()
121
+
122
+ # Create transformer registry and register transformers (BEFORE configuring logging)
123
+ registry = TransformerRegistry()
124
+
125
+ # Register image to video transformer
126
+ image_transformer = ImageToVideoTransformer(
127
+ codec=args.image_codec,
128
+ quality=args.image_quality,
129
+ preset=args.image_preset,
130
+ use_hardware=args.image_use_hardware,
131
+ max_dimension=args.image_max_dimension,
132
+ )
133
+ registry.register(image_transformer)
134
+
135
+ pointcloud_transformer = PointCloudVoxelTransformer(
136
+ voxel_size=args.pointcloud_voxel_size,
137
+ skip_nans=args.pointcloud_skip_nans,
138
+ )
139
+ registry.register(pointcloud_transformer)
140
+
141
+ # Create proxy bridge with transformers
142
+ bridge = ProxyBridge(
143
+ upstream_url=args.source_ws,
144
+ listen_host=args.host,
145
+ listen_port=args.port,
146
+ transformer_registry=registry,
147
+ default_throttle_hz=args.throttle_hz,
148
+ max_message_size=args.max_message_size if args.max_message_size > 0 else None,
149
+ )
150
+
151
+ # Create dashboard if enabled (with shared console for logging integration)
152
+ dashboard = None
153
+ if not args.no_dashboard:
154
+ dashboard = DashboardRenderer(
155
+ bridge, refresh_rate=args.dashboard_refresh_rate, console=console
156
+ )
157
+ # Start dashboard BEFORE configuring logging
158
+ dashboard.start_sync()
159
+
160
+ # NOW configure logging with Rich handler (after dashboard is started)
161
+ logging.basicConfig(
162
+ level=logging.DEBUG if args.verbose else logging.INFO,
163
+ format="%(message)s",
164
+ datefmt="[%X]",
165
+ handlers=[
166
+ RichHandler(
167
+ console=console,
168
+ rich_tracebacks=True,
169
+ tracebacks_show_locals=args.verbose,
170
+ )
171
+ ],
172
+ )
173
+
174
+ logger.info("Registered transformers:")
175
+ for transformer in registry.get_all_transformers():
176
+ logger.info(f" {transformer.get_input_schema()} -> {transformer.get_output_schema()}")
177
+
178
+ # Setup signal handlers for graceful shutdown
179
+ loop = asyncio.get_running_loop()
180
+
181
+ def signal_handler() -> None:
182
+ logger.info("Received shutdown signal")
183
+ asyncio.create_task(bridge.stop()) # noqa: RUF006
184
+
185
+ for sig in (signal.SIGTERM, signal.SIGINT):
186
+ loop.add_signal_handler(sig, signal_handler)
187
+
188
+ try:
189
+ if dashboard:
190
+ # Dashboard already started above (before logging config)
191
+ # Just create a background task for dashboard updates
192
+ dashboard_task = asyncio.create_task(dashboard.run_updates())
193
+
194
+ try:
195
+ # Start proxy (this will block until stop() is called)
196
+ await bridge.start()
197
+ finally:
198
+ # Cancel dashboard updates
199
+ dashboard_task.cancel()
200
+ with contextlib.suppress(asyncio.CancelledError):
201
+ await dashboard_task
202
+ await dashboard.stop()
203
+ else:
204
+ # No dashboard - just start proxy
205
+ await bridge.start()
206
+ except KeyboardInterrupt:
207
+ logger.info("Keyboard interrupt received")
208
+ except Exception:
209
+ logger.exception("Unexpected error in proxy bridge")
210
+ sys.exit(1)
211
+ finally:
212
+ await bridge.stop()
213
+ if dashboard:
214
+ await dashboard.stop()
215
+
216
+
217
+ def main() -> None:
218
+ """Main entry point."""
219
+ args = parse_args()
220
+
221
+ try:
222
+ asyncio.run(main_async(args))
223
+ except KeyboardInterrupt:
224
+ logger.info("Exiting")
225
+
226
+
227
+ if __name__ == "__main__":
228
+ main()
@@ -0,0 +1,245 @@
1
+ """Rich-based dashboard for websocket proxy server."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ from datetime import datetime, timedelta, timezone
8
+ from typing import TYPE_CHECKING
9
+
10
+ from rich.console import Console, Group
11
+ from rich.live import Live
12
+ from rich.panel import Panel
13
+ from rich.table import Table
14
+ from rich.text import Text
15
+
16
+ if TYPE_CHECKING:
17
+ from .metrics import MetricsCollector
18
+ from .proxy import ProxyBridge
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def _format_duration(seconds: float) -> str:
24
+ """Format duration in seconds to human-readable string."""
25
+ if seconds < 60:
26
+ return f"{seconds:.0f}s"
27
+ if seconds < 3600:
28
+ return f"{seconds // 60:.0f}m {seconds % 60:.0f}s"
29
+ hours = seconds // 3600
30
+ minutes = (seconds % 3600) // 60
31
+ return f"{hours:.0f}h {minutes:.0f}m"
32
+
33
+
34
+ def _format_rate(rate: float) -> str:
35
+ """Format rate (msgs/sec) to human-readable string."""
36
+ if rate < 0.01:
37
+ return "0.00"
38
+ if rate < 1:
39
+ return f"{rate:.2f}"
40
+ if rate < 10:
41
+ return f"{rate:.1f}"
42
+ return f"{rate:.0f}"
43
+
44
+
45
+ def _format_bandwidth(bytes_per_sec: float) -> str:
46
+ """Format bandwidth to human-readable string."""
47
+ if bytes_per_sec < 1024:
48
+ return f"{bytes_per_sec:.0f} B/s"
49
+ if bytes_per_sec < 1024 * 1024:
50
+ return f"{bytes_per_sec / 1024:.1f} KB/s"
51
+ return f"{bytes_per_sec / (1024 * 1024):.2f} MB/s"
52
+
53
+
54
+ def _format_bytes(byte_count: int) -> str:
55
+ """Format byte count to human-readable string."""
56
+ if byte_count < 1024:
57
+ return f"{byte_count} B"
58
+ if byte_count < 1024 * 1024:
59
+ return f"{byte_count / 1024:.1f} KB"
60
+ if byte_count < 1024 * 1024 * 1024:
61
+ return f"{byte_count / (1024 * 1024):.1f} MB"
62
+ return f"{byte_count / (1024 * 1024 * 1024):.2f} GB"
63
+
64
+
65
+ def _format_timestamp(dt: datetime | None) -> str:
66
+ """Format timestamp to relative time string."""
67
+ if dt is None:
68
+ return "never"
69
+
70
+ now = datetime.now(timezone.utc)
71
+ diff = now - dt
72
+
73
+ if diff < timedelta(seconds=1):
74
+ return "just now"
75
+ if diff < timedelta(seconds=60):
76
+ return f"{diff.seconds}s ago"
77
+ if diff < timedelta(minutes=60):
78
+ return f"{diff.seconds // 60}m ago"
79
+ return dt.strftime("%H:%M:%S")
80
+
81
+
82
+ class DashboardRenderer:
83
+ """Renders a live dashboard for the proxy server using Rich."""
84
+
85
+ def __init__(self, proxy: ProxyBridge, refresh_rate: float, console: Console) -> None:
86
+ self.proxy = proxy
87
+ self.metrics: MetricsCollector = proxy.metrics
88
+ self.refresh_rate = refresh_rate
89
+ self.console = console
90
+ self._live: Live | None = None
91
+
92
+ def _create_header_panel(self) -> Panel:
93
+ """Create the header panel with global stats."""
94
+ uptime = _format_duration(self.metrics.get_uptime())
95
+ total_clients = len(self.metrics.clients)
96
+
97
+ header_text = Text()
98
+ header_text.append("Foxglove WebSocket Proxy Dashboard", style="bold cyan")
99
+ header_text.append("\n\n")
100
+ header_text.append("Uptime: ", style="bold")
101
+ header_text.append(uptime)
102
+ header_text.append(" | Connected Clients: ", style="bold")
103
+ header_text.append(str(total_clients), style="green" if total_clients > 0 else "dim")
104
+
105
+ # Add upstream metrics
106
+ header_text.append("\n\n")
107
+ header_text.append("Upstream Status: ", style="bold")
108
+ if self.metrics.upstream_connected:
109
+ header_text.append("● Connected", style="green bold")
110
+ else:
111
+ header_text.append("● Disconnected", style="red bold")
112
+
113
+ header_text.append(" | Topics: ", style="bold")
114
+ header_text.append(str(self.metrics.upstream_topic_count), style="cyan")
115
+
116
+ header_text.append(" | Transformed: ", style="bold")
117
+ header_text.append(str(self.metrics.transformed_channel_count), style="magenta")
118
+
119
+ header_text.append("\n")
120
+ header_text.append("Messages Received: ", style="bold")
121
+ header_text.append(str(self.metrics.upstream_messages_received), style="green")
122
+
123
+ header_text.append(" | Throttled: ", style="bold")
124
+ throttled_style = "red" if self.metrics.upstream_messages_throttled > 0 else "dim"
125
+ header_text.append(str(self.metrics.upstream_messages_throttled), style=throttled_style)
126
+
127
+ return Panel(header_text, border_style="blue")
128
+
129
+ def _create_clients_table(self) -> Table:
130
+ """Create the clients table."""
131
+ table = Table(
132
+ title="Connected Clients",
133
+ title_style="bold magenta",
134
+ show_header=True,
135
+ header_style="bold",
136
+ show_lines=False,
137
+ expand=False,
138
+ )
139
+
140
+ table.add_column("ID", style="cyan", no_wrap=True)
141
+ table.add_column("Remote Address", style="blue", no_wrap=True)
142
+ table.add_column("Connected", style="green", no_wrap=True)
143
+ table.add_column("Msg/s", justify="right", style="yellow")
144
+ table.add_column("Bandwidth", justify="right", style="yellow")
145
+ table.add_column("Msgs", justify="right", style="white")
146
+ table.add_column("Bytes", justify="right", style="yellow")
147
+ table.add_column("Subs", justify="center", style="cyan")
148
+ table.add_column("Errors", justify="center", style="red")
149
+ table.add_column("Last Msg", style="dim", no_wrap=True)
150
+
151
+ # Sort clients by connection time (oldest first)
152
+ sorted_clients = sorted(
153
+ self.metrics.clients.values(),
154
+ key=lambda c: c.connected_at,
155
+ )
156
+
157
+ for client in sorted_clients:
158
+ client_id_short = client.client_id.replace("client_", "")[:12]
159
+ duration = _format_duration(client.connected_duration)
160
+ msg_rate = _format_rate(client.get_message_rate())
161
+ bandwidth = _format_bandwidth(client.get_bandwidth())
162
+ bytes_send = _format_bytes(client.bytes_sent)
163
+ last_msg = _format_timestamp(client.last_message_at)
164
+
165
+ # Style errors in red if > 0
166
+ errors_str = str(client.errors)
167
+ errors_style = "red bold" if client.errors > 0 else "dim"
168
+
169
+ table.add_row(
170
+ client_id_short,
171
+ client.remote_address,
172
+ duration,
173
+ msg_rate,
174
+ bandwidth,
175
+ str(client.messages_sent),
176
+ bytes_send,
177
+ str(client.subscription_count),
178
+ Text(errors_str, style=errors_style),
179
+ last_msg,
180
+ )
181
+
182
+ if not sorted_clients:
183
+ table.add_row(
184
+ Text("No clients connected", style="dim italic"),
185
+ "",
186
+ "",
187
+ "",
188
+ "",
189
+ "",
190
+ "",
191
+ "",
192
+ "",
193
+ "",
194
+ )
195
+
196
+ return table
197
+
198
+ def _create_layout(self) -> Panel:
199
+ """Create the full dashboard layout."""
200
+ header = self._create_header_panel()
201
+ clients_table = self._create_clients_table()
202
+
203
+ # Group everything together
204
+ layout = Group(
205
+ header,
206
+ "", # Spacer
207
+ clients_table,
208
+ )
209
+
210
+ return Panel(layout, border_style="bright_blue", padding=(1, 2))
211
+
212
+ def start_sync(self) -> None:
213
+ """Start the live dashboard (synchronous - starts the Live display)."""
214
+ self._live = Live(
215
+ self._create_layout(),
216
+ console=self.console,
217
+ refresh_per_second=1 / self.refresh_rate,
218
+ screen=False,
219
+ auto_refresh=True, # Enable auto-refresh so updates are visible
220
+ )
221
+ self._live.start()
222
+ # Do an initial render
223
+ self._live.update(self._create_layout())
224
+
225
+ async def run_updates(self) -> None:
226
+ """Run the dashboard update loop (call after start_sync)."""
227
+ if self._live is None:
228
+ return
229
+
230
+ while True:
231
+ try:
232
+ self._live.update(self._create_layout())
233
+ await asyncio.sleep(self.refresh_rate)
234
+ except asyncio.CancelledError:
235
+ break
236
+ except (RuntimeError, ValueError, KeyError, IndexError) as e:
237
+ # Log errors during rendering but continue
238
+ logger.debug("Dashboard rendering error: %s", e)
239
+ await asyncio.sleep(self.refresh_rate)
240
+
241
+ async def stop(self) -> None:
242
+ """Stop the live dashboard."""
243
+ if self._live:
244
+ self._live.stop()
245
+ self._live = None
@@ -0,0 +1,119 @@
1
+ """Metrics collection for websocket proxy server."""
2
+
3
+ import time
4
+ from collections import deque
5
+ from dataclasses import dataclass, field
6
+ from datetime import datetime, timezone
7
+
8
+
9
+ @dataclass
10
+ class ClientMetrics:
11
+ """Metrics for a single client connection."""
12
+
13
+ client_id: str
14
+ remote_address: str
15
+ connected_at: datetime
16
+ user_agent: str | None = None
17
+
18
+ # Traffic metrics
19
+ messages_sent: int = 0
20
+ bytes_sent: int = 0
21
+ errors: int = 0
22
+ last_message_at: datetime | None = None
23
+
24
+ # Subscriptions
25
+ subscription_count: int = 0
26
+ subscribed_topics: set[str] = field(default_factory=set)
27
+
28
+ # Rate tracking (samples stored for windowed calculations)
29
+ _message_samples: deque[tuple[float, int]] = field(default_factory=lambda: deque(maxlen=60))
30
+ _byte_samples: deque[tuple[float, int]] = field(default_factory=lambda: deque(maxlen=60))
31
+
32
+ def record_message(self, byte_count: int) -> None:
33
+ """Record a message sent to this client."""
34
+ now = time.time()
35
+ self.messages_sent += 1
36
+ self.bytes_sent += byte_count
37
+ self.last_message_at = datetime.now(timezone.utc)
38
+ self._message_samples.append((now, 1))
39
+ self._byte_samples.append((now, byte_count))
40
+
41
+ def record_error(self) -> None:
42
+ """Record an error for this client."""
43
+ self.errors += 1
44
+
45
+ def get_message_rate(self, window_seconds: float = 5.0) -> float:
46
+ """Calculate messages per second over the last N seconds."""
47
+ return self._calculate_rate(self._message_samples, window_seconds)
48
+
49
+ def get_bandwidth(self, window_seconds: float = 5.0) -> float:
50
+ """Calculate bytes per second over the last N seconds."""
51
+ return self._calculate_rate(self._byte_samples, window_seconds)
52
+
53
+ @staticmethod
54
+ def _calculate_rate(samples: deque[tuple[float, int]], window_seconds: float) -> float:
55
+ """Calculate rate from time-stamped samples within window."""
56
+ if not samples:
57
+ return 0.0
58
+
59
+ now = time.time()
60
+ cutoff = now - window_seconds
61
+ total = sum(count for ts, count in samples if ts >= cutoff)
62
+
63
+ # Find actual time span of samples in window
64
+ valid_samples = [(ts, count) for ts, count in samples if ts >= cutoff]
65
+ if not valid_samples:
66
+ return 0.0
67
+
68
+ oldest_ts = valid_samples[0][0]
69
+ time_span = now - oldest_ts
70
+
71
+ if time_span <= 0:
72
+ return 0.0
73
+
74
+ return total / time_span
75
+
76
+ @property
77
+ def connected_duration(self) -> float:
78
+ """Get connection duration in seconds."""
79
+ return (datetime.now(timezone.utc) - self.connected_at).total_seconds()
80
+
81
+
82
+ class MetricsCollector:
83
+ """Central metrics collector for the proxy server."""
84
+
85
+ def __init__(self) -> None:
86
+ self.clients: dict[str, ClientMetrics] = {}
87
+ self._start_time = datetime.now(timezone.utc)
88
+
89
+ # Upstream metrics
90
+ self.upstream_connected: bool = False
91
+ self.upstream_topic_count: int = 0
92
+ self.upstream_messages_received: int = 0
93
+ self.upstream_messages_throttled: int = 0
94
+ self.transformed_channel_count: int = 0
95
+
96
+ def add_client(
97
+ self, client_id: str, remote_address: str, user_agent: str | None = None
98
+ ) -> ClientMetrics:
99
+ """Add a new client to track."""
100
+ metrics = ClientMetrics(
101
+ client_id=client_id,
102
+ remote_address=remote_address,
103
+ connected_at=datetime.now(timezone.utc),
104
+ user_agent=user_agent,
105
+ )
106
+ self.clients[client_id] = metrics
107
+ return metrics
108
+
109
+ def remove_client(self, client_id: str) -> None:
110
+ """Remove a client from tracking."""
111
+ self.clients.pop(client_id, None)
112
+
113
+ def get_client(self, client_id: str) -> ClientMetrics | None:
114
+ """Get client metrics by ID."""
115
+ return self.clients.get(client_id)
116
+
117
+ def get_uptime(self) -> float:
118
+ """Get server uptime in seconds."""
119
+ return (datetime.now(timezone.utc) - self._start_time).total_seconds()