hypern 0.3.0__cp311-cp311-win32.whl → 0.3.2__cp311-cp311-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,184 @@
1
+ import hashlib
2
+ import hmac
3
+ import secrets
4
+ import time
5
+ from base64 import b64decode, b64encode
6
+ from dataclasses import dataclass
7
+ from datetime import datetime, timedelta, timezone
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ import jwt
11
+
12
+ from hypern.exceptions import Forbidden, Unauthorized
13
+ from hypern.hypern import Request, Response
14
+ from .base import Middleware, MiddlewareConfig
15
+
16
+
17
+ @dataclass
18
+ class CORSConfig:
19
+ allowed_origins: List[str]
20
+ allowed_methods: List[str]
21
+ max_age: int
22
+
23
+
24
+ @dataclass
25
+ class SecurityConfig:
26
+ rate_limiting: bool = False
27
+ jwt_auth: bool = False
28
+ cors_configuration: Optional[CORSConfig] = None
29
+ csrf_protection: bool = False
30
+ security_headers: Optional[Dict[str, str]] = None
31
+ jwt_secret: str = ""
32
+ jwt_algorithm: str = "HS256"
33
+ jwt_expires_in: int = 3600 # 1 hour in seconds
34
+
35
+ def __post_init__(self):
36
+ if self.cors_configuration:
37
+ self.cors_configuration = CORSConfig(**self.cors_configuration)
38
+
39
+ if self.security_headers is None:
40
+ self.security_headers = {
41
+ "X-Frame-Options": "DENY",
42
+ "X-Content-Type-Options": "nosniff",
43
+ "Strict-Transport-Security": "max-age=31536000; includeSubDomains",
44
+ }
45
+
46
+
47
+ class SecurityMiddleware(Middleware):
48
+ def __init__(self, secur_config: SecurityConfig, config: Optional[MiddlewareConfig] = None):
49
+ super().__init__(config)
50
+ self.secur_config = secur_config
51
+ self._secret_key = secrets.token_bytes(32)
52
+ self._token_lifetime = 3600
53
+ self._rate_limit_storage = {}
54
+
55
+ def _rate_limit_check(self, request: Request) -> Optional[Response]:
56
+ """Check if the request exceeds rate limits"""
57
+ if not self.secur_config.rate_limiting:
58
+ return None
59
+
60
+ client_ip = request.client.host
61
+ current_time = time.time()
62
+ window_start = int(current_time - 60) # 1-minute window
63
+
64
+ # Clean up old entries
65
+ self._rate_limit_storage = {ip: hits for ip, hits in self._rate_limit_storage.items() if hits["timestamp"] > window_start}
66
+
67
+ if client_ip not in self._rate_limit_storage:
68
+ self._rate_limit_storage[client_ip] = {"count": 1, "timestamp": current_time}
69
+ else:
70
+ self._rate_limit_storage[client_ip]["count"] += 1
71
+
72
+ if self._rate_limit_storage[client_ip]["count"] > 60: # 60 requests per minute
73
+ return Response(status_code=429, description=b"Too Many Requests", headers={"Retry-After": "60"})
74
+ return None
75
+
76
+ def _generate_jwt_token(self, user_data: Dict[str, Any]) -> str:
77
+ """Generate a JWT token"""
78
+ if not self.secur_config.jwt_secret:
79
+ raise ValueError("JWT secret key is not configured")
80
+
81
+ payload = {
82
+ "user": user_data,
83
+ "exp": datetime.now(tz=timezone.utc) + timedelta(seconds=self.secur_config.jwt_expires_in),
84
+ "iat": datetime.now(tz=timezone.utc),
85
+ }
86
+ return jwt.encode(payload, self.secur_config.jwt_secret, algorithm=self.secur_config.jwt_algorithm)
87
+
88
+ def _verify_jwt_token(self, token: str) -> Dict[str, Any]:
89
+ """Verify JWT token and return payload"""
90
+ try:
91
+ payload = jwt.decode(token, self.secur_config.jwt_secret, algorithms=[self.secur_config.jwt_algorithm])
92
+ return payload
93
+ except jwt.ExpiredSignatureError:
94
+ raise Unauthorized("Token has expired")
95
+ except jwt.InvalidTokenError:
96
+ raise Unauthorized("Invalid token")
97
+
98
+ def _generate_csrf_token(self, session_id: str) -> str:
99
+ """Generate a new CSRF token"""
100
+ timestamp = str(int(time.time()))
101
+ token_data = f"{session_id}:{timestamp}"
102
+ signature = hmac.new(self._secret_key, token_data.encode(), hashlib.sha256).digest()
103
+ return b64encode(f"{token_data}:{b64encode(signature).decode()}".encode()).decode()
104
+
105
+ def _validate_csrf_token(self, token: str) -> bool:
106
+ """Validate CSRF token"""
107
+ try:
108
+ decoded_token = b64decode(token.encode()).decode()
109
+ session_id, timestamp, signature = decoded_token.rsplit(":", 2)
110
+
111
+ # Verify timestamp
112
+ token_time = int(timestamp)
113
+ current_time = int(time.time())
114
+ if current_time - token_time > self._token_lifetime:
115
+ return False
116
+
117
+ # Verify signature
118
+ expected_data = f"{session_id}:{timestamp}"
119
+ expected_signature = hmac.new(self._secret_key, expected_data.encode(), hashlib.sha256).digest()
120
+
121
+ actual_signature = b64decode(signature)
122
+ return hmac.compare_digest(expected_signature, actual_signature)
123
+
124
+ except (ValueError, AttributeError, TypeError):
125
+ return False
126
+
127
+ def _apply_cors_headers(self, response: Response) -> None:
128
+ """Apply CORS headers to response"""
129
+ if not self.secur_config.cors_configuration:
130
+ return
131
+
132
+ cors = self.secur_config.cors_configuration
133
+ response.headers.update(
134
+ {
135
+ "Access-Control-Allow-Origin": ", ".join(cors.allowed_origins),
136
+ "Access-Control-Allow-Methods": ", ".join(cors.allowed_methods),
137
+ "Access-Control-Max-Age": str(cors.max_age),
138
+ "Access-Control-Allow-Headers": "Content-Type, Authorization, X-CSRF-Token",
139
+ "Access-Control-Allow-Credentials": "true",
140
+ }
141
+ )
142
+
143
+ def _apply_security_headers(self, response: Response) -> None:
144
+ """Apply security headers to response"""
145
+ if self.secur_config.security_headers:
146
+ response.headers.update(self.secur_config.security_headers)
147
+
148
+ async def before_request(self, request: Request) -> Request | Response:
149
+ """Process request before handling"""
150
+ # Rate limiting check
151
+ if rate_limit_response := self._rate_limit_check(request):
152
+ return rate_limit_response
153
+
154
+ # JWT authentication check
155
+ if self.secur_config.jwt_auth:
156
+ auth_header = request.headers.get("Authorization")
157
+ if not auth_header or not auth_header.startswith("Bearer "):
158
+ raise Unauthorized("Missing or invalid authorization header")
159
+ token = auth_header.split(" ")[1]
160
+ try:
161
+ request.user = self._verify_jwt_token(token)
162
+ except Unauthorized as e:
163
+ return Response(status_code=401, description=str(e))
164
+
165
+ # CSRF protection check
166
+ if self.secur_config.csrf_protection and request.method in ["POST", "PUT", "DELETE", "PATCH"]:
167
+ csrf_token = request.headers.get("X-CSRF-Token")
168
+ if not csrf_token or not self._validate_csrf_token(csrf_token):
169
+ raise Forbidden("CSRF token missing or invalid")
170
+
171
+ return request
172
+
173
+ async def after_request(self, response: Response) -> Response:
174
+ """Process response after handling"""
175
+ self._apply_security_headers(response)
176
+ self._apply_cors_headers(response)
177
+ return response
178
+
179
+ def generate_csrf_token(self, request: Request) -> str:
180
+ """Generate and set CSRF token for the request"""
181
+ if not hasattr(request, "session_id"):
182
+ request.session_id = secrets.token_urlsafe(32)
183
+ token = self._generate_csrf_token(request.session_id)
184
+ return token
hypern/processpool.py CHANGED
@@ -25,11 +25,26 @@ def run_processes(
25
25
  after_request: List[FunctionInfo],
26
26
  response_headers: Dict[str, str],
27
27
  reload: bool = True,
28
+ on_startup: FunctionInfo | None = None,
29
+ on_shutdown: FunctionInfo | None = None,
30
+ auto_compression: bool = False,
28
31
  ) -> List[Process]:
