phoenix-channels-python-client 0.1.4__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {phoenix_channels_python_client-0.1.4/phoenix_channels_python_client.egg-info → phoenix_channels_python_client-0.2.0}/PKG-INFO +1 -1
  2. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0}/phoenix_channels_python_client/__init__.py +12 -2
  3. phoenix_channels_python_client-0.2.0/phoenix_channels_python_client/client.py +234 -0
  4. phoenix_channels_python_client-0.2.0/phoenix_channels_python_client/client_state_machine.py +38 -0
  5. phoenix_channels_python_client-0.2.0/phoenix_channels_python_client/client_types.py +102 -0
  6. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0}/phoenix_channels_python_client/phx_messages.py +18 -15
  7. phoenix_channels_python_client-0.2.0/phoenix_channels_python_client/protocol_handler.py +207 -0
  8. phoenix_channels_python_client-0.2.0/phoenix_channels_python_client/reconnect_controller.py +281 -0
  9. phoenix_channels_python_client-0.2.0/phoenix_channels_python_client/supervisor.py +378 -0
  10. phoenix_channels_python_client-0.2.0/phoenix_channels_python_client/topic_runtime.py +581 -0
  11. phoenix_channels_python_client-0.2.0/phoenix_channels_python_client/topic_subscription.py +53 -0
  12. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0/phoenix_channels_python_client.egg-info}/PKG-INFO +1 -1
  13. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0}/phoenix_channels_python_client.egg-info/SOURCES.txt +9 -1
  14. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0}/pyproject.toml +1 -1
  15. phoenix_channels_python_client-0.2.0/tests/test_internal_components.py +979 -0
  16. phoenix_channels_python_client-0.2.0/tests/test_reconnect_policy_invariants.py +124 -0
  17. phoenix_channels_python_client-0.2.0/tests/test_reconnect_stress.py +136 -0
  18. phoenix_channels_python_client-0.1.4/phoenix_channels_python_client/client.py +0 -900
  19. phoenix_channels_python_client-0.1.4/phoenix_channels_python_client/protocol_handler.py +0 -170
  20. phoenix_channels_python_client-0.1.4/phoenix_channels_python_client/topic_subscription.py +0 -46
  21. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0}/LICENSE +0 -0
  22. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0}/MANIFEST.in +0 -0
  23. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0}/README.md +0 -0
  24. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0}/phoenix_channels_python_client/exceptions.py +0 -0
  25. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0}/phoenix_channels_python_client/utils.py +0 -0
  26. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0}/phoenix_channels_python_client.egg-info/dependency_links.txt +0 -0
  27. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0}/phoenix_channels_python_client.egg-info/requires.txt +0 -0
  28. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0}/phoenix_channels_python_client.egg-info/top_level.txt +0 -0
  29. {phoenix_channels_python_client-0.1.4 → phoenix_channels_python_client-0.2.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: phoenix-channels-python-client
3
- Version: 0.1.4
3
+ Version: 0.2.0
4
4
  Summary: A Python client library for connecting to Phoenix Channels
5
5
  Author: Phoenix Channels Python Client
6
6
  License-Expression: MIT
@@ -4,19 +4,29 @@ Phoenix Channels Python Client
4
4
  A Python client library for connecting to Phoenix Channels.
5
5
  """
6
6
 
7
- from phoenix_channels_python_client.client import PHXChannelsClient
7
+ from __future__ import annotations
8
+
9
+ from phoenix_channels_python_client.client import (
10
+ DisconnectCallback,
11
+ PHXChannelsClient,
12
+ ReconnectCallback,
13
+ ReconnectPolicy,
14
+ )
8
15
  from phoenix_channels_python_client.protocol_handler import (
9
16
  PHXProtocolHandler,
10
17
  PhoenixChannelsProtocolVersion,
11
18
  )
12
19
  from phoenix_channels_python_client.utils import setup_logging
13
20
 
14
- __version__ = "0.1.4"
21
+ __version__ = "0.2.0"
15
22
  __author__ = "Phoenix Channels Python Client"
16
23
 
17
24
  __all__ = [
25
+ "DisconnectCallback",
18
26
  "PHXChannelsClient",
19
27
  "PHXProtocolHandler",
20
28
  "PhoenixChannelsProtocolVersion",
29
+ "ReconnectCallback",
30
+ "ReconnectPolicy",
21
31
  "setup_logging",
22
32
  ]
@@ -0,0 +1,234 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ from collections import deque
6
+ from collections.abc import Awaitable, Callable
7
+ from types import TracebackType
8
+ from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
9
+
10
+ from websockets import ClientConnection
11
+
12
+ from phoenix_channels_python_client.client_state_machine import transition_client_state
13
+ from phoenix_channels_python_client.client_types import (
14
+ ClientState,
15
+ ReconnectPolicy,
16
+ reconnect_policy_is_invalid,
17
+ validate_reconnect_policy,
18
+ )
19
+ from phoenix_channels_python_client.exceptions import PHXConnectionError
20
+ from phoenix_channels_python_client.protocol_handler import (
21
+ PHXProtocolHandler,
22
+ PhoenixChannelsProtocolVersion,
23
+ )
24
+ from phoenix_channels_python_client.reconnect_controller import ReconnectControllerMixin
25
+ from phoenix_channels_python_client.supervisor import SupervisorMixin
26
+ from phoenix_channels_python_client.topic_runtime import TopicRuntimeMixin
27
+ from phoenix_channels_python_client.topic_subscription import TopicSubscription
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ ReconnectCallback = Callable[[], Awaitable[None]]
32
+ DisconnectCallback = Callable[[Exception | None], Awaitable[None]]
33
+
34
+
35
+ def _build_channel_socket_urls(
36
+ websocket_url: str, api_key: str, vsn: str
37
+ ) -> tuple[str, str]:
38
+ split_url = urlsplit(websocket_url)
39
+ query_params = parse_qsl(split_url.query, keep_blank_values=True)
40
+ filtered = [
41
+ (key, value) for key, value in query_params if key not in {"api_key", "vsn"}
42
+ ]
43
+ with_auth = [*filtered, ("api_key", api_key), ("vsn", vsn)]
44
+
45
+ connect_url = urlunsplit(
46
+ (
47
+ split_url.scheme,
48
+ split_url.netloc,
49
+ split_url.path,
50
+ urlencode(with_auth),
51
+ split_url.fragment,
52
+ )
53
+ )
54
+ redacted_url = urlunsplit(
55
+ (
56
+ split_url.scheme,
57
+ split_url.netloc,
58
+ split_url.path,
59
+ urlencode(
60
+ [
61
+ (key, "***" if key == "api_key" else value)
62
+ for key, value in with_auth
63
+ ]
64
+ ),
65
+ split_url.fragment,
66
+ )
67
+ )
68
+ return connect_url, redacted_url
69
+
70
+
71
+ class PHXChannelsClient(SupervisorMixin, TopicRuntimeMixin, ReconnectControllerMixin):
72
+ def __init__(
73
+ self,
74
+ websocket_url: str,
75
+ api_key: str,
76
+ *,
77
+ protocol_version: PhoenixChannelsProtocolVersion = PhoenixChannelsProtocolVersion.V2,
78
+ auto_reconnect: bool = True,
79
+ reconnect_policy: ReconnectPolicy | None = None,
80
+ join_timeout_s: float = 10.0,
81
+ leave_timeout_s: float = 5.0,
82
+ max_topic_queue_size: int = 1000,
83
+ callback_drain_timeout_s: float = 2.0,
84
+ heartbeat_interval_s: float | None = 30.0,
85
+ on_reconnect: ReconnectCallback | None = None,
86
+ on_disconnect: DisconnectCallback | None = None,
87
+ ):
88
+ self.logger = logger
89
+
90
+ if heartbeat_interval_s is not None and heartbeat_interval_s <= 0:
91
+ raise ValueError("heartbeat_interval_s must be > 0 or None to disable")
92
+
93
+ if join_timeout_s <= 0:
94
+ raise ValueError("join_timeout_s must be > 0")
95
+ if leave_timeout_s <= 0:
96
+ raise ValueError("leave_timeout_s must be > 0")
97
+ if max_topic_queue_size <= 0:
98
+ raise ValueError("max_topic_queue_size must be > 0")
99
+ if callback_drain_timeout_s <= 0:
100
+ raise ValueError("callback_drain_timeout_s must be > 0")
101
+ try:
102
+ validate_reconnect_policy(reconnect_policy or ReconnectPolicy())
103
+ except ValueError as exc:
104
+ raise ValueError("Invalid reconnect policy configuration") from exc
105
+
106
+ vsn = (
107
+ "2.0.0"
108
+ if protocol_version == PhoenixChannelsProtocolVersion.V2
109
+ else "1.0.0"
110
+ )
111
+ connect_url, redacted_url = _build_channel_socket_urls(
112
+ websocket_url=websocket_url,
113
+ api_key=api_key,
114
+ vsn=vsn,
115
+ )
116
+ self.channel_socket_url = connect_url
117
+ self.channel_socket_url_redacted = redacted_url
118
+
119
+ self.auto_reconnect = auto_reconnect
120
+ self.reconnect_policy = reconnect_policy or ReconnectPolicy()
121
+ self.join_timeout_s = join_timeout_s
122
+ self.leave_timeout_s = leave_timeout_s
123
+ self.max_topic_queue_size = max_topic_queue_size
124
+ self.callback_drain_timeout_s = callback_drain_timeout_s
125
+
126
+ self.connection: ClientConnection | None = None
127
+ self._topic_subscriptions: dict[str, TopicSubscription] = {}
128
+ self._protocol_handler = PHXProtocolHandler(protocol_version)
129
+ self._ref_counter = 0
130
+ self._state = ClientState.CLOSED
131
+ self._topics_lock = asyncio.Lock()
132
+ self._shutdown_event = asyncio.Event()
133
+ self._connected_event = asyncio.Event()
134
+ self._conn_generation = 0
135
+ self._supervisor_task: asyncio.Task[None] | None = None
136
+ self._message_routing_task: asyncio.Task[None] | None = None
137
+ self._initial_connection_future: asyncio.Future[None] | None = None
138
+ self._rapid_disconnects: deque[float] = deque()
139
+ self._terminal_error: Exception | None = None
140
+ self._heartbeat_interval_s = heartbeat_interval_s
141
+ self._heartbeat_task: asyncio.Task[None] | None = None
142
+ self._pending_heartbeat_ref: str | None = None
143
+ self._on_reconnect = on_reconnect
144
+ self._on_disconnect = on_disconnect
145
+
146
+ @staticmethod
147
+ def reconnect_policy_is_invalid(policy: ReconnectPolicy) -> bool:
148
+ return reconnect_policy_is_invalid(policy)
149
+
150
+ async def __aenter__(self) -> PHXChannelsClient:
151
+ self.logger.debug("Entering PHXChannelsClient context")
152
+ if self._state != ClientState.CLOSED:
153
+ raise PHXConnectionError("Client is already running")
154
+
155
+ self._shutdown_event.clear()
156
+ self._connected_event.clear()
157
+ self._rapid_disconnects.clear()
158
+ self._terminal_error = None
159
+ self._initial_connection_future = asyncio.get_running_loop().create_future()
160
+ self._transition_state(ClientState.CONNECTING)
161
+
162
+ self._supervisor_task = asyncio.create_task(self._supervisor_loop())
163
+
164
+ try:
165
+ await self._initial_connection_future
166
+ except Exception:
167
+ await self.shutdown("Initial connection failed")
168
+ raise
169
+
170
+ return self
171
+
172
+ async def __aexit__(
173
+ self,
174
+ exc_type: type[BaseException] | None = None,
175
+ exc_value: BaseException | None = None,
176
+ traceback: TracebackType | None = None,
177
+ ) -> None:
178
+ self.logger.debug("Leaving PHXChannelsClient context")
179
+ await self.shutdown("Leaving PHXChannelsClient context")
180
+
181
+ async def shutdown(
182
+ self,
183
+ reason: str,
184
+ ) -> None:
185
+ if (
186
+ self._state == ClientState.CLOSED
187
+ and not self._topic_subscriptions
188
+ and self.connection is None
189
+ ):
190
+ return
191
+
192
+ self.logger.info("Event loop shutting down! reason=%s", reason)
193
+
194
+ if self._state not in (ClientState.SHUTTING_DOWN, ClientState.CLOSED):
195
+ self._transition_state(ClientState.SHUTTING_DOWN)
196
+
197
+ self._shutdown_event.set()
198
+
199
+ topics_to_unsubscribe = list(self._topic_subscriptions.keys())
200
+ if topics_to_unsubscribe:
201
+ unsubscribe_tasks = [
202
+ self.unsubscribe_from_topic(topic, _allow_disconnected=True)
203
+ for topic in topics_to_unsubscribe
204
+ ]
205
+ results = await asyncio.gather(*unsubscribe_tasks, return_exceptions=True)
206
+
207
+ for topic, result in zip(topics_to_unsubscribe, results, strict=False):
208
+ if isinstance(result, Exception):
209
+ self.logger.warning(
210
+ "Failed to unsubscribe from topic %s during shutdown: %s",
211
+ topic,
212
+ result,
213
+ )
214
+
215
+ if self._supervisor_task and not self._supervisor_task.done():
216
+ self._supervisor_task.cancel()
217
+ try:
218
+ await self._supervisor_task
219
+ except asyncio.CancelledError:
220
+ self.logger.debug("Supervisor task cancelled during shutdown")
221
+
222
+ await self._cleanup_connection()
223
+ self._connected_event.clear()
224
+ self._transition_state(ClientState.CLOSED)
225
+
226
+ def _transition_state(self, new_state: ClientState) -> None:
227
+ if self._state == new_state:
228
+ return
229
+
230
+ transitioned_state = transition_client_state(self._state, new_state)
231
+ self.logger.debug(
232
+ "Client state transition: %s -> %s", self._state.value, new_state.value
233
+ )
234
+ self._state = transitioned_state
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+
3
+ from phoenix_channels_python_client.client_types import ClientState
4
+
5
+
6
+ def transition_client_state(
7
+ current: ClientState, new_state: ClientState
8
+ ) -> ClientState:
9
+ if current == new_state:
10
+ return current
11
+
12
+ allowed_transitions: dict[ClientState, set[ClientState]] = {
13
+ ClientState.CLOSED: {ClientState.CONNECTING},
14
+ ClientState.CONNECTING: {
15
+ ClientState.CONNECTED,
16
+ ClientState.RECONNECTING,
17
+ ClientState.SHUTTING_DOWN,
18
+ ClientState.CLOSED,
19
+ },
20
+ ClientState.CONNECTED: {
21
+ ClientState.RECONNECTING,
22
+ ClientState.SHUTTING_DOWN,
23
+ ClientState.CLOSED,
24
+ },
25
+ ClientState.RECONNECTING: {
26
+ ClientState.CONNECTED,
27
+ ClientState.SHUTTING_DOWN,
28
+ ClientState.CLOSED,
29
+ },
30
+ ClientState.SHUTTING_DOWN: {ClientState.CLOSED},
31
+ }
32
+
33
+ if new_state not in allowed_transitions[current]:
34
+ raise RuntimeError(
35
+ f"Invalid state transition {current.value} -> {new_state.value}"
36
+ )
37
+
38
+ return new_state
@@ -0,0 +1,102 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+
6
+ from phoenix_channels_python_client.exceptions import PHXConnectionError
7
+
8
+
9
+ class ClientState(Enum):
10
+ CONNECTING = "connecting"
11
+ CONNECTED = "connected"
12
+ RECONNECTING = "reconnecting"
13
+ SHUTTING_DOWN = "shutting_down"
14
+ CLOSED = "closed"
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class ReconnectPolicy:
19
+ base_delay_s: float = 0.5
20
+ factor: float = 2.0
21
+ max_delay_s: float = 30.0
22
+ stable_reset_s: float = 60.0
23
+ reconnect_on_normal_close: bool = False
24
+ policy_violation_is_terminal: bool = True
25
+ service_restart_min_delay_s: float = 1.0
26
+ service_restart_max_delay_s: float = 5.0
27
+ try_again_later_min_delay_s: float = 30.0
28
+ try_again_later_max_delay_s: float = 60.0
29
+ rapid_disconnect_uptime_s: float = 5.0
30
+ rapid_window_s: float = 60.0
31
+ rapid_first_min_delay_s: float = 2.0
32
+ rapid_second_min_delay_s: float = 10.0
33
+ rapid_cooldown_base_s: float = 60.0
34
+ rapid_cooldown_step_s: float = 30.0
35
+ rapid_cooldown_max_s: float = 300.0
36
+ rapid_suppress_disconnect_count: int = 10
37
+ rapid_hold_down_jitter_low_ratio: float = 0.25
38
+ rapid_hold_down_jitter_high_ratio: float = 1.0
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class ReconnectDecision:
43
+ should_reconnect: bool
44
+ min_delay_s: float | None = None
45
+ max_delay_s: float | None = None
46
+ terminal_error: PHXConnectionError | None = None
47
+
48
+
49
+ def validate_reconnect_policy(policy: ReconnectPolicy) -> None:
50
+ if policy.base_delay_s < 0:
51
+ raise ValueError("base_delay_s must be >= 0")
52
+ if policy.factor <= 0:
53
+ raise ValueError("factor must be > 0")
54
+ if policy.max_delay_s < 0:
55
+ raise ValueError("max_delay_s must be >= 0")
56
+ if policy.stable_reset_s <= 0:
57
+ raise ValueError("stable_reset_s must be > 0")
58
+ if policy.service_restart_min_delay_s < 0:
59
+ raise ValueError("service_restart_min_delay_s must be >= 0")
60
+ if policy.service_restart_max_delay_s < policy.service_restart_min_delay_s:
61
+ raise ValueError(
62
+ "service_restart_max_delay_s must be >= service_restart_min_delay_s"
63
+ )
64
+ if policy.try_again_later_min_delay_s < 0:
65
+ raise ValueError("try_again_later_min_delay_s must be >= 0")
66
+ if policy.try_again_later_max_delay_s < policy.try_again_later_min_delay_s:
67
+ raise ValueError(
68
+ "try_again_later_max_delay_s must be >= try_again_later_min_delay_s"
69
+ )
70
+ if policy.rapid_disconnect_uptime_s < 0:
71
+ raise ValueError("rapid_disconnect_uptime_s must be >= 0")
72
+ if policy.rapid_window_s <= 0:
73
+ raise ValueError("rapid_window_s must be > 0")
74
+ if policy.rapid_first_min_delay_s < 0:
75
+ raise ValueError("rapid_first_min_delay_s must be >= 0")
76
+ if policy.rapid_second_min_delay_s < 0:
77
+ raise ValueError("rapid_second_min_delay_s must be >= 0")
78
+ if policy.rapid_cooldown_base_s < 0:
79
+ raise ValueError("rapid_cooldown_base_s must be >= 0")
80
+ if policy.rapid_cooldown_step_s < 0:
81
+ raise ValueError("rapid_cooldown_step_s must be >= 0")
82
+ if policy.rapid_cooldown_max_s < policy.rapid_cooldown_base_s:
83
+ raise ValueError("rapid_cooldown_max_s must be >= rapid_cooldown_base_s")
84
+ if policy.rapid_suppress_disconnect_count < 0:
85
+ raise ValueError("rapid_suppress_disconnect_count must be >= 0")
86
+ if policy.rapid_hold_down_jitter_low_ratio < 0:
87
+ raise ValueError("rapid_hold_down_jitter_low_ratio must be >= 0")
88
+ if (
89
+ policy.rapid_hold_down_jitter_high_ratio
90
+ < policy.rapid_hold_down_jitter_low_ratio
91
+ ):
92
+ raise ValueError(
93
+ "rapid_hold_down_jitter_high_ratio must be >= rapid_hold_down_jitter_low_ratio"
94
+ )
95
+
96
+
97
+ def reconnect_policy_is_invalid(policy: ReconnectPolicy) -> bool:
98
+ try:
99
+ validate_reconnect_policy(policy)
100
+ except ValueError:
101
+ return True
102
+ return False
@@ -1,41 +1,39 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import dataclass
2
4
  from enum import Enum, unique
3
5
  from functools import cached_property
4
- from typing import Any, NewType, Optional, Union
5
-
6
-
7
- Event = NewType("Event", str)
8
- ChannelEvent = Union["PHXEvent", Event]
9
- ChannelMessage = Union["PHXMessage", "PHXEventMessage"]
6
+ from typing import Any, NewType
10
7
 
11
8
 
12
9
  @unique
13
10
  class PHXEvent(Enum):
14
- """Phoenix Channels admin events"""
15
-
16
11
  close = "phx_close"
17
12
  error = "phx_error"
18
13
  join = "phx_join"
19
14
  reply = "phx_reply"
20
15
  leave = "phx_leave"
21
16
 
22
- value: str
23
-
24
17
  def __str__(self) -> str:
25
18
  return self.value
26
19
 
27
20
 
21
+ UserEvent = NewType("UserEvent", str)
22
+ # Compatibility alias for existing imports and call sites.
23
+ Event = UserEvent
24
+ ChannelEvent = PHXEvent | UserEvent
25
+
26
+
28
27
  @dataclass(frozen=True)
29
28
  class BasePHXMessage:
30
29
  topic: str
31
- ref: Optional[str]
30
+ ref: str | None
32
31
  payload: dict[str, Any]
33
32
 
34
33
  @cached_property
35
- def subtopic(self) -> Optional[str]:
34
+ def subtopic(self) -> str | None:
36
35
  if ":" not in self.topic:
37
36
  return None
38
-
39
37
  _, subtopic = self.topic.split(":", 1)
40
38
  return subtopic
41
39
 
@@ -43,10 +41,15 @@ class BasePHXMessage:
43
41
  @dataclass(frozen=True)
44
42
  class PHXMessage(BasePHXMessage):
45
43
  event: Event
46
- join_ref: Optional[str] = None
44
+ join_ref: str | None = None
47
45
 
48
46
 
49
47
  @dataclass(frozen=True)
50
48
  class PHXEventMessage(BasePHXMessage):
51
49
  event: PHXEvent
52
- join_ref: Optional[str] = None
50
+ join_ref: str | None = None
51
+
52
+
53
+ ChannelMessage = PHXMessage | PHXEventMessage
54
+ # Compatibility alias for existing tests and call sites.
55
+ Message = ChannelMessage
@@ -0,0 +1,207 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ from collections.abc import Callable
7
+ from enum import Enum
8
+
9
+ from websockets import ClientConnection
10
+
11
+ from phoenix_channels_python_client.phx_messages import ChannelMessage, Event
12
+ from phoenix_channels_python_client.topic_subscription import TopicSubscription
13
+ from phoenix_channels_python_client.utils import make_message
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class PhoenixChannelsProtocolVersion(Enum):
19
+ V1 = "1.0"
20
+ V2 = "2.0"
21
+
22
+
23
+ class PHXProtocolHandler:
24
+ def __init__(
25
+ self,
26
+ protocol_version: PhoenixChannelsProtocolVersion = PhoenixChannelsProtocolVersion.V2,
27
+ ):
28
+ self.protocol_version = protocol_version
29
+ self.logger = logger.getChild("ProtocolHandler")
30
+ self.logger.debug(
31
+ "Initialized PHXProtocolHandler for protocol version %s",
32
+ self.protocol_version.value,
33
+ )
34
+
35
+ def parse_message(self, raw_message: str | bytes) -> ChannelMessage:
36
+ self.logger.debug("Parsing raw message: %s", raw_message)
37
+ try:
38
+ parsed_data = json.loads(raw_message)
39
+ self.logger.debug("Decoded data: %s", parsed_data)
40
+ if self.protocol_version == PhoenixChannelsProtocolVersion.V2:
41
+ if not isinstance(parsed_data, list):
42
+ raise TypeError(
43
+ "Protocol v2 expects array format, "
44
+ f"got {type(parsed_data).__name__}"
45
+ )
46
+ if len(parsed_data) != 5:
47
+ raise ValueError(
48
+ "Protocol v2 expects 5-element array "
49
+ "[join_ref, ref, topic, event, payload]"
50
+ )
51
+
52
+ join_ref, ref, topic, event, payload = parsed_data
53
+ if not isinstance(topic, str) or not topic:
54
+ raise TypeError(
55
+ "Protocol v2 message topic must be a non-empty string"
56
+ )
57
+ if not isinstance(event, str) or not event:
58
+ raise TypeError(
59
+ "Protocol v2 message event must be a non-empty string"
60
+ )
61
+ if payload is None or not isinstance(payload, dict):
62
+ payload = {}
63
+
64
+ return make_message(
65
+ topic=topic,
66
+ event=Event(event),
67
+ payload=payload,
68
+ ref=ref if ref is None else str(ref),
69
+ join_ref=join_ref if join_ref is None else str(join_ref),
70
+ )
71
+
72
+ if not isinstance(parsed_data, dict):
73
+ raise TypeError(
74
+ "Protocol v1 expects object format, "
75
+ f"got {type(parsed_data).__name__}"
76
+ )
77
+
78
+ topic = parsed_data.get("topic")
79
+ event = parsed_data.get("event")
80
+ payload = parsed_data.get("payload", {})
81
+ if not isinstance(topic, str) or not topic:
82
+ raise TypeError("Protocol v1 message topic must be a non-empty string")
83
+ if not isinstance(event, str) or not event:
84
+ raise TypeError("Protocol v1 message event must be a non-empty string")
85
+ if not isinstance(payload, dict):
86
+ payload = {}
87
+
88
+ return make_message(
89
+ topic=topic,
90
+ event=Event(event),
91
+ payload=payload,
92
+ ref=parsed_data.get("ref"),
93
+ join_ref=parsed_data.get("join_ref"),
94
+ )
95
+ except (TypeError, ValueError):
96
+ self.logger.exception("Failed to parse message")
97
+ raise
98
+ except Exception as exc:
99
+ self.logger.exception("Unexpected error parsing message")
100
+ raise ValueError(f"Invalid message format: {exc}") from exc
101
+
102
+ def serialize_message(self, message: ChannelMessage) -> str:
103
+ self.logger.debug("Serializing message: %s", message)
104
+ try:
105
+ if self.protocol_version == PhoenixChannelsProtocolVersion.V2:
106
+ serialized = json.dumps(
107
+ [
108
+ message.join_ref,
109
+ message.ref,
110
+ message.topic,
111
+ str(message.event),
112
+ message.payload,
113
+ ]
114
+ )
115
+ else:
116
+ serialized = json.dumps(
117
+ {
118
+ "topic": message.topic,
119
+ "event": str(message.event),
120
+ "ref": message.ref,
121
+ "payload": message.payload,
122
+ }
123
+ )
124
+ self.logger.debug("Serialized to: %s", serialized)
125
+ return serialized
126
+ except Exception as exc:
127
+ self.logger.exception("Failed to serialize message")
128
+ raise TypeError(f"Cannot serialize message: {exc}") from exc
129
+
130
+ async def send_message(
131
+ self, websocket: ClientConnection, message: ChannelMessage
132
+ ) -> None:
133
+ self.logger.debug(
134
+ "Serializing %s to Phoenix Channels %s format",
135
+ message,
136
+ self.protocol_version.value,
137
+ )
138
+ text_message = self.serialize_message(message)
139
+
140
+ self.logger.debug("Sending as TEXT frame: %s", text_message)
141
+ await websocket.send(text_message)
142
+
143
+ async def process_websocket_messages(
144
+ self,
145
+ connection: ClientConnection,
146
+ topic_subscriptions: dict[str, TopicSubscription],
147
+ conn_generation: int,
148
+ on_heartbeat_response: Callable[[ChannelMessage], None] | None = None,
149
+ ) -> None:
150
+ self.logger.debug(
151
+ "Starting websocket message loop for generation %s", conn_generation
152
+ )
153
+ async for socket_message in connection:
154
+ phx_message = self.parse_message(socket_message)
155
+ self.logger.debug("Processing message - %s", phx_message)
156
+ topic = phx_message.topic
157
+
158
+ if topic == "phoenix":
159
+ if on_heartbeat_response is not None:
160
+ on_heartbeat_response(phx_message)
161
+ continue
162
+
163
+ if topic not in topic_subscriptions:
164
+ continue
165
+
166
+ topic_subscription = topic_subscriptions[topic]
167
+ if topic_subscription.conn_generation != conn_generation:
168
+ self.logger.debug(
169
+ "Dropping message for stale generation on topic %s. routing_gen=%s subscription_gen=%s",
170
+ topic,
171
+ conn_generation,
172
+ topic_subscription.conn_generation,
173
+ )
174
+ continue
175
+
176
+ if (
177
+ self.protocol_version == PhoenixChannelsProtocolVersion.V2
178
+ and topic_subscription.join_ref != phx_message.join_ref
179
+ ):
180
+ self.logger.debug(
181
+ "Dropping message with stale join_ref on topic %s. got=%s expected=%s",
182
+ topic,
183
+ phx_message.join_ref,
184
+ topic_subscription.join_ref,
185
+ )
186
+ continue
187
+
188
+ if topic_subscription.queue.full():
189
+ try:
190
+ topic_subscription.queue.get_nowait()
191
+ topic_subscription.dropped_message_count += 1
192
+ if (
193
+ topic_subscription.dropped_message_count == 1
194
+ or topic_subscription.dropped_message_count % 100 == 0
195
+ ):
196
+ self.logger.warning(
197
+ "Dropped %s queued messages for topic %s due to full queue",
198
+ topic_subscription.dropped_message_count,
199
+ topic,
200
+ )
201
+ except asyncio.QueueEmpty:
202
+ self.logger.debug(
203
+ "Queue became empty before drop on topic %s; skipping drop",
204
+ topic,
205
+ )
206
+
207
+ await topic_subscription.queue.put(phx_message)