phoenix-channels-python-client 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,900 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import signal
6
+ from asyncio import AbstractEventLoop, Queue
7
+ from types import TracebackType
8
+ from typing import Callable, Type, Awaitable, Dict, Any
9
+
10
+ from websockets import connect
11
+
12
+ from phoenix_channels_python_client.exceptions import PHXConnectionError, PHXTopicError
13
+ from phoenix_channels_python_client.phx_messages import (
14
+ ChannelMessage,
15
+ ChannelEvent,
16
+ Event,
17
+ PHXEvent,
18
+ PHXEventMessage,
19
+ )
20
+ from phoenix_channels_python_client.protocol_handler import (
21
+ PHXProtocolHandler,
22
+ PhoenixChannelsProtocolVersion,
23
+ )
24
+ from phoenix_channels_python_client.topic_subscription import (
25
+ TopicSubscription,
26
+ TopicProcessingState,
27
+ )
28
+ from phoenix_channels_python_client.utils import make_message
29
+
30
+
31
+ DEFAULT_HEARTBEAT_INTERVAL_SECS = 30
32
+
33
+ DEFAULT_RECONNECT_MAX_ATTEMPTS = 10
34
+ DEFAULT_RECONNECT_BACKOFF_BASE = 1.0
35
+ DEFAULT_RECONNECT_BACKOFF_MAX = 30.0
36
+
37
+ # Type alias for reconnection callbacks
38
+ ReconnectCallback = Callable[[], Awaitable[None]]
39
+ DisconnectCallback = Callable[[Exception | None], Awaitable[None]]
40
+
41
+
42
+ class PHXChannelsClient:
43
+ """
44
+ Async Python client for Phoenix Channels WebSocket connections.
45
+
46
+ Security Note:
47
+ This client passes the API key as a URL query parameter during the WebSocket
48
+ handshake. While the connection uses WSS (encrypted), the API key may appear
49
+ in server access logs, proxy logs, or network monitoring tools. Ensure your
50
+ infrastructure does not log full URLs in production environments.
51
+
52
+ The official Phoenix JS client (v1.8+) supports header-based authentication
53
+ via the `authToken` option, which avoids this issue. This client currently
54
+ uses the older `params` style for compatibility.
55
+
56
+ Args:
57
+ websocket_url: The WebSocket URL to connect to.
58
+ api_key: The API key for authentication.
59
+ event_loop: Optional event loop to use (defaults to current running loop).
60
+ protocol_version: Phoenix Channels protocol version (default: V2).
61
+ heartbeat_interval_secs: Interval between heartbeat messages in seconds.
62
+ Set to None to disable heartbeat. Default is 30 seconds, matching
63
+ the Phoenix JS client.
64
+ auto_reconnect: Whether to automatically reconnect on connection loss.
65
+ Default is True.
66
+ reconnect_max_attempts: Maximum number of reconnection attempts before
67
+ giving up. Default is 10. Set to 0 for unlimited attempts.
68
+ reconnect_backoff_base: Base delay in seconds for exponential backoff.
69
+ Default is 1.0 second.
70
+ reconnect_backoff_max: Maximum delay in seconds between reconnection
71
+ attempts. Default is 30 seconds.
72
+ on_reconnect: Optional async callback called after successful reconnection.
73
+ on_disconnect: Optional async callback called when disconnection is detected.
74
+ Receives the exception that caused the disconnect (if any).
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ websocket_url: str,
80
+ api_key: str,
81
+ event_loop: AbstractEventLoop | None = None,
82
+ protocol_version: PhoenixChannelsProtocolVersion = PhoenixChannelsProtocolVersion.V2,
83
+ heartbeat_interval_secs: float | None = DEFAULT_HEARTBEAT_INTERVAL_SECS,
84
+ auto_reconnect: bool = True,
85
+ reconnect_max_attempts: int = DEFAULT_RECONNECT_MAX_ATTEMPTS,
86
+ reconnect_backoff_base: float = DEFAULT_RECONNECT_BACKOFF_BASE,
87
+ reconnect_backoff_max: float = DEFAULT_RECONNECT_BACKOFF_MAX,
88
+ on_reconnect: ReconnectCallback | None = None,
89
+ on_disconnect: DisconnectCallback | None = None,
90
+ ):
91
+ self.logger = logging.getLogger(__name__)
92
+
93
+ vsn = (
94
+ "2.0.0"
95
+ if protocol_version == PhoenixChannelsProtocolVersion.V2
96
+ else "1.0.0"
97
+ )
98
+ self.channel_socket_url = f"{websocket_url}?api_key={api_key}&vsn={vsn}"
99
+
100
+ self.connection = None
101
+ self._topic_subscriptions: dict[str, TopicSubscription] = {}
102
+ self._loop = event_loop or asyncio.get_event_loop()
103
+ self._message_routing_task = None
104
+ self._protocol_handler = PHXProtocolHandler(protocol_version)
105
+ self._ref_counter = 0
106
+
107
+ # Heartbeat configuration
108
+ self._heartbeat_interval_secs = heartbeat_interval_secs
109
+ self._heartbeat_task: asyncio.Task | None = None
110
+ self._pending_heartbeat_ref: str | None = None
111
+
112
+ # Reconnection configuration
113
+ self._auto_reconnect = auto_reconnect
114
+ self._reconnect_max_attempts = reconnect_max_attempts
115
+ self._reconnect_backoff_base = reconnect_backoff_base
116
+ self._reconnect_backoff_max = reconnect_backoff_max
117
+ self._on_reconnect = on_reconnect
118
+ self._on_disconnect = on_disconnect
119
+
120
+ # Reconnection state
121
+ self._reconnect_task: asyncio.Task | None = None
122
+ self._reconnect_attempt = 0
123
+ self._is_reconnecting = False
124
+ self._shutdown_requested = False
125
+
126
+ # Store subscription info for reconnection (topic -> callback)
127
+ self._subscription_callbacks: Dict[
128
+ str, Callable[[ChannelMessage], Awaitable[None]] | None
129
+ ] = {}
130
+ self._subscription_event_handlers: Dict[
131
+ str, Dict[ChannelEvent, Callable[[Dict[str, Any]], Awaitable[None]]]
132
+ ] = {}
133
+
134
+ async def __aenter__(self) -> "PHXChannelsClient":
135
+ await self._connect()
136
+ return self
137
+
138
+ async def _connect(self) -> None:
139
+ """
140
+ Establish WebSocket connection and start background tasks.
141
+
142
+ This is an internal method used by both initial connection and reconnection.
143
+ """
144
+ try:
145
+ self.connection = await connect(self.channel_socket_url)
146
+ self.logger.info("Connected to Phoenix WebSocket server")
147
+ self._message_routing_task = self._loop.create_task(
148
+ self._start_processing_with_reconnect()
149
+ )
150
+ # Start heartbeat loop if enabled
151
+ if self._heartbeat_interval_secs is not None:
152
+ self._heartbeat_task = self._loop.create_task(self._heartbeat_loop())
153
+ self.logger.debug(
154
+ "Heartbeat enabled with interval of %s seconds",
155
+ self._heartbeat_interval_secs,
156
+ )
157
+ # Reset reconnection state on successful connection
158
+ self._reconnect_attempt = 0
159
+ self._is_reconnecting = False
160
+ except Exception as e:
161
+ self.logger.error("Failed to connect to Phoenix WebSocket server: %s", e)
162
+ raise PHXConnectionError(
163
+ f"Failed to connect to {self.channel_socket_url}: {e}"
164
+ ) from e
165
+
166
+ async def _start_processing_with_reconnect(self) -> None:
167
+ """
168
+ Wrapper around _start_processing that handles disconnection and triggers reconnection.
169
+ """
170
+ try:
171
+ await self._start_processing()
172
+ except Exception as e:
173
+ # Connection closed or error occurred
174
+ if not self._shutdown_requested:
175
+ self.logger.warning("Connection lost: %s", e)
176
+ await self._handle_disconnection(e)
177
+ finally:
178
+ # If processing ended without exception (clean close)
179
+ if not self._shutdown_requested and not self._is_reconnecting:
180
+ self.logger.warning("Connection closed unexpectedly")
181
+ await self._handle_disconnection(None)
182
+
183
+ async def _handle_disconnection(self, error: Exception | None) -> None:
184
+ """
185
+ Handle disconnection by notifying callback and starting reconnection if enabled.
186
+ """
187
+ # Notify disconnect callback
188
+ if self._on_disconnect:
189
+ try:
190
+ await self._on_disconnect(error)
191
+ except Exception as cb_error:
192
+ self.logger.error("Error in on_disconnect callback: %s", cb_error)
193
+
194
+ # Clean up current connection state
195
+ await self._cleanup_connection()
196
+
197
+ # Start reconnection if enabled and not shutting down
198
+ if self._auto_reconnect and not self._shutdown_requested:
199
+ self._reconnect_task = self._loop.create_task(self._reconnect_loop())
200
+
201
+ async def _cleanup_connection(self) -> None:
202
+ """
203
+ Clean up connection-related resources without full shutdown.
204
+ """
205
+ # Cancel heartbeat task
206
+ if self._heartbeat_task and not self._heartbeat_task.done():
207
+ self._heartbeat_task.cancel()
208
+ try:
209
+ await self._heartbeat_task
210
+ except asyncio.CancelledError:
211
+ pass
212
+ self._heartbeat_task = None
213
+ self._pending_heartbeat_ref = None
214
+
215
+ # Close connection if still open
216
+ if self.connection:
217
+ try:
218
+ await self.connection.close()
219
+ except Exception:
220
+ pass
221
+ self.connection = None
222
+
223
+ async def _reconnect_loop(self) -> None:
224
+ """
225
+ Attempt to reconnect with exponential backoff.
226
+ """
227
+ self._is_reconnecting = True
228
+
229
+ while not self._shutdown_requested:
230
+ self._reconnect_attempt += 1
231
+
232
+ # Check max attempts (0 = unlimited)
233
+ if (
234
+ self._reconnect_max_attempts > 0
235
+ and self._reconnect_attempt > self._reconnect_max_attempts
236
+ ):
237
+ self.logger.error(
238
+ "Max reconnection attempts (%d) reached. Giving up.",
239
+ self._reconnect_max_attempts,
240
+ )
241
+ self._is_reconnecting = False
242
+ return
243
+
244
+ # Calculate backoff delay with exponential increase
245
+ delay = min(
246
+ self._reconnect_backoff_base * (2 ** (self._reconnect_attempt - 1)),
247
+ self._reconnect_backoff_max,
248
+ )
249
+
250
+ self.logger.info(
251
+ "Reconnection attempt %d/%s in %.1f seconds...",
252
+ self._reconnect_attempt,
253
+ self._reconnect_max_attempts
254
+ if self._reconnect_max_attempts > 0
255
+ else "∞",
256
+ delay,
257
+ )
258
+
259
+ await asyncio.sleep(delay)
260
+
261
+ if self._shutdown_requested:
262
+ break
263
+
264
+ try:
265
+ await self._connect()
266
+ self.logger.info("Reconnected successfully!")
267
+
268
+ # Re-subscribe to all topics
269
+ await self._resubscribe_topics()
270
+
271
+ # Notify reconnect callback
272
+ if self._on_reconnect:
273
+ try:
274
+ await self._on_reconnect()
275
+ except Exception as cb_error:
276
+ self.logger.error(
277
+ "Error in on_reconnect callback: %s", cb_error
278
+ )
279
+
280
+ self._is_reconnecting = False
281
+ return
282
+
283
+ except Exception as e:
284
+ self.logger.warning(
285
+ "Reconnection attempt %d failed: %s", self._reconnect_attempt, e
286
+ )
287
+
288
+ self._is_reconnecting = False
289
+
290
+ async def _resubscribe_topics(self) -> None:
291
+ """
292
+ Re-subscribe to all topics after reconnection.
293
+ """
294
+ topics_to_resubscribe = list(self._subscription_callbacks.keys())
295
+
296
+ if not topics_to_resubscribe:
297
+ return
298
+
299
+ self.logger.info("Re-subscribing to %d topic(s)...", len(topics_to_resubscribe))
300
+
301
+ subscriptions_snapshot = {
302
+ topic: (
303
+ self._subscription_callbacks.get(topic),
304
+ self._subscription_event_handlers.get(topic, {}).copy(),
305
+ )
306
+ for topic in topics_to_resubscribe
307
+ }
308
+
309
+ for topic, (callback, event_handlers) in subscriptions_snapshot.items():
310
+ try:
311
+ if topic in self._topic_subscriptions:
312
+ self._unregister_topic(topic)
313
+
314
+ # Re-subscribe
315
+ await self.subscribe_to_topic(topic, callback)
316
+
317
+ # Restore event handlers
318
+ for event, handler in event_handlers.items():
319
+ self.add_event_handler(topic, event, handler)
320
+
321
+ self.logger.info("Re-subscribed to topic: %s", topic)
322
+
323
+ except Exception as e:
324
+ self.logger.error("Failed to re-subscribe to topic %s: %s", topic, e)
325
+
326
+ async def __aexit__(
327
+ self,
328
+ exc_type: Type[BaseException] | None = None,
329
+ exc_value: BaseException | None = None,
330
+ traceback: TracebackType | None = None,
331
+ ) -> None:
332
+ await self.shutdown("Client context exiting")
333
+
334
+ async def shutdown(
335
+ self,
336
+ reason: str,
337
+ ) -> None:
338
+ """
339
+ Gracefully shutdown the client connection.
340
+
341
+ This method will:
342
+ 1. Stop reconnection attempts (if any)
343
+ 2. Unsubscribe from all topics (with 5 second timeout)
344
+ 3. Cancel the message routing task
345
+ 4. Close the WebSocket connection
346
+
347
+ Args:
348
+ reason: Human-readable reason for shutdown (for logging)
349
+
350
+ Note: This method is automatically called by __aexit__ when using
351
+ the async context manager. You can also call it explicitly.
352
+ """
353
+ self.logger.info("Shutting down client: %s", reason)
354
+
355
+ # Signal that shutdown is requested (prevents reconnection)
356
+ self._shutdown_requested = True
357
+
358
+ # Cancel reconnection task if running
359
+ if self._reconnect_task and not self._reconnect_task.done():
360
+ self._reconnect_task.cancel()
361
+ try:
362
+ await self._reconnect_task
363
+ except asyncio.CancelledError:
364
+ pass
365
+ self._reconnect_task = None
366
+
367
+ # Clear stored subscription info (won't need it after shutdown)
368
+ self._subscription_callbacks.clear()
369
+ self._subscription_event_handlers.clear()
370
+
371
+ topics_to_unsubscribe = list(self._topic_subscriptions.keys())
372
+ if topics_to_unsubscribe:
373
+ self.logger.info(
374
+ "Unsubscribing from %d topic(s)", len(topics_to_unsubscribe)
375
+ )
376
+ unsubscribe_tasks = [
377
+ self.unsubscribe_from_topic(topic) for topic in topics_to_unsubscribe
378
+ ]
379
+
380
+ async def gather_unsubscribes() -> list[BaseException | None]:
381
+ return await asyncio.gather(*unsubscribe_tasks, return_exceptions=True)
382
+
383
+ try:
384
+ results = await asyncio.wait_for(gather_unsubscribes(), timeout=5.0)
385
+
386
+ for topic, result in zip(topics_to_unsubscribe, results):
387
+ if isinstance(result, Exception):
388
+ self.logger.warning(
389
+ "Failed to unsubscribe from topic %s: %s", topic, result
390
+ )
391
+ self._unregister_topic(topic)
392
+ except asyncio.TimeoutError:
393
+ self.logger.warning("Unsubscribe timed out after 5s, forcing cleanup")
394
+ for topic in topics_to_unsubscribe:
395
+ self._unregister_topic(topic)
396
+
397
+ # Cancel heartbeat task
398
+ if self._heartbeat_task and not self._heartbeat_task.done():
399
+ self._heartbeat_task.cancel()
400
+ try:
401
+ await self._heartbeat_task
402
+ except asyncio.CancelledError:
403
+ pass
404
+ self._heartbeat_task = None
405
+ self._pending_heartbeat_ref = None
406
+
407
+ if self._message_routing_task and not self._message_routing_task.done():
408
+ self._message_routing_task.cancel()
409
+ try:
410
+ await self._message_routing_task
411
+ except asyncio.CancelledError:
412
+ pass
413
+
414
+ if self.connection:
415
+ await self.connection.close()
416
+ self.connection = None
417
+ self.logger.info("Connection closed")
418
+
419
+ async def _heartbeat_loop(self) -> None:
420
+ """
421
+ Send periodic heartbeat messages to keep the connection alive.
422
+
423
+ Phoenix servers expect heartbeat messages on the "phoenix" topic at regular
424
+ intervals (default 30 seconds). The Phoenix server's default timeout is
425
+ typically configured to 60 seconds (2x the heartbeat interval), after which
426
+ it will close connections that haven't sent a heartbeat.
427
+
428
+ This loop sends heartbeat messages and tracks pending heartbeat refs.
429
+ Heartbeat responses are handled in _handle_heartbeat_response().
430
+
431
+ Note: This task is only created when heartbeat_interval_secs is not None,
432
+ so no None check is needed here.
433
+ """
434
+ # Assert for type checker - this task is only created when interval is not None
435
+ assert self._heartbeat_interval_secs is not None
436
+ interval = self._heartbeat_interval_secs
437
+
438
+ while True:
439
+ try:
440
+ await asyncio.sleep(interval)
441
+
442
+ if self.connection is None:
443
+ self.logger.debug("Heartbeat loop stopping: no connection")
444
+ break
445
+
446
+ # Don't send a new heartbeat if we're still waiting for a response.
447
+ # This matches the Phoenix JS client behavior.
448
+ # Note: After clearing the pending ref, if a late response arrives for
449
+ # the old heartbeat, it won't be recognized (the ref won't match).
450
+ # This is intentional and matches JS client behavior - we don't track
451
+ # multiple outstanding heartbeats.
452
+ if self._pending_heartbeat_ref is not None:
453
+ self.logger.warning(
454
+ "Heartbeat timeout: no response to heartbeat ref=%s",
455
+ self._pending_heartbeat_ref,
456
+ )
457
+ # Clear the pending ref and continue - the connection might
458
+ # still be alive, let the next heartbeat try again
459
+ self._pending_heartbeat_ref = None
460
+
461
+ # Generate ref and send heartbeat
462
+ self._pending_heartbeat_ref = self._generate_ref()
463
+ heartbeat_message = make_message(
464
+ event=Event("heartbeat"),
465
+ topic="phoenix",
466
+ ref=self._pending_heartbeat_ref,
467
+ payload={},
468
+ )
469
+
470
+ await self._protocol_handler.send_message(
471
+ self.connection, heartbeat_message
472
+ )
473
+ self.logger.debug("Sent heartbeat ref=%s", self._pending_heartbeat_ref)
474
+
475
+ except asyncio.CancelledError:
476
+ self.logger.debug("Heartbeat loop cancelled")
477
+ raise
478
+ except Exception as e:
479
+ self.logger.warning("Heartbeat failed: %s", e)
480
+ # Connection might be dead, exit the loop
481
+ break
482
+
483
+ def _handle_heartbeat_response(self, message: ChannelMessage) -> bool:
484
+ """
485
+ Handle a heartbeat response from the server.
486
+
487
+ This method is intentionally synchronous because it only performs simple
488
+ attribute access and comparison - no I/O or blocking operations. Keeping
489
+ it sync avoids unnecessary async overhead for this hot path.
490
+
491
+ Args:
492
+ message: The received message
493
+
494
+ Returns:
495
+ True if this was a heartbeat response and was handled,
496
+ False if this message is not a heartbeat response.
497
+ """
498
+ # Check if this is a response to our heartbeat
499
+ if (
500
+ message.topic == "phoenix"
501
+ and message.ref is not None
502
+ and message.ref == self._pending_heartbeat_ref
503
+ ):
504
+ self._pending_heartbeat_ref = None
505
+ self.logger.debug("Heartbeat acknowledged ref=%s", message.ref)
506
+ return True
507
+ return False
508
+
509
+ def _set_subscription_ready(self, topic_subscription: TopicSubscription) -> None:
510
+ if not topic_subscription.subscription_ready.done():
511
+ topic_subscription.subscription_ready.set_result(None)
512
+
513
+ def _set_subscription_error(
514
+ self, topic_subscription: TopicSubscription, error: Exception
515
+ ) -> None:
516
+ if not topic_subscription.subscription_ready.done():
517
+ topic_subscription.subscription_ready.set_exception(error)
518
+
519
+ def _determine_processing_state(
520
+ self, topic: TopicSubscription
521
+ ) -> TopicProcessingState:
522
+ subscription_ready = topic.subscription_ready.done()
523
+ leave_requested = topic.leave_requested.is_set()
524
+
525
+ if not subscription_ready:
526
+ return TopicProcessingState.WAITING_FOR_JOIN
527
+ elif leave_requested:
528
+ return TopicProcessingState.PROCESSING_LEAVE
529
+ else:
530
+ return TopicProcessingState.NORMAL_PROCESSING
531
+
532
+ async def _process_topic_messages(self, topic_name: str) -> None:
533
+ topic = self._topic_subscriptions[topic_name]
534
+
535
+ try:
536
+ while True:
537
+ message = await topic.queue.get()
538
+
539
+ current_state = self._determine_processing_state(topic)
540
+
541
+ if current_state == TopicProcessingState.WAITING_FOR_JOIN:
542
+ await self._handle_join_response_mode(topic, message)
543
+
544
+ elif current_state == TopicProcessingState.PROCESSING_LEAVE:
545
+ try:
546
+ await self._handle_leave_mode(topic, message)
547
+ except PHXTopicError:
548
+ break
549
+
550
+ elif current_state == TopicProcessingState.NORMAL_PROCESSING:
551
+ await self._handle_normal_message_mode(topic, message)
552
+
553
+ except Exception as e:
554
+ self.logger.error("Error in topic processor for %s: %s", topic.name, e)
555
+ self._unregister_topic(topic.name)
556
+
557
+ async def _handle_join_response_mode(
558
+ self, topic: TopicSubscription, message: ChannelMessage
559
+ ) -> None:
560
+ if not isinstance(message, PHXEventMessage) or message.event != PHXEvent.reply:
561
+ raise PHXTopicError(
562
+ f"Unexpected message type in join response mode: {message}"
563
+ )
564
+
565
+ if message.payload.get("status") == "ok":
566
+ self._set_subscription_ready(topic)
567
+ self.logger.info("Subscribed to topic: %s", topic.name)
568
+ else:
569
+ response = message.payload.get("response", {})
570
+ error_message = (
571
+ response.get("reason", "invalid topic")
572
+ if isinstance(response, dict)
573
+ else "invalid topic"
574
+ )
575
+
576
+ error = PHXTopicError(error_message)
577
+ self._set_subscription_error(topic, error)
578
+ self.logger.error(
579
+ "Failed to subscribe to topic %s: %s", topic.name, error_message
580
+ )
581
+ raise error
582
+
583
+ def _capture_handlers_atomically(
584
+ self, topic: TopicSubscription, message: ChannelMessage
585
+ ) -> tuple:
586
+ message_handler = topic.async_callback
587
+ event_handler = topic.get_event_handler(message.event)
588
+ return message_handler, event_handler
589
+
590
+ async def _handle_normal_message_mode(
591
+ self, topic: TopicSubscription, message: ChannelMessage
592
+ ) -> None:
593
+ message_handler, event_handler = self._capture_handlers_atomically(
594
+ topic, message
595
+ )
596
+
597
+ try:
598
+ has_message_handler = message_handler is not None
599
+ has_specific_handler = event_handler is not None
600
+
601
+ if has_message_handler:
602
+ topic.current_callback_task = asyncio.create_task(
603
+ message_handler(message)
604
+ )
605
+ await topic.current_callback_task
606
+ topic.current_callback_task = None
607
+
608
+ if has_specific_handler:
609
+ topic.current_callback_task = asyncio.create_task(
610
+ event_handler(message.payload)
611
+ )
612
+ await topic.current_callback_task
613
+
614
+ if not has_message_handler and not has_specific_handler:
615
+ self.logger.warning(
616
+ "No handler for event %s on topic %s", message.event, topic.name
617
+ )
618
+
619
+ except Exception as e:
620
+ self.logger.error("Error in topic callback for %s: %s", topic.name, e)
621
+ finally:
622
+ topic.current_callback_task = None
623
+
624
+ async def _handle_leave_mode(
625
+ self, topic: TopicSubscription, message: ChannelMessage
626
+ ) -> None:
627
+ if not isinstance(message, PHXEventMessage) or message.event != PHXEvent.reply:
628
+ return
629
+
630
+ is_leave_success = message.payload.get("status") == "ok"
631
+
632
+ if is_leave_success:
633
+ self.logger.info("Unsubscribed from topic: %s", topic.name)
634
+
635
+ if topic.current_callback_task and not topic.current_callback_task.done():
636
+ try:
637
+ await topic.current_callback_task
638
+ except Exception as e:
639
+ self.logger.error(
640
+ "Error waiting for callback to finish for %s: %s", topic.name, e
641
+ )
642
+
643
+ if topic.unsubscribe_completed and not topic.unsubscribe_completed.done():
644
+ topic.unsubscribe_completed.set_result(None)
645
+
646
+ else:
647
+ self.logger.error(
648
+ "Failed to unsubscribe from topic %s: %s", topic.name, message.payload
649
+ )
650
+ if topic.unsubscribe_completed and not topic.unsubscribe_completed.done():
651
+ topic.unsubscribe_completed.set_exception(
652
+ PHXTopicError(f"Failed to unsubscribe: {message.payload}")
653
+ )
654
+ raise PHXTopicError(f"Failed to unsubscribe: {message.payload}")
655
+
656
+ def _unregister_topic(self, topic_name: str) -> None:
657
+ if topic_name in self._topic_subscriptions:
658
+ topic_subscription = self._topic_subscriptions[topic_name]
659
+ if topic_subscription.process_topic_messages_task:
660
+ topic_subscription.process_topic_messages_task.cancel()
661
+ del self._topic_subscriptions[topic_name]
662
+
663
+ def get_current_subscriptions(self) -> dict[str, TopicSubscription]:
664
+ return self._topic_subscriptions.copy()
665
+
666
+ def get_protocol_handler(self) -> PHXProtocolHandler:
667
+ return self._protocol_handler
668
+
669
+ def _generate_ref(self) -> str:
670
+ self._ref_counter += 1
671
+ return str(self._ref_counter)
672
+
673
+ async def subscribe_to_topic(
674
+ self,
675
+ topic: str,
676
+ async_callback: Callable[[ChannelMessage], Awaitable[None]] | None = None,
677
+ ) -> None:
678
+ if topic in self._topic_subscriptions:
679
+ raise PHXTopicError(f"Topic {topic} already subscribed")
680
+
681
+ topic_queue = Queue()
682
+ subscription_ready_future = self._loop.create_future()
683
+ join_ref = self._generate_ref()
684
+
685
+ topic_subscription = TopicSubscription(
686
+ name=topic,
687
+ async_callback=async_callback,
688
+ queue=topic_queue,
689
+ subscription_ready=subscription_ready_future,
690
+ join_ref=join_ref,
691
+ process_topic_messages_task=self._loop.create_task(
692
+ self._process_topic_messages(topic)
693
+ ),
694
+ )
695
+
696
+ self._topic_subscriptions[topic] = topic_subscription
697
+ topic_join_message = make_message(
698
+ event=PHXEvent.join, topic=topic, ref=join_ref, join_ref=join_ref
699
+ )
700
+ if self.connection is None:
701
+ raise PHXConnectionError("Not connected to server")
702
+ await self._protocol_handler.send_message(self.connection, topic_join_message)
703
+
704
+ try:
705
+ await subscription_ready_future
706
+ # Store callback for reconnection
707
+ self._subscription_callbacks[topic] = async_callback
708
+ if topic not in self._subscription_event_handlers:
709
+ self._subscription_event_handlers[topic] = {}
710
+ except Exception as e:
711
+ self.logger.error("Failed to subscribe to %s: %s", topic, e)
712
+ self._unregister_topic(topic)
713
+ raise
714
+
715
+ async def unsubscribe_from_topic(self, topic: str) -> None:
716
+ if topic not in self._topic_subscriptions:
717
+ raise PHXTopicError(f"Topic {topic} not subscribed")
718
+
719
+ topic_subscription = self._topic_subscriptions[topic]
720
+
721
+ unsubscribe_completed_future = self._loop.create_future()
722
+ topic_subscription.unsubscribe_completed = unsubscribe_completed_future
723
+
724
+ leave_ref = self._generate_ref()
725
+ topic_leave_message = make_message(
726
+ event=PHXEvent.leave,
727
+ topic=topic,
728
+ ref=leave_ref,
729
+ join_ref=topic_subscription.join_ref,
730
+ )
731
+ if self.connection is None:
732
+ raise PHXConnectionError("Not connected to server")
733
+ await self._protocol_handler.send_message(self.connection, topic_leave_message)
734
+
735
+ topic_subscription.leave_requested.set()
736
+
737
+ try:
738
+ await unsubscribe_completed_future
739
+ except asyncio.CancelledError:
740
+ raise
741
+ except Exception as e:
742
+ self.logger.error("Error unsubscribing from %s: %s", topic, e)
743
+ raise
744
+ finally:
745
+ self._unregister_topic(topic)
746
+ # Remove stored subscription info
747
+ self._subscription_callbacks.pop(topic, None)
748
+ self._subscription_event_handlers.pop(topic, None)
749
+
750
+ def add_event_handler(
751
+ self,
752
+ topic: str,
753
+ event: ChannelEvent,
754
+ handler: Callable[[Dict[str, Any]], Awaitable[None]],
755
+ ) -> None:
756
+ """Add or update an event handler for a specific event type on a topic."""
757
+ if topic not in self._topic_subscriptions:
758
+ raise PHXTopicError(f"Topic {topic} not subscribed")
759
+
760
+ topic_subscription = self._topic_subscriptions[topic]
761
+ topic_subscription.add_event_handler(event, handler)
762
+
763
+ # Store for reconnection
764
+ if topic not in self._subscription_event_handlers:
765
+ self._subscription_event_handlers[topic] = {}
766
+ self._subscription_event_handlers[topic][event] = handler
767
+
768
+ def remove_event_handler(self, topic: str, event: ChannelEvent) -> None:
769
+ """Remove an event handler for a specific event type on a topic."""
770
+ if topic not in self._topic_subscriptions:
771
+ raise PHXTopicError(f"Topic {topic} not subscribed")
772
+
773
+ topic_subscription = self._topic_subscriptions[topic]
774
+ topic_subscription.remove_event_handler(event)
775
+
776
+ # Remove from stored handlers
777
+ if topic in self._subscription_event_handlers:
778
+ self._subscription_event_handlers[topic].pop(event, None)
779
+
780
+ def get_event_handler(
781
+ self, topic: str, event: ChannelEvent
782
+ ) -> Callable[[Dict[str, Any]], Awaitable[None]] | None:
783
+ """Get the handler for a specific event type on a topic."""
784
+ if topic not in self._topic_subscriptions:
785
+ raise PHXTopicError(f"Topic {topic} not subscribed")
786
+
787
+ topic_subscription = self._topic_subscriptions[topic]
788
+ return topic_subscription.get_event_handler(event)
789
+
790
+ def has_event_handler(self, topic: str, event: ChannelEvent) -> bool:
791
+ """Check if a handler exists for a specific event type on a topic."""
792
+ if topic not in self._topic_subscriptions:
793
+ return False
794
+
795
+ topic_subscription = self._topic_subscriptions[topic]
796
+ return topic_subscription.has_event_handler(event)
797
+
798
+ def list_event_handlers(
799
+ self, topic: str
800
+ ) -> Dict[ChannelEvent, Callable[[Dict[str, Any]], Awaitable[None]]]:
801
+ """List all event handlers for a topic."""
802
+ if topic not in self._topic_subscriptions:
803
+ raise PHXTopicError(f"Topic {topic} not subscribed")
804
+
805
+ topic_subscription = self._topic_subscriptions[topic]
806
+ return topic_subscription.event_handlers.copy()
807
+
808
+ def set_message_handler(
809
+ self, topic: str, handler: Callable[[ChannelMessage], Awaitable[None]]
810
+ ) -> None:
811
+ """Set or update the message handler for a topic. This handler receives all messages."""
812
+ if topic not in self._topic_subscriptions:
813
+ raise PHXTopicError(f"Topic {topic} not subscribed")
814
+
815
+ topic_subscription = self._topic_subscriptions[topic]
816
+ topic_subscription.async_callback = handler
817
+
818
+ def remove_message_handler(self, topic: str) -> None:
819
+ """Remove the message handler for a topic."""
820
+ if topic not in self._topic_subscriptions:
821
+ raise PHXTopicError(f"Topic {topic} not subscribed")
822
+
823
+ topic_subscription = self._topic_subscriptions[topic]
824
+ topic_subscription.async_callback = None
825
+
826
+ def get_message_handler(
827
+ self, topic: str
828
+ ) -> Callable[[ChannelMessage], Awaitable[None]] | None:
829
+ """Get the current message handler for a topic."""
830
+ if topic not in self._topic_subscriptions:
831
+ raise PHXTopicError(f"Topic {topic} not subscribed")
832
+
833
+ topic_subscription = self._topic_subscriptions[topic]
834
+ return topic_subscription.async_callback
835
+
836
+ def has_message_handler(self, topic: str) -> bool:
837
+ """Check if a message handler exists for a topic."""
838
+ if topic not in self._topic_subscriptions:
839
+ return False
840
+
841
+ topic_subscription = self._topic_subscriptions[topic]
842
+ return topic_subscription.async_callback is not None
843
+
844
+ async def _start_processing(self) -> None:
845
+ if self.connection is None:
846
+ raise PHXConnectionError("Not connected to server")
847
+ await self._protocol_handler.process_websocket_messages(
848
+ self.connection,
849
+ self._topic_subscriptions,
850
+ on_unhandled_message=self._handle_heartbeat_response,
851
+ )
852
+
853
+ async def run_forever(self) -> None:
854
+ """
855
+ Run until connection closes or Ctrl+C is pressed.
856
+
857
+ This method registers signal handlers for SIGINT (Ctrl+C) and SIGTERM
858
+ to enable graceful shutdown. When a signal is received, the client will:
859
+ 1. Send leave messages to all subscribed topics
860
+ 2. Wait for server acknowledgments (up to 5 seconds)
861
+ 3. Close the connection cleanly
862
+
863
+ Note: Signal handlers are automatically cleaned up when this method exits.
864
+ If you need custom signal handling, consider managing signals at the
865
+ application level and calling shutdown() explicitly.
866
+
867
+ Raises:
868
+ PHXConnectionError: If client is not connected
869
+ Exception: If the WebSocket connection fails
870
+ """
871
+ if self._message_routing_task is None:
872
+ raise PHXConnectionError(
873
+ "Client is not connected. Use 'async with' context manager."
874
+ )
875
+
876
+ shutdown_event = asyncio.Event()
877
+ loop = asyncio.get_running_loop()
878
+
879
+ def signal_handler():
880
+ shutdown_event.set()
881
+
882
+ # Register asyncio signal handlers for graceful shutdown
883
+ loop.add_signal_handler(signal.SIGINT, signal_handler)
884
+ loop.add_signal_handler(signal.SIGTERM, signal_handler)
885
+
886
+ try:
887
+ # Wait for either the message routing task to complete or shutdown signal
888
+ await asyncio.wait(
889
+ [
890
+ self._message_routing_task,
891
+ asyncio.create_task(shutdown_event.wait()),
892
+ ],
893
+ return_when=asyncio.FIRST_COMPLETED,
894
+ )
895
+ # When this returns, either connection closed or Ctrl+C was pressed
896
+ # In both cases, we exit and let __aexit__ handle cleanup via shutdown()
897
+ finally:
898
+ # Remove signal handlers
899
+ loop.remove_signal_handler(signal.SIGINT)
900
+ loop.remove_signal_handler(signal.SIGTERM)