29
32
  socket = SocketHeld(host, port)
30
33
 
31
34
  process_pool = init_processpool(
32
- router, websocket_router, socket, workers, processes, max_blocking_threads, injectables, before_request, after_request, response_headers
35
+ router,
36
+ websocket_router,
37
+ socket,
38
+ workers,
39
+ processes,
40
+ max_blocking_threads,
41
+ injectables,
42
+ before_request,
43
+ after_request,
44
+ response_headers,
45
+ on_startup,
46
+ on_shutdown,
47
+ auto_compression,
33
48
  )
34
49
 
35
50
  def terminating_signal_handler(_sig, _frame):
@@ -79,6 +94,9 @@ def init_processpool(
79
94
  before_request: List[FunctionInfo],
80
95
  after_request: List[FunctionInfo],
81
96
  response_headers: Dict[str, str],
97
+ on_startup: FunctionInfo | None = None,
98
+ on_shutdown: FunctionInfo | None = None,
99
+ auto_compression: bool = False,
82
100
  ) -> List[Process]:
83
101
  process_pool = []
84
102
 
@@ -86,7 +104,20 @@ def init_processpool(
86
104
  copied_socket = socket.try_clone()
87
105
  process = Process(
88
106
  target=spawn_process,
89
- args=(router, websocket_router, copied_socket, workers, max_blocking_threads, injectables, before_request, after_request, response_headers),
107
+ args=(
108
+ router,
109
+ websocket_router,
110
+ copied_socket,
111
+ workers,
112
+ max_blocking_threads,
113
+ injectables,
114
+ before_request,
115
+ after_request,
116
+ response_headers,
117
+ on_startup,
118
+ on_shutdown,
119
+ auto_compression,
120
+ ),
90
121
  )
91
122
  process.start()
92
123
  process_pool.append(process)
@@ -118,6 +149,9 @@ def spawn_process(
118
149
  before_request: List[FunctionInfo],
119
150
  after_request: List[FunctionInfo],
120
151
  response_headers: Dict[str, str],
152
+ on_startup: FunctionInfo | None = None,
153
+ on_shutdown: FunctionInfo | None = None,
154
+ auto_compression: bool = False,
121
155
  ):
122
156
  loop = initialize_event_loop()
123
157
 
@@ -128,7 +162,12 @@ def spawn_process(
128
162
  server.set_before_hooks(hooks=before_request)
129
163
  server.set_after_hooks(hooks=after_request)
130
164
  server.set_response_headers(headers=response_headers)
165
+ server.set_auto_compression(enabled=auto_compression)
131
166
 
167
+ if on_startup:
168
+ server.set_startup_handler(on_startup)
169
+ if on_shutdown:
170
+ server.set_shutdown_handler(on_shutdown)
132
171
  try:
133
172
  server.start(socket, workers, max_blocking_threads)
134
173
  loop = asyncio.get_event_loop()
hypern/reload.py CHANGED
@@ -2,6 +2,8 @@ import sys
2
2
  import time
3
3
  import subprocess
4
4
  from watchdog.events import FileSystemEventHandler
5
+ import signal
6
+ import os
5
7
 
6
8
  from .logging import logger
7
9
 
@@ -10,51 +12,35 @@ class EventHandler(FileSystemEventHandler):
10
12
  def __init__(self, file_path: str, directory_path: str) -> None:
11
13
  self.file_path = file_path
12
14
  self.directory_path = directory_path
13
- self.process = None # Keep track of the subprocess
14
- self.last_reload = time.time() # Keep track of the last reload. EventHandler is initialized with the process.
15
-
16
- def stop_server(self):
17
- if self.process:
18
- try:
19
- # Check if the process is still alive
20
- if self.process.poll() is None: # None means the process is still running
21
- self.process.terminate() # Gracefully terminate the process
22
- self.process.wait(timeout=5) # Wait for the process to exit
23
- else:
24
- logger.error("Process is not running.")
25
- except subprocess.TimeoutExpired:
26
- logger.error("Process did not terminate in time. Forcing termination.")
27
- self.process.kill() # Forcefully kill the process if it doesn't stop
28
- except ProcessLookupError:
29
- logger.error("Process does not exist.")
30
- except Exception as e:
31
- logger.error(f"An error occurred while stopping the server: {e}")
32
- else:
33
- logger.debug("No process to stop.")
15
+ self.process = None
16
+ self.last_reload = time.time()
34
17
 
35
18
  def reload(self):
36
- self.stop_server()
37
- logger.debug("Reloading the server")
38
- prev_process = self.process
39
- if prev_process:
40
- prev_process.kill()
41
-
42
- self.process = subprocess.Popen(
43
- [sys.executable, *sys.argv],
44
- )
45
-
46
- self.last_reload = time.time()
19
+ # Kill all existing processes with the same command
20
+ current_cmd = [sys.executable, *sys.argv]
21
+
22
+ try:
23
+ # Find and kill existing processes
24
+ for proc in subprocess.Popen(["ps", "aux"], stdout=subprocess.PIPE).communicate()[0].decode().splitlines():
25
+ if all(str(arg) in proc for arg in current_cmd):
26
+ pid = int(proc.split()[1])
27
+ try:
28
+ os.kill(pid, signal.SIGKILL) # NOSONAR
29
+ logger.debug(f"Killed process with PID {pid}")
30
+ except ProcessLookupError:
31
+ pass
32
+
33
+ # Start new process
34
+ self.process = subprocess.Popen(current_cmd)
35
+ self.last_reload = time.time()
36
+ logger.debug("Server reloaded successfully")
37
+
38
+ except Exception as e:
39
+ logger.error(f"Reload failed: {e}")
47
40
 
48
41
  def on_modified(self, event) -> None:
49
- """
50
- This function is a callback that will start a new server on every even change
51
-
52
- :param event FSEvent: a data structure with info about the events
53
- """
54
-
55
- # Avoid reloading multiple times when watchdog detects multiple events
56
42
  if time.time() - self.last_reload < 0.5:
57
43
  return
58
44
 
59
- time.sleep(0.2) # Wait for the file to be fully written
45
+ time.sleep(0.2) # Ensure file is written
60
46
  self.reload()
@@ -1,4 +1,5 @@
1
1
  from .route import Route
2
2
  from .endpoint import HTTPEndpoint
3
+ from .queue import QueuedHTTPEndpoint
3
4
 
4
- __all__ = ["Route", "HTTPEndpoint"]
5
+ __all__ = ["Route", "HTTPEndpoint", "QueuedHTTPEndpoint"]
@@ -0,0 +1,175 @@
1
+ import asyncio
2
+ import time
3
+ from contextlib import asynccontextmanager
4
+ from dataclasses import dataclass, field
5
+ from queue import PriorityQueue
6
+ from typing import Any, Dict
7
+
8
+ from hypern.hypern import Request, Response
9
+ from hypern.response import JSONResponse
10
+ from hypern.routing import HTTPEndpoint
11
+ from hypern.logging import logger
12
+
13
+
14
+ @dataclass(order=True)
15
+ class PrioritizedRequest:
16
+ priority: int
17
+ timestamp: float = field(default_factory=time.time)
18
+ request: Request | None = field(default=None, compare=False)
19
+ future: asyncio.Future = field(compare=False, default_factory=asyncio.Future)
20
+
21
+
22
+ class QueuedHTTPEndpoint(HTTPEndpoint):
23
+ """
24
+ HTTPEndpoint with request queuing capabilities for high-load scenarios.
25
+ """
26
+
27
+ def __init__(self, *args, **kwargs):
28
+ super().__init__(*args, **kwargs)
29
+ # Queue configuration
30
+ self._max_concurrent = kwargs.get("max_concurrent", 100)
31
+ self._queue_size = kwargs.get("queue_size", 1000)
32
+ self._request_timeout = kwargs.get("request_timeout", 30)
33
+
34
+ # Initialize queuing system
35
+ self._request_queue: PriorityQueue = PriorityQueue(maxsize=self._queue_size)
36
+ self._active_requests = 0
37
+ self._lock = None # Will be initialized when needed
38
+ self._request_semaphore = None # Will be initialized when needed
39
+ self._shutdown = False
40
+ self._queue_task = None
41
+ self._initialized = False
42
+
43
+ # Metrics
44
+ self._metrics = {"processed_requests": 0, "queued_requests": 0, "rejected_requests": 0, "avg_wait_time": 0.0}
45
+
46
+ self._fully_message = "Request queue is full"
47
+
48
+ async def _initialize(self):
49
+ """Initialize async components when first request arrives"""
50
+ if not self._initialized:
51
+ self._lock = asyncio.Lock()
52
+ self._request_semaphore = asyncio.Semaphore(self._max_concurrent)
53
+ self._queue_task = asyncio.create_task(self._process_queue())
54
+ self._initialized = True
55
+
56
+ @asynccontextmanager
57
+ async def _queue_context(self, request: Request, priority: int = 10):
58
+ """Context manager for handling request queuing."""
59
+ if self._shutdown:
60
+ raise RuntimeError("Endpoint is shutting down")
61
+
62
+ await self._initialize() # Ensure async components are initialized
63
+
64
+ request_future = asyncio.Future()
65
+ prioritized_request = PrioritizedRequest(priority=priority, timestamp=time.time(), request=request, future=request_future)
66
+
67
+ try:
68
+ if self._request_queue.qsize() >= self._queue_size:
69
+ self._metrics["rejected_requests"] += 1
70
+ raise asyncio.QueueFull(self._fully_message)
71
+
72
+ await self._enqueue_request(prioritized_request)
73
+ yield await asyncio.wait_for(request_future, timeout=self._request_timeout)
74
+
75
+ except asyncio.TimeoutError:
76
+ self._metrics["rejected_requests"] += 1
77
+ raise asyncio.TimeoutError("Request timed out while waiting in queue")
78
+ finally:
79
+ if not request_future.done():
80
+ request_future.cancel()
81
+
82
+ async def _enqueue_request(self, request: PrioritizedRequest):
83
+ """Add request to the queue."""
84
+ try:
85
+ self._request_queue.put_nowait(request)
86
+ self._metrics["queued_requests"] += 1
87
+ except asyncio.QueueFull:
88
+ self._metrics["rejected_requests"] += 1
89
+ raise asyncio.QueueFull(self._fully_message)
90
+
91
+ async def _process_queue(self):
92
+ """Background task to process queued requests."""
93
+ while not self._shutdown:
94
+ try:
95
+ if not self._request_queue.empty():
96
+ async with self._lock:
97
+ if self._active_requests >= self._max_concurrent:
98
+ await asyncio.sleep(0.1)
99
+ continue
100
+
101
+ request = self._request_queue.get_nowait()
102
+ wait_time = time.time() - request.timestamp
103
+ self._metrics["avg_wait_time"] = (self._metrics["avg_wait_time"] * self._metrics["processed_requests"] + wait_time) / (
104
+ self._metrics["processed_requests"] + 1
105
+ )
106
+
107
+ if not request.future.cancelled():
108
+ self._active_requests += 1
109
+ asyncio.create_task(self._handle_request(request))
110
+
111
+ await asyncio.sleep(0.01)
112
+ except Exception as e:
113
+ logger.error(f"Error processing queue: {e}")
114
+ await asyncio.sleep(1)
115
+
116
+ async def _handle_request(self, request: PrioritizedRequest):
117
+ """Handle individual request."""
118
+ try:
119
+ async with self._request_semaphore:
120
+ response = await super().dispatch(request.request, {})
121
+ if not request.future.done():
122
+ request.future.set_result(response)
123
+ except Exception as e:
124
+ if not request.future.done():
125
+ request.future.set_exception(e)
126
+ finally:
127
+ self._active_requests -= 1
128
+ self._metrics["processed_requests"] += 1
129
+ self._metrics["queued_requests"] -= 1
130
+
131
+ async def dispatch(self, request: Request, inject: Dict[str, Any]) -> Response:
132
+ """
133
+ Enhanced dispatch method with request queuing.
134
+ """
135
+ try:
136
+ priority = self._get_request_priority(request)
137
+
138
+ async with self._queue_context(request, priority) as response:
139
+ return response
140
+
141
+ except asyncio.QueueFull:
142
+ return JSONResponse(description={"error": "Server too busy", "message": self._fully_message, "retry_after": 5}, status_code=503)
143
+ except asyncio.TimeoutError:
144
+ return JSONResponse(
145
+ description={
146
+ "error": "Request timeout",
147
+ "message": "Request timed out while waiting in queue",
148
+ },
149
+ status_code=504,
150
+ )
151
+ except Exception as e:
152
+ return JSONResponse(description={"error": "Internal server error", "message": str(e)}, status_code=500)
153
+
154
+ def _get_request_priority(self, request: Request) -> int:
155
+ """
156
+ Determine request priority. Override this method to implement
157
+ custom priority logic.
158
+ """
159
+ if request.method == "GET":
160
+ return 5
161
+ return 10
162
+
163
+ async def shutdown(self):
164
+ """Gracefully shutdown the endpoint."""
165
+ self._shutdown = True
166
+ if self._queue_task and not self._queue_task.done():
167
+ self._queue_task.cancel()
168
+ try:
169
+ await self._queue_task
170
+ except asyncio.CancelledError:
171
+ pass
172
+
173
+ def get_metrics(self) -> Dict[str, Any]:
174
+ """Get current queue metrics."""
175
+ return {**self._metrics, "current_queue_size": self._request_queue.qsize(), "active_requests": self._active_requests}
hypern/ws/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from hypern.hypern import WebSocketSession
2
+ from .route import WebsocketRoute
3
+
4
+ __all__ = ["WebsocketRoute", "WebSocketSession"]
hypern/ws/channel.py ADDED
@@ -0,0 +1,80 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Awaitable, Callable, Dict, Set
3
+
4
+ from hypern.hypern import WebSocketSession
5
+
6
+
7
+ @dataclass
8
+ class Channel:
9
+ name: str
10
+ subscribers: Set[WebSocketSession] = field(default_factory=set)
11
+ handlers: Dict[str, Callable[[WebSocketSession, Any], Awaitable[None]]] = field(default_factory=dict)
12
+
13
+ def publish(self, event: str, data: Any, publisher: WebSocketSession = None):
14
+ """Publish an event to all subscribers except the publisher"""
15
+ for subscriber in self.subscribers:
16
+ if subscriber != publisher:
17
+ subscriber.send({"channel": self.name, "event": event, "data": data})
18
+
19
+ def handle_event(self, event: str, session: WebSocketSession, data: Any):
20
+ """Handle an event on this channel"""
21
+ if event in self.handlers:
22
+ self.handlers[event](session, data)
23
+
24
+ def add_subscriber(self, subscriber: WebSocketSession):
25
+ """Add a subscriber to the channel"""
26
+ self.subscribers.add(subscriber)
27
+
28
+ def remove_subscriber(self, subscriber: WebSocketSession):
29
+ """Remove a subscriber from the channel"""
30
+ self.subscribers.discard(subscriber)
31
+
32
+ def on(self, event: str):
33
+ """Decorator for registering event handlers"""
34
+
35
+ def decorator(handler: Callable[[WebSocketSession, Any], Awaitable[None]]):
36
+ self.handlers[event] = handler
37
+ return handler
38
+
39
+ return decorator
40
+
41
+
42
+ class ChannelManager:
43
+ def __init__(self):
44
+ self.channels: Dict[str, Channel] = {}
45
+ self.client_channels: Dict[WebSocketSession, Set[str]] = {}
46
+
47
+ def create_channel(self, channel_name: str) -> Channel:
48
+ """Create a new channel if it doesn't exist"""
49
+ if channel_name not in self.channels:
50
+ self.channels[channel_name] = Channel(channel_name)
51
+ return self.channels[channel_name]
52
+
53
+ def get_channel(self, channel_name: str) -> Channel:
54
+ """Get a channel by name"""
55
+ return self.channels.get(channel_name)
56
+
57
+ def subscribe(self, client: WebSocketSession, channel_name: str):
58
+ """Subscribe a client to a channel"""
59
+ channel = self.create_channel(channel_name)
60
+ channel.add_subscriber(client)
61
+
62
+ if client not in self.client_channels:
63
+ self.client_channels[client] = set()
64
+ self.client_channels[client].add(channel_name)
65
+
66
+ def unsubscribe(self, client: WebSocketSession, channel_name: str):
67
+ """Unsubscribe a client from a channel"""
68
+ channel = self.get_channel(channel_name)
69
+ if channel:
70
+ channel.remove_subscriber(client)
71
+ if client in self.client_channels:
72
+ self.client_channels[client].discard(channel_name)
73
+
74
+ def unsubscribe_all(self, client: WebSocketSession):
75
+ """Unsubscribe a client from all channels"""
76
+ if client in self.client_channels:
77
+ channels = self.client_channels[client].copy()
78
+ for channel_name in channels:
79
+ self.unsubscribe(client, channel_name)
80
+ del self.client_channels[client]