google-adk-extras 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,271 @@
1
+ """Redis-based session service implementation."""
2
+
3
+ import json
4
+ import time
5
+ import uuid
6
+ from typing import Any, Optional, Dict
7
+
8
+ try:
9
+ import redis
10
+ from redis.exceptions import RedisError
11
+ except ImportError:
12
+ raise ImportError(
13
+ "Redis is required for RedisSessionService. "
14
+ "Install it with: pip install redis"
15
+ )
16
+
17
+ from google.adk.sessions.session import Session
18
+ from google.adk.events.event import Event
19
+ from google.adk.sessions.base_session_service import GetSessionConfig, ListSessionsResponse
20
+
21
+ from .base_custom_session_service import BaseCustomSessionService
22
+
23
+
24
+ class RedisSessionService(BaseCustomSessionService):
25
+ """Redis-based session service implementation."""
26
+
27
+ def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0, password: Optional[str] = None):
28
+ """Initialize the Redis session service.
29
+
30
+ Args:
31
+ host: Redis host
32
+ port: Redis port
33
+ db: Redis database number
34
+ password: Redis password (if required)
35
+ """
36
+ super().__init__()
37
+ self.host = host
38
+ self.port = port
39
+ self.db = db
40
+ self.password = password
41
+ self.client: Optional[redis.Redis] = None
42
+ self.key_prefix = "adk_session:"
43
+
44
+ async def _initialize_impl(self) -> None:
45
+ """Initialize the Redis connection."""
46
+ try:
47
+ self.client = redis.Redis(
48
+ host=self.host,
49
+ port=self.port,
50
+ db=self.db,
51
+ password=self.password,
52
+ decode_responses=True
53
+ )
54
+ # Test connection
55
+ self.client.ping()
56
+ except RedisError as e:
57
+ raise RuntimeError(f"Failed to initialize Redis session service: {e}")
58
+
59
+ async def _cleanup_impl(self) -> None:
60
+ """Clean up Redis connections."""
61
+ if self.client:
62
+ self.client.close()
63
+ self.client = None
64
+
65
+ def _get_session_key(self, app_name: str, user_id: str, session_id: str) -> str:
66
+ """Generate Redis key for a session."""
67
+ return f"{self.key_prefix}{app_name}:{user_id}:{session_id}"
68
+
69
+ def _get_user_sessions_key(self, app_name: str, user_id: str) -> str:
70
+ """Generate Redis key for user sessions set."""
71
+ return f"{self.key_prefix}{app_name}:{user_id}:sessions"
72
+
73
+ def _serialize_state(self, state: dict[str, Any]) -> str:
74
+ """Serialize session state to JSON string."""
75
+ try:
76
+ return json.dumps(state)
77
+ except (TypeError, ValueError) as e:
78
+ raise ValueError(f"Failed to serialize state: {e}")
79
+
80
+ def _deserialize_state(self, state_str: str) -> dict[str, Any]:
81
+ """Deserialize session state from JSON string."""
82
+ try:
83
+ return json.loads(state_str) if state_str else {}
84
+ except (TypeError, ValueError) as e:
85
+ raise ValueError(f"Failed to deserialize state: {e}")
86
+
87
+ def _serialize_events(self, events: list[Event]) -> str:
88
+ """Serialize events to JSON string."""
89
+ try:
90
+ event_dicts = [event.model_dump() for event in events]
91
+ return json.dumps(event_dicts)
92
+ except (TypeError, ValueError) as e:
93
+ raise ValueError(f"Failed to serialize events: {e}")
94
+
95
+ def _deserialize_events(self, events_str: str) -> list[Event]:
96
+ """Deserialize events from JSON string."""
97
+ try:
98
+ event_dicts = json.loads(events_str) if events_str else []
99
+ return [Event(**event_dict) for event_dict in event_dicts]
100
+ except (TypeError, ValueError) as e:
101
+ raise ValueError(f"Failed to deserialize events: {e}")
102
+
103
+ async def _create_session_impl(
104
+ self,
105
+ *,
106
+ app_name: str,
107
+ user_id: str,
108
+ state: Optional[dict[str, Any]] = None,
109
+ session_id: Optional[str] = None,
110
+ ) -> Session:
111
+ """Implementation of session creation."""
112
+ try:
113
+ # Generate session ID if not provided
114
+ session_id = session_id or str(uuid.uuid4())
115
+
116
+ # Create session object
117
+ session = Session(
118
+ id=session_id,
119
+ app_name=app_name,
120
+ user_id=user_id,
121
+ state=state or {},
122
+ events=[],
123
+ last_update_time=time.time()
124
+ )
125
+
126
+ # Serialize data for storage
127
+ session_data = {
128
+ "id": session_id,
129
+ "app_name": app_name,
130
+ "user_id": user_id,
131
+ "state": self._serialize_state(session.state),
132
+ "events": self._serialize_events(session.events),
133
+ "last_update_time": session.last_update_time
134
+ }
135
+
136
+ # Store session data in Redis
137
+ session_key = self._get_session_key(app_name, user_id, session_id)
138
+ self.client.hset(session_key, mapping=session_data)
139
+
140
+ # Add to user sessions set
141
+ user_sessions_key = self._get_user_sessions_key(app_name, user_id)
142
+ self.client.sadd(user_sessions_key, session_id)
143
+
144
+ # Set expiration (optional, can be configured)
145
+ # self.client.expire(session_key, 86400) # 24 hours
146
+
147
+ return session
148
+ except RedisError as e:
149
+ raise RuntimeError(f"Failed to create session: {e}")
150
+
151
+ async def _get_session_impl(
152
+ self,
153
+ *,
154
+ app_name: str,
155
+ user_id: str,
156
+ session_id: str,
157
+ config: Optional[GetSessionConfig] = None,
158
+ ) -> Optional[Session]:
159
+ """Implementation of session retrieval."""
160
+ try:
161
+ # Retrieve from Redis
162
+ session_key = self._get_session_key(app_name, user_id, session_id)
163
+ session_data = self.client.hgetall(session_key)
164
+
165
+ if not session_data:
166
+ return None
167
+
168
+ # Create session object
169
+ session = Session(
170
+ id=session_data["id"],
171
+ app_name=session_data["app_name"],
172
+ user_id=session_data["user_id"],
173
+ state=self._deserialize_state(session_data["state"]),
174
+ events=self._deserialize_events(session_data["events"]),
175
+ last_update_time=float(session_data["last_update_time"])
176
+ )
177
+
178
+ # Apply config filters if provided
179
+ if config:
180
+ if config.num_recent_events:
181
+ session.events = session.events[-config.num_recent_events:]
182
+ if config.after_timestamp:
183
+ filtered_events = [
184
+ event for event in session.events
185
+ if event.timestamp >= config.after_timestamp
186
+ ]
187
+ session.events = filtered_events
188
+
189
+ return session
190
+ except RedisError as e:
191
+ raise RuntimeError(f"Failed to get session: {e}")
192
+ except (KeyError, ValueError) as e:
193
+ raise RuntimeError(f"Failed to deserialize session data: {e}")
194
+
195
+ async def _list_sessions_impl(
196
+ self,
197
+ *,
198
+ app_name: str,
199
+ user_id: str
200
+ ) -> ListSessionsResponse:
201
+ """Implementation of session listing."""
202
+ try:
203
+ # Get all session IDs for user
204
+ user_sessions_key = self._get_user_sessions_key(app_name, user_id)
205
+ session_ids = self.client.smembers(user_sessions_key)
206
+
207
+ # Create session objects without events for performance
208
+ sessions = []
209
+ for session_id in session_ids:
210
+ session_key = self._get_session_key(app_name, user_id, session_id)
211
+ session_data = self.client.hgetall(session_key)
212
+
213
+ if session_data:
214
+ session = Session(
215
+ id=session_data["id"],
216
+ app_name=session_data["app_name"],
217
+ user_id=session_data["user_id"],
218
+ state=self._deserialize_state(session_data["state"]),
219
+ events=[], # Empty events for listing
220
+ last_update_time=float(session_data["last_update_time"])
221
+ )
222
+ sessions.append(session)
223
+
224
+ return ListSessionsResponse(sessions=sessions)
225
+ except RedisError as e:
226
+ raise RuntimeError(f"Failed to list sessions: {e}")
227
+
228
+ async def _delete_session_impl(
229
+ self,
230
+ *,
231
+ app_name: str,
232
+ user_id: str,
233
+ session_id: str
234
+ ) -> None:
235
+ """Implementation of session deletion."""
236
+ try:
237
+ # Remove from Redis
238
+ session_key = self._get_session_key(app_name, user_id, session_id)
239
+ self.client.delete(session_key)
240
+
241
+ # Remove from user sessions set
242
+ user_sessions_key = self._get_user_sessions_key(app_name, user_id)
243
+ self.client.srem(user_sessions_key, session_id)
244
+ except RedisError as e:
245
+ raise RuntimeError(f"Failed to delete session: {e}")
246
+
247
+ async def _append_event_impl(self, session: Session, event: Event) -> None:
248
+ """Implementation of event appending."""
249
+ try:
250
+ session_key = self._get_session_key(session.app_name, session.user_id, session.id)
251
+
252
+ # Update session data
253
+ update_data = {
254
+ "events": self._serialize_events(session.events),
255
+ "last_update_time": session.last_update_time
256
+ }
257
+
258
+ # Apply state changes from event if present
259
+ if event.actions and event.actions.state_delta:
260
+ # Get current state and update it
261
+ current_state_str = self.client.hget(session_key, "state")
262
+ current_state = self._deserialize_state(current_state_str) if current_state_str else {}
263
+ current_state.update(event.actions.state_delta)
264
+ update_data["state"] = self._serialize_state(current_state)
265
+
266
+ # Update in Redis
267
+ self.client.hset(session_key, mapping=update_data)
268
+ except RedisError as e:
269
+ raise RuntimeError(f"Failed to append event: {e}")
270
+ except (KeyError, ValueError) as e:
271
+ raise RuntimeError(f"Failed to update session data: {e}")
@@ -0,0 +1,308 @@
1
+ """SQL-based session service implementation using SQLAlchemy."""
2
+
3
+ import json
4
+ import logging
5
+ from typing import Any, Optional
6
+ from datetime import datetime, timezone
7
+
8
+ try:
9
+ from sqlalchemy import create_engine, Column, String, Text, DateTime
10
+ from sqlalchemy.orm import declarative_base, sessionmaker
11
+ from sqlalchemy.exc import SQLAlchemyError
12
+ except ImportError:
13
+ raise ImportError(
14
+ "SQLAlchemy is required for SQLSessionService. "
15
+ "Install it with: pip install sqlalchemy"
16
+ )
17
+
18
+ from google.adk.events.event import Event
19
+ from .base_custom_session_service import BaseCustomSessionService
20
+
21
+
22
+ logger = logging.getLogger('google_adk_extras.' + __name__)
23
+
24
+ # Use the modern declarative_base import
25
+ Base = declarative_base()
26
+
27
+
28
+ class SQLSessionModel(Base):
29
+ """SQLAlchemy model for storing sessions."""
30
+ __tablename__ = 'adk_sessions'
31
+
32
+ # Primary key
33
+ id = Column(String, primary_key=True)
34
+
35
+ # Session identifiers
36
+ app_name = Column(String, nullable=False, index=True)
37
+ user_id = Column(String, nullable=False, index=True)
38
+
39
+ # Session data
40
+ state = Column(Text, nullable=False) # JSON string
41
+ events = Column(Text, nullable=False) # JSON string
42
+ last_update_time = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc))
43
+
44
+
45
+ class SQLSessionService(BaseCustomSessionService):
46
+ """SQL-based session service implementation."""
47
+
48
+ def __init__(self, database_url: str):
49
+ """Initialize the SQL session service.
50
+
51
+ Args:
52
+ database_url: Database connection string (e.g., 'sqlite:///sessions.db')
53
+ """
54
+ super().__init__()
55
+ self.database_url = database_url
56
+ self.engine: Optional[object] = None
57
+ self.session_local: Optional[object] = None
58
+
59
+ async def _initialize_impl(self) -> None:
60
+ """Initialize the database connection and create tables."""
61
+ try:
62
+ self.engine = create_engine(self.database_url)
63
+ Base.metadata.create_all(self.engine)
64
+ self.session_local = sessionmaker(
65
+ autocommit=False,
66
+ autoflush=False,
67
+ bind=self.engine
68
+ )
69
+ except SQLAlchemyError as e:
70
+ raise RuntimeError(f"Failed to initialize SQL session service: {e}")
71
+
72
+ async def _cleanup_impl(self) -> None:
73
+ """Clean up database connections."""
74
+ if self.engine:
75
+ self.engine.dispose()
76
+ self.engine = None
77
+ self.session_local = None
78
+
79
+ def _get_db_session(self):
80
+ """Get a database session."""
81
+ if not self.session_local:
82
+ raise RuntimeError("Service not initialized")
83
+ return self.session_local()
84
+
85
+ def _serialize_state(self, state: dict[str, Any]) -> str:
86
+ """Serialize session state to JSON string."""
87
+ try:
88
+ return json.dumps(state)
89
+ except (TypeError, ValueError) as e:
90
+ raise ValueError(f"Failed to serialize state: {e}")
91
+
92
+ def _deserialize_state(self, state_str: str) -> dict[str, Any]:
93
+ """Deserialize session state from JSON string."""
94
+ try:
95
+ return json.loads(state_str) if state_str else {}
96
+ except (TypeError, ValueError) as e:
97
+ raise ValueError(f"Failed to deserialize state: {e}")
98
+
99
+ def _serialize_events(self, events: list[Event]) -> str:
100
+ """Serialize events to JSON string."""
101
+ try:
102
+ # Convert events to dictionaries
103
+ event_dicts = [event.model_dump() for event in events]
104
+ return json.dumps(event_dicts)
105
+ except (TypeError, ValueError) as e:
106
+ raise ValueError(f"Failed to serialize events: {e}")
107
+
108
+ def _deserialize_events(self, events_str: str) -> list[Event]:
109
+ """Deserialize events from JSON string."""
110
+ try:
111
+ event_dicts = json.loads(events_str) if events_str else []
112
+ return [Event(**event_dict) for event_dict in event_dicts]
113
+ except (TypeError, ValueError) as e:
114
+ raise ValueError(f"Failed to deserialize events: {e}")
115
+
116
+ async def _create_session_impl(
117
+ self,
118
+ *,
119
+ app_name: str,
120
+ user_id: str,
121
+ state: Optional[dict[str, Any]] = None,
122
+ session_id: Optional[str] = None,
123
+ ) -> "Session":
124
+ """Implementation of session creation."""
125
+ # Import Session inside the function to avoid circular import
126
+ from google.adk.sessions.session import Session
127
+ import time
128
+ import uuid
129
+
130
+ # Generate session ID if not provided
131
+ session_id = session_id or str(uuid.uuid4())
132
+
133
+ # Create session object
134
+ session = Session(
135
+ id=session_id,
136
+ app_name=app_name,
137
+ user_id=user_id,
138
+ state=state or {},
139
+ events=[],
140
+ last_update_time=time.time()
141
+ )
142
+
143
+ # Save to database
144
+ db_session = self._get_db_session()
145
+ try:
146
+ db_session_model = SQLSessionModel(
147
+ id=session_id,
148
+ app_name=app_name,
149
+ user_id=user_id,
150
+ state=self._serialize_state(session.state),
151
+ events=self._serialize_events(session.events),
152
+ last_update_time=datetime.fromtimestamp(session.last_update_time, tz=timezone.utc)
153
+ )
154
+
155
+ db_session.add(db_session_model)
156
+ db_session.commit()
157
+ except SQLAlchemyError as e:
158
+ db_session.rollback()
159
+ raise RuntimeError(f"Failed to create session: {e}")
160
+ finally:
161
+ db_session.close()
162
+
163
+ return session
164
+
165
+ async def _get_session_impl(
166
+ self,
167
+ *,
168
+ app_name: str,
169
+ user_id: str,
170
+ session_id: str,
171
+ config: Optional["GetSessionConfig"] = None,
172
+ ) -> Optional["Session"]:
173
+ """Implementation of session retrieval."""
174
+ from google.adk.sessions.base_session_service import GetSessionConfig
175
+
176
+ db_session = self._get_db_session()
177
+ try:
178
+ db_session_model = db_session.query(SQLSessionModel).filter(
179
+ SQLSessionModel.id == session_id,
180
+ SQLSessionModel.app_name == app_name,
181
+ SQLSessionModel.user_id == user_id
182
+ ).first()
183
+
184
+ if not db_session_model:
185
+ return None
186
+
187
+ # Create session object
188
+ # Import Session inside the function to avoid circular import
189
+ from google.adk.sessions.session import Session
190
+ session = Session(
191
+ id=db_session_model.id,
192
+ app_name=db_session_model.app_name,
193
+ user_id=db_session_model.user_id,
194
+ state=self._deserialize_state(db_session_model.state),
195
+ events=self._deserialize_events(db_session_model.events),
196
+ last_update_time=db_session_model.last_update_time.timestamp()
197
+ )
198
+
199
+ # Apply config filters if provided
200
+ if config:
201
+ if config.num_recent_events:
202
+ session.events = session.events[-config.num_recent_events:]
203
+ if config.after_timestamp:
204
+ filtered_events = [
205
+ event for event in session.events
206
+ if event.timestamp >= config.after_timestamp
207
+ ]
208
+ session.events = filtered_events
209
+
210
+ return session
211
+ except SQLAlchemyError as e:
212
+ raise RuntimeError(f"Failed to get session: {e}")
213
+ finally:
214
+ db_session.close()
215
+
216
+ async def _list_sessions_impl(
217
+ self,
218
+ *,
219
+ app_name: str,
220
+ user_id: str
221
+ ) -> "ListSessionsResponse":
222
+ """Implementation of session listing."""
223
+ from google.adk.sessions.base_session_service import ListSessionsResponse
224
+
225
+ db_session = self._get_db_session()
226
+ try:
227
+ # Retrieve all sessions for user (without events)
228
+ db_session_models = db_session.query(SQLSessionModel).filter(
229
+ SQLSessionModel.app_name == app_name,
230
+ SQLSessionModel.user_id == user_id
231
+ ).all()
232
+
233
+ # Create session objects without events
234
+ sessions = []
235
+ for db_model in db_session_models:
236
+ # Import Session inside the function to avoid circular import
237
+ from google.adk.sessions.session import Session
238
+ session = Session(
239
+ id=db_model.id,
240
+ app_name=db_model.app_name,
241
+ user_id=db_model.user_id,
242
+ state=self._deserialize_state(db_model.state),
243
+ events=[], # Empty events for listing
244
+ last_update_time=db_model.last_update_time.timestamp()
245
+ )
246
+ sessions.append(session)
247
+
248
+ return ListSessionsResponse(sessions=sessions)
249
+ except SQLAlchemyError as e:
250
+ raise RuntimeError(f"Failed to list sessions: {e}")
251
+ finally:
252
+ db_session.close()
253
+
254
+ async def _delete_session_impl(
255
+ self,
256
+ *,
257
+ app_name: str,
258
+ user_id: str,
259
+ session_id: str
260
+ ) -> None:
261
+ """Implementation of session deletion."""
262
+ db_session = self._get_db_session()
263
+ try:
264
+ # Delete from database
265
+ db_session.query(SQLSessionModel).filter(
266
+ SQLSessionModel.id == session_id,
267
+ SQLSessionModel.app_name == app_name,
268
+ SQLSessionModel.user_id == user_id
269
+ ).delete()
270
+
271
+ db_session.commit()
272
+ except SQLAlchemyError as e:
273
+ db_session.rollback()
274
+ raise RuntimeError(f"Failed to delete session: {e}")
275
+ finally:
276
+ db_session.close()
277
+
278
+ async def _append_event_impl(self, session: "Session", event: Event) -> None:
279
+ """Implementation of event appending."""
280
+ db_session = self._get_db_session()
281
+ try:
282
+ # Update session in database
283
+ db_session_model = db_session.query(SQLSessionModel).filter(
284
+ SQLSessionModel.id == session.id,
285
+ SQLSessionModel.app_name == session.app_name,
286
+ SQLSessionModel.user_id == session.user_id
287
+ ).first()
288
+
289
+ if not db_session_model:
290
+ raise ValueError(f"Session {session.id} not found")
291
+
292
+ # Update the session model
293
+ db_session_model.events = self._serialize_events(session.events)
294
+ db_session_model.last_update_time = datetime.fromtimestamp(session.last_update_time, tz=timezone.utc)
295
+
296
+ # Apply state changes from event if present
297
+ if event.actions and event.actions.state_delta:
298
+ # Update state in the session model
299
+ current_state = self._deserialize_state(db_session_model.state)
300
+ current_state.update(event.actions.state_delta)
301
+ db_session_model.state = self._serialize_state(current_state)
302
+
303
+ db_session.commit()
304
+ except SQLAlchemyError as e:
305
+ db_session.rollback()
306
+ raise RuntimeError(f"Failed to append event: {e}")
307
+ finally:
308
+ db_session.close()