chuk-ai-session-manager 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,295 @@
1
+ # chuk_ai_session_manager/storage/providers/redis.py
2
+ """
3
+ Async Redis-based session storage implementation.
4
+ """
5
+ import json
6
+ import logging
7
+ import asyncio
8
+ from datetime import datetime
9
+ from typing import Any, Dict, List, Optional, Type, TypeVar, Generic, Union, cast
10
+
11
+ # Note: redis is an optional dependency, so we import it conditionally
12
+ try:
13
+ import redis.asyncio as aioredis
14
+ from redis.asyncio import Redis
15
+ from redis.exceptions import RedisError
16
+ AIOREDIS_AVAILABLE = True
17
+ except ImportError:
18
+ AIOREDIS_AVAILABLE = False
19
+ # Define a dummy class for type checking
20
+ class Redis: # type: ignore
21
+ pass
22
+ # Standard redis for sync fallback if needed
23
+ try:
24
+ import redis
25
+ REDIS_AVAILABLE = True
26
+ except ImportError:
27
+ REDIS_AVAILABLE = False
28
+
29
+ from chuk_ai_session_manager.models.session import Session
30
+ from chuk_ai_session_manager.storage.base import SessionStoreInterface
31
+ from chuk_ai_session_manager.exceptions import SessionManagerError
32
+
33
+ # Type variable for serializable models
34
+ T = TypeVar('T', bound='Session')
35
+
36
+ # Setup logging
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class RedisStorageError(SessionManagerError):
41
+ """Raised when Redis storage operations fail."""
42
+ pass
43
+
44
+
45
+ class RedisSessionStore(SessionStoreInterface, Generic[T]):
46
+ """
47
+ An async session store that persists sessions to Redis.
48
+
49
+ This implementation stores sessions as JSON documents in Redis,
50
+ with configurable key prefixes and expiration.
51
+ """
52
+
53
+ def __init__(self,
54
+ redis_client: Any, # Can be async or sync Redis client
55
+ key_prefix: str = "session:",
56
+ expiration_seconds: Optional[int] = None,
57
+ session_class: Type[T] = Session,
58
+ auto_save: bool = True):
59
+ """
60
+ Initialize the async Redis session store.
61
+
62
+ Args:
63
+ redis_client: Pre-configured Redis client
64
+ key_prefix: Prefix for Redis keys
65
+ expiration_seconds: Optional TTL for sessions
66
+ session_class: The Session class to use for deserialization
67
+ auto_save: Whether to automatically save on each update
68
+ """
69
+ if not (AIOREDIS_AVAILABLE or REDIS_AVAILABLE):
70
+ raise ImportError(
71
+ "Redis package is not installed. "
72
+ "Install it with 'pip install redis[asyncio]'."
73
+ )
74
+
75
+ self.redis = redis_client
76
+ self.is_client = AIOREDIS_AVAILABLE and isinstance(redis_client, aioredis.Redis)
77
+ self.key_prefix = key_prefix
78
+ self.expiration_seconds = expiration_seconds
79
+ self.session_class = session_class
80
+ self.auto_save = auto_save
81
+ # In-memory cache for better performance
82
+ self._cache: Dict[str, T] = {}
83
+
84
+ def _get_key(self, session_id: str) -> str:
85
+ """Get the Redis key for a session ID."""
86
+ return f"{self.key_prefix}{session_id}"
87
+
88
+ def _json_default(self, obj: Any) -> Any:
89
+ """Handle non-serializable objects in JSON serialization."""
90
+ if isinstance(obj, datetime):
91
+ return obj.isoformat()
92
+ raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
93
+
94
+ async def get(self, session_id: str) -> Optional[T]:
95
+ """Async: Retrieve a session by its ID."""
96
+ # Check cache first
97
+ if session_id in self._cache:
98
+ return self._cache[session_id]
99
+
100
+ # If not in cache, try to load from Redis
101
+ key = self._get_key(session_id)
102
+ try:
103
+ if self.is_client:
104
+ data = await self.redis.get(key)
105
+ else:
106
+ # Fall back to sync client in executor if needed
107
+ loop = asyncio.get_event_loop()
108
+ data = await loop.run_in_executor(None, lambda: self.redis.get(key))
109
+
110
+ if not data:
111
+ return None
112
+
113
+ # Convert bytes to str if needed
114
+ if isinstance(data, bytes):
115
+ data = data.decode('utf-8')
116
+
117
+ session_dict = json.loads(data)
118
+ session = cast(T, self.session_class.model_validate(session_dict))
119
+
120
+ # Update cache
121
+ self._cache[session_id] = session
122
+ return session
123
+ except (RedisError, json.JSONDecodeError) as e:
124
+ logger.error(f"Failed to load session {session_id} from Redis: {e}")
125
+ return None
126
+
127
+ async def save(self, session: T) -> None:
128
+ """Async: Save a session to the store."""
129
+ session_id = session.id
130
+ # Update cache
131
+ self._cache[session_id] = session
132
+
133
+ if self.auto_save:
134
+ await self._save_to_redis(session)
135
+
136
+ async def _save_to_redis(self, session: T) -> None:
137
+ """Async: Save a session to Redis."""
138
+ session_id = session.id
139
+ key = self._get_key(session_id)
140
+
141
+ try:
142
+ # Convert session to JSON
143
+ session_dict = session.model_dump()
144
+ data = json.dumps(session_dict, default=self._json_default)
145
+
146
+ # Save to Redis with optional expiration
147
+ if self.is_client:
148
+ if self.expiration_seconds:
149
+ await self.redis.setex(key, self.expiration_seconds, data)
150
+ else:
151
+ await self.redis.set(key, data)
152
+ else:
153
+ # Fall back to sync client in executor if needed
154
+ loop = asyncio.get_event_loop()
155
+ if self.expiration_seconds:
156
+ await loop.run_in_executor(
157
+ None,
158
+ lambda: self.redis.setex(key, self.expiration_seconds, data)
159
+ )
160
+ else:
161
+ await loop.run_in_executor(
162
+ None,
163
+ lambda: self.redis.set(key, data)
164
+ )
165
+ except (RedisError, TypeError) as e:
166
+ logger.error(f"Failed to save session {session_id} to Redis: {e}")
167
+ raise RedisStorageError(f"Failed to save session {session_id}: {str(e)}")
168
+
169
+ async def delete(self, session_id: str) -> None:
170
+ """Async: Delete a session by its ID."""
171
+ # Remove from cache
172
+ if session_id in self._cache:
173
+ del self._cache[session_id]
174
+
175
+ # Remove from Redis
176
+ key = self._get_key(session_id)
177
+ try:
178
+ if self.is_client:
179
+ await self.redis.delete(key)
180
+ else:
181
+ # Fall back to sync client in executor if needed
182
+ loop = asyncio.get_event_loop()
183
+ await loop.run_in_executor(None, lambda: self.redis.delete(key))
184
+ except RedisError as e:
185
+ logger.error(f"Failed to delete session {session_id} from Redis: {e}")
186
+ raise RedisStorageError(f"Failed to delete session {session_id}: {str(e)}")
187
+
188
+ async def list_sessions(self, prefix: str = "") -> List[str]:
189
+ """Async: List all session IDs, optionally filtered by prefix."""
190
+ search_pattern = f"{self.key_prefix}{prefix}*"
191
+ try:
192
+ # Get all keys matching the pattern
193
+ if self.is_client:
194
+ keys = await self.redis.keys(search_pattern)
195
+ else:
196
+ # Fall back to sync client in executor if needed
197
+ loop = asyncio.get_event_loop()
198
+ keys = await loop.run_in_executor(None, lambda: self.redis.keys(search_pattern))
199
+
200
+ # Extract session IDs by removing the prefix
201
+ session_ids = [
202
+ key.decode('utf-8').replace(self.key_prefix, '') if isinstance(key, bytes)
203
+ else key.replace(self.key_prefix, '')
204
+ for key in keys
205
+ ]
206
+ return session_ids
207
+ except RedisError as e:
208
+ logger.error(f"Failed to list sessions from Redis: {e}")
209
+ raise RedisStorageError(f"Failed to list sessions: {str(e)}")
210
+
211
+ async def flush(self) -> None:
212
+ """Async: Force save all cached sessions to Redis."""
213
+ for session in self._cache.values():
214
+ try:
215
+ await self._save_to_redis(session)
216
+ except RedisStorageError:
217
+ # Already logged in _save_to_redis
218
+ pass
219
+
220
+ async def clear_cache(self) -> None:
221
+ """Async: Clear the in-memory cache."""
222
+ self._cache.clear()
223
+
224
+ async def set_expiration(self, session_id: str, seconds: int) -> None:
225
+ """Async: Set or update expiration for a session."""
226
+ key = self._get_key(session_id)
227
+ try:
228
+ if self.is_client:
229
+ await self.redis.expire(key, seconds)
230
+ else:
231
+ # Fall back to sync client in executor if needed
232
+ loop = asyncio.get_event_loop()
233
+ await loop.run_in_executor(None, lambda: self.redis.expire(key, seconds))
234
+ except RedisError as e:
235
+ logger.error(f"Failed to set expiration for session {session_id}: {e}")
236
+ raise RedisStorageError(f"Failed to set expiration for session {session_id}: {str(e)}")
237
+
238
+
239
+ async def create_redis_session_store(
240
+ host: str = "localhost",
241
+ port: int = 6379,
242
+ db: int = 0,
243
+ password: Optional[str] = None,
244
+ key_prefix: str = "session:",
245
+ expiration_seconds: Optional[int] = None,
246
+ session_class: Type[T] = Session,
247
+ auto_save: bool = True,
248
+ **redis_kwargs: Any
249
+ ) -> RedisSessionStore[T]:
250
+ """
251
+ Create an async Redis-based session store.
252
+
253
+ Args:
254
+ host: Redis host
255
+ port: Redis port
256
+ db: Redis database number
257
+ password: Optional Redis password
258
+ key_prefix: Prefix for Redis keys
259
+ expiration_seconds: Optional TTL for sessions
260
+ session_class: The Session class to use
261
+ auto_save: Whether to automatically save on each update
262
+ **redis_kwargs: Additional arguments for Redis client
263
+
264
+ Returns:
265
+ A configured RedisSessionStore
266
+ """
267
+ if AIOREDIS_AVAILABLE:
268
+ # Use async Redis client
269
+ redis_client = aioredis.Redis(
270
+ host=host,
271
+ port=port,
272
+ db=db,
273
+ password=password,
274
+ **redis_kwargs
275
+ )
276
+ elif REDIS_AVAILABLE:
277
+ # Fall back to sync Redis client
278
+ redis_client = redis.Redis(
279
+ host=host,
280
+ port=port,
281
+ db=db,
282
+ password=password,
283
+ **redis_kwargs
284
+ )
285
+ logger.warning("Using synchronous Redis client. Install 'redis[asyncio]' for better performance.")
286
+ else:
287
+ raise ImportError("Redis package is not installed. Install with 'pip install redis[asyncio]'")
288
+
289
+ return RedisSessionStore(
290
+ redis_client=redis_client,
291
+ key_prefix=key_prefix,
292
+ expiration_seconds=expiration_seconds,
293
+ session_class=session_class,
294
+ auto_save=auto_save
295
+ )