workspace-mcp 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,519 @@
1
+ """
2
+ Session Store
3
+
4
+ Manages user sessions with proper isolation for multi-user OAuth 2.1 environments.
5
+ Provides session persistence, cleanup, and security features.
6
+ """
7
+
8
+ import asyncio
9
+ import logging
10
+ import secrets
11
+ from dataclasses import dataclass, field
12
+ from datetime import datetime, timezone, timedelta
13
+ from typing import Dict, Any, Optional, List, Set
14
+ from threading import RLock
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class Session:
21
+ """Represents a user session with OAuth 2.1 context."""
22
+
23
+ session_id: str
24
+ user_id: str
25
+ token_info: Dict[str, Any]
26
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
27
+ last_accessed: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
28
+ expires_at: Optional[datetime] = None
29
+ scopes: List[str] = field(default_factory=list)
30
+ authorization_server: Optional[str] = None
31
+ client_id: Optional[str] = None
32
+ metadata: Dict[str, Any] = field(default_factory=dict)
33
+
34
+ def is_expired(self) -> bool:
35
+ """Check if session is expired."""
36
+ if self.expires_at:
37
+ return datetime.now(timezone.utc) >= self.expires_at
38
+ return False
39
+
40
+ def update_access_time(self):
41
+ """Update last accessed timestamp."""
42
+ self.last_accessed = datetime.now(timezone.utc)
43
+
44
+ def to_dict(self) -> Dict[str, Any]:
45
+ """Convert session to dictionary."""
46
+ return {
47
+ "session_id": self.session_id,
48
+ "user_id": self.user_id,
49
+ "token_info": self.token_info,
50
+ "created_at": self.created_at.isoformat(),
51
+ "last_accessed": self.last_accessed.isoformat(),
52
+ "expires_at": self.expires_at.isoformat() if self.expires_at else None,
53
+ "scopes": self.scopes,
54
+ "authorization_server": self.authorization_server,
55
+ "client_id": self.client_id,
56
+ "metadata": self.metadata,
57
+ }
58
+
59
+ @classmethod
60
+ def from_dict(cls, data: Dict[str, Any]) -> "Session":
61
+ """Create session from dictionary."""
62
+ session = cls(
63
+ session_id=data["session_id"],
64
+ user_id=data["user_id"],
65
+ token_info=data["token_info"],
66
+ scopes=data.get("scopes", []),
67
+ authorization_server=data.get("authorization_server"),
68
+ client_id=data.get("client_id"),
69
+ metadata=data.get("metadata", {}),
70
+ )
71
+
72
+ # Parse timestamps
73
+ session.created_at = datetime.fromisoformat(data["created_at"])
74
+ session.last_accessed = datetime.fromisoformat(data["last_accessed"])
75
+ if data.get("expires_at"):
76
+ session.expires_at = datetime.fromisoformat(data["expires_at"])
77
+
78
+ return session
79
+
80
+
81
+ class SessionStore:
82
+ """Manages user sessions with proper isolation."""
83
+
84
+ def __init__(
85
+ self,
86
+ default_session_timeout: int = 3600, # 1 hour
87
+ max_sessions_per_user: int = 10,
88
+ cleanup_interval: int = 300, # 5 minutes
89
+ enable_persistence: bool = False,
90
+ persistence_file: Optional[str] = None,
91
+ ):
92
+ """
93
+ Initialize the session store.
94
+
95
+ Args:
96
+ default_session_timeout: Default session timeout in seconds
97
+ max_sessions_per_user: Maximum sessions per user
98
+ cleanup_interval: Session cleanup interval in seconds
99
+ enable_persistence: Enable session persistence to disk
100
+ persistence_file: File path for session persistence
101
+ """
102
+ self.default_session_timeout = default_session_timeout
103
+ self.max_sessions_per_user = max_sessions_per_user
104
+ self.cleanup_interval = cleanup_interval
105
+ self.enable_persistence = enable_persistence
106
+ self.persistence_file = persistence_file or ".oauth21_sessions.json"
107
+
108
+ # Thread-safe session storage
109
+ self._sessions: Dict[str, Session] = {}
110
+ self._user_sessions: Dict[str, Set[str]] = {} # user_id -> set of session_ids
111
+ self._lock = RLock()
112
+
113
+ # Cleanup task
114
+ self._cleanup_task: Optional[asyncio.Task] = None
115
+ self._shutdown = False
116
+
117
+ # Load persisted sessions
118
+ if self.enable_persistence:
119
+ self._load_sessions()
120
+
121
+ async def start_cleanup_task(self):
122
+ """Start the background cleanup task."""
123
+ if not self._cleanup_task or self._cleanup_task.done():
124
+ self._shutdown = False
125
+ self._cleanup_task = asyncio.create_task(self._cleanup_loop())
126
+ logger.info("Started session cleanup task")
127
+
128
+ async def stop_cleanup_task(self):
129
+ """Stop the background cleanup task."""
130
+ self._shutdown = True
131
+ if self._cleanup_task and not self._cleanup_task.done():
132
+ self._cleanup_task.cancel()
133
+ try:
134
+ await self._cleanup_task
135
+ except asyncio.CancelledError:
136
+ pass
137
+ logger.info("Stopped session cleanup task")
138
+
139
+ def create_session(
140
+ self,
141
+ user_id: str,
142
+ token_info: Dict[str, Any],
143
+ session_timeout: Optional[int] = None,
144
+ scopes: Optional[List[str]] = None,
145
+ authorization_server: Optional[str] = None,
146
+ client_id: Optional[str] = None,
147
+ metadata: Optional[Dict[str, Any]] = None,
148
+ ) -> str:
149
+ """
150
+ Create new session and return session ID.
151
+
152
+ Args:
153
+ user_id: User identifier (email)
154
+ token_info: OAuth token information
155
+ session_timeout: Session timeout in seconds
156
+ scopes: OAuth scopes
157
+ authorization_server: Authorization server URL
158
+ client_id: OAuth client ID
159
+ metadata: Additional session metadata
160
+
161
+ Returns:
162
+ Session ID string
163
+
164
+ Raises:
165
+ ValueError: If maximum sessions exceeded
166
+ """
167
+ with self._lock:
168
+ # Check session limits
169
+ user_session_count = len(self._user_sessions.get(user_id, set()))
170
+ if user_session_count >= self.max_sessions_per_user:
171
+ # Clean up oldest session for this user
172
+ self._cleanup_oldest_user_session(user_id)
173
+
174
+ # Generate secure session ID
175
+ session_id = self._generate_session_id()
176
+
177
+ # Calculate expiration
178
+ timeout = session_timeout or self.default_session_timeout
179
+ expires_at = datetime.now(timezone.utc) + timedelta(seconds=timeout)
180
+
181
+ # Create session
182
+ session = Session(
183
+ session_id=session_id,
184
+ user_id=user_id,
185
+ token_info=token_info,
186
+ expires_at=expires_at,
187
+ scopes=scopes or [],
188
+ authorization_server=authorization_server,
189
+ client_id=client_id,
190
+ metadata=metadata or {},
191
+ )
192
+
193
+ # Store session
194
+ self._sessions[session_id] = session
195
+
196
+ # Update user session mapping
197
+ if user_id not in self._user_sessions:
198
+ self._user_sessions[user_id] = set()
199
+ self._user_sessions[user_id].add(session_id)
200
+
201
+ logger.info(f"Created session {session_id} for user {user_id}")
202
+
203
+ # Persist if enabled
204
+ if self.enable_persistence:
205
+ self._save_sessions()
206
+
207
+ return session_id
208
+
209
+ def get_session(self, session_id: str) -> Optional[Session]:
210
+ """
211
+ Retrieve session by ID.
212
+
213
+ Args:
214
+ session_id: Session identifier
215
+
216
+ Returns:
217
+ Session object or None if not found/expired
218
+ """
219
+ with self._lock:
220
+ session = self._sessions.get(session_id)
221
+
222
+ if not session:
223
+ return None
224
+
225
+ # Check expiration
226
+ if session.is_expired():
227
+ logger.debug(f"Session {session_id} has expired")
228
+ self._remove_session(session_id)
229
+ return None
230
+
231
+ # Update access time
232
+ session.update_access_time()
233
+
234
+ logger.debug(f"Retrieved session {session_id} for user {session.user_id}")
235
+ return session
236
+
237
+ def update_session(
238
+ self,
239
+ session_id: str,
240
+ token_info: Optional[Dict[str, Any]] = None,
241
+ extend_expiration: bool = True,
242
+ metadata_updates: Optional[Dict[str, Any]] = None,
243
+ ) -> bool:
244
+ """
245
+ Update session with new token information.
246
+
247
+ Args:
248
+ session_id: Session identifier
249
+ token_info: Updated token information
250
+ extend_expiration: Whether to extend session expiration
251
+ metadata_updates: Metadata updates to apply
252
+
253
+ Returns:
254
+ True if session was updated, False if not found
255
+ """
256
+ with self._lock:
257
+ session = self._sessions.get(session_id)
258
+
259
+ if not session or session.is_expired():
260
+ return False
261
+
262
+ # Update token info
263
+ if token_info:
264
+ session.token_info.update(token_info)
265
+
266
+ # Update metadata
267
+ if metadata_updates:
268
+ session.metadata.update(metadata_updates)
269
+
270
+ # Extend expiration
271
+ if extend_expiration:
272
+ session.expires_at = datetime.now(timezone.utc) + timedelta(
273
+ seconds=self.default_session_timeout
274
+ )
275
+
276
+ session.update_access_time()
277
+
278
+ logger.debug(f"Updated session {session_id}")
279
+
280
+ # Persist if enabled
281
+ if self.enable_persistence:
282
+ self._save_sessions()
283
+
284
+ return True
285
+
286
+ def remove_session(self, session_id: str) -> bool:
287
+ """
288
+ Remove session by ID.
289
+
290
+ Args:
291
+ session_id: Session identifier
292
+
293
+ Returns:
294
+ True if session was removed, False if not found
295
+ """
296
+ with self._lock:
297
+ return self._remove_session(session_id)
298
+
299
+ def get_user_sessions(self, user_id: str) -> List[Session]:
300
+ """
301
+ Get all active sessions for a user.
302
+
303
+ Args:
304
+ user_id: User identifier
305
+
306
+ Returns:
307
+ List of active sessions for the user
308
+ """
309
+ with self._lock:
310
+ session_ids = self._user_sessions.get(user_id, set())
311
+ sessions = []
312
+
313
+ for session_id in list(session_ids): # Create copy to avoid modification during iteration
314
+ session = self.get_session(session_id) # This handles expiration
315
+ if session:
316
+ sessions.append(session)
317
+
318
+ return sessions
319
+
320
+ def remove_user_sessions(self, user_id: str) -> int:
321
+ """
322
+ Remove all sessions for a user.
323
+
324
+ Args:
325
+ user_id: User identifier
326
+
327
+ Returns:
328
+ Number of sessions removed
329
+ """
330
+ with self._lock:
331
+ session_ids = self._user_sessions.get(user_id, set()).copy()
332
+ removed_count = 0
333
+
334
+ for session_id in session_ids:
335
+ if self._remove_session(session_id):
336
+ removed_count += 1
337
+
338
+ logger.info(f"Removed {removed_count} sessions for user {user_id}")
339
+ return removed_count
340
+
341
+ def cleanup_expired_sessions(self) -> int:
342
+ """
343
+ Remove expired sessions.
344
+
345
+ Returns:
346
+ Number of sessions removed
347
+ """
348
+ with self._lock:
349
+ expired_session_ids = []
350
+
351
+ for session_id, session in self._sessions.items():
352
+ if session.is_expired():
353
+ expired_session_ids.append(session_id)
354
+
355
+ removed_count = 0
356
+ for session_id in expired_session_ids:
357
+ if self._remove_session(session_id):
358
+ removed_count += 1
359
+
360
+ if removed_count > 0:
361
+ logger.info(f"Cleaned up {removed_count} expired sessions")
362
+
363
+ return removed_count
364
+
365
+ def get_session_stats(self) -> Dict[str, Any]:
366
+ """
367
+ Get session store statistics.
368
+
369
+ Returns:
370
+ Dictionary with session statistics
371
+ """
372
+ with self._lock:
373
+ total_sessions = len(self._sessions)
374
+ active_users = len(self._user_sessions)
375
+
376
+ # Count sessions by user
377
+ sessions_per_user = {}
378
+ for user_id, session_ids in self._user_sessions.items():
379
+ sessions_per_user[user_id] = len(session_ids)
380
+
381
+ return {
382
+ "total_sessions": total_sessions,
383
+ "active_users": active_users,
384
+ "sessions_per_user": sessions_per_user,
385
+ "max_sessions_per_user": self.max_sessions_per_user,
386
+ "default_timeout": self.default_session_timeout,
387
+ }
388
+
389
+ def _generate_session_id(self) -> str:
390
+ """Generate cryptographically secure session ID."""
391
+ return secrets.token_urlsafe(32)
392
+
393
+ def _remove_session(self, session_id: str) -> bool:
394
+ """Internal method to remove session (assumes lock is held)."""
395
+ session = self._sessions.get(session_id)
396
+ if not session:
397
+ return False
398
+
399
+ # Remove from main storage
400
+ del self._sessions[session_id]
401
+
402
+ # Remove from user mapping
403
+ user_sessions = self._user_sessions.get(session.user_id)
404
+ if user_sessions:
405
+ user_sessions.discard(session_id)
406
+ if not user_sessions:
407
+ del self._user_sessions[session.user_id]
408
+
409
+ logger.debug(f"Removed session {session_id} for user {session.user_id}")
410
+
411
+ # Persist if enabled
412
+ if self.enable_persistence:
413
+ self._save_sessions()
414
+
415
+ return True
416
+
417
+ def _cleanup_oldest_user_session(self, user_id: str):
418
+ """Remove oldest session for a user."""
419
+ session_ids = self._user_sessions.get(user_id, set())
420
+ if not session_ids:
421
+ return
422
+
423
+ # Find oldest session
424
+ oldest_session_id = None
425
+ oldest_time = datetime.now(timezone.utc)
426
+
427
+ for session_id in session_ids:
428
+ session = self._sessions.get(session_id)
429
+ if session and session.created_at < oldest_time:
430
+ oldest_time = session.created_at
431
+ oldest_session_id = session_id
432
+
433
+ if oldest_session_id:
434
+ self._remove_session(oldest_session_id)
435
+ logger.info(f"Removed oldest session {oldest_session_id} for user {user_id}")
436
+
437
+ async def _cleanup_loop(self):
438
+ """Background cleanup task."""
439
+ while not self._shutdown:
440
+ try:
441
+ self.cleanup_expired_sessions()
442
+ await asyncio.sleep(self.cleanup_interval)
443
+ except asyncio.CancelledError:
444
+ break
445
+ except Exception as e:
446
+ logger.error(f"Error in session cleanup loop: {e}")
447
+ await asyncio.sleep(self.cleanup_interval)
448
+
449
+ def _save_sessions(self):
450
+ """Save sessions to disk (if persistence enabled)."""
451
+ if not self.enable_persistence:
452
+ return
453
+
454
+ try:
455
+ import json
456
+
457
+ data = {
458
+ "sessions": {
459
+ session_id: session.to_dict()
460
+ for session_id, session in self._sessions.items()
461
+ },
462
+ "user_sessions": {
463
+ user_id: list(session_ids)
464
+ for user_id, session_ids in self._user_sessions.items()
465
+ },
466
+ }
467
+
468
+ with open(self.persistence_file, "w") as f:
469
+ json.dump(data, f, indent=2)
470
+
471
+ logger.debug(f"Saved {len(self._sessions)} sessions to {self.persistence_file}")
472
+
473
+ except Exception as e:
474
+ logger.error(f"Failed to save sessions: {e}")
475
+
476
+ def _load_sessions(self):
477
+ """Load sessions from disk (if persistence enabled)."""
478
+ if not self.enable_persistence:
479
+ return
480
+
481
+ try:
482
+ import json
483
+ import os
484
+
485
+ if not os.path.exists(self.persistence_file):
486
+ return
487
+
488
+ with open(self.persistence_file, "r") as f:
489
+ data = json.load(f)
490
+
491
+ # Load sessions
492
+ for session_id, session_data in data.get("sessions", {}).items():
493
+ try:
494
+ session = Session.from_dict(session_data)
495
+ if not session.is_expired():
496
+ self._sessions[session_id] = session
497
+ except Exception as e:
498
+ logger.warning(f"Failed to load session {session_id}: {e}")
499
+
500
+ # Rebuild user session mappings
501
+ self._user_sessions.clear()
502
+ for session_id, session in self._sessions.items():
503
+ if session.user_id not in self._user_sessions:
504
+ self._user_sessions[session.user_id] = set()
505
+ self._user_sessions[session.user_id].add(session_id)
506
+
507
+ logger.info(f"Loaded {len(self._sessions)} sessions from {self.persistence_file}")
508
+
509
+ except Exception as e:
510
+ logger.error(f"Failed to load sessions: {e}")
511
+
512
+ async def __aenter__(self):
513
+ """Async context manager entry."""
514
+ await self.start_cleanup_task()
515
+ return self
516
+
517
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
518
+ """Async context manager exit."""
519
+ await self.stop_cleanup_task()