rakam-systems-agent 0.1.1rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,668 @@
1
+ """PostgreSQL-Based Chat History Manager.
2
+
3
+ This module provides a ChatHistoryComponent implementation that stores
4
+ chat history in a PostgreSQL database. Suitable for production deployments
5
+ requiring persistent, structured storage with multi-instance support.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import os
12
+ from typing import Any, Dict, List, Optional
13
+ from psycopg2.pool import SimpleConnectionPool
14
+ from rakam_systems_core.ai_core.interfaces.chat_history import \
15
+ ChatHistoryComponent
16
+
17
+ # Optional pydantic-ai integration
18
+ try:
19
+ from pydantic_ai.messages import ModelMessage, ModelMessagesTypeAdapter
20
+ from pydantic_core import to_jsonable_python
21
+ PYDANTIC_AI_AVAILABLE = True
22
+ except ImportError:
23
+ PYDANTIC_AI_AVAILABLE = False
24
+ ModelMessagesTypeAdapter = None # type: ignore
25
+ ModelMessage = None # type: ignore
26
+ to_jsonable_python = None # type: ignore
27
+
28
+
29
+ class PostgresChatHistory(ChatHistoryComponent):
30
+ """Chat history manager using PostgreSQL database storage.
31
+
32
+ This implementation stores all chat histories in a PostgreSQL database.
33
+ It's suitable for:
34
+ - Production deployments
35
+ - Multi-instance applications with concurrent access
36
+ - Applications requiring structured queries and scalability
37
+ - Large scale applications
38
+
39
+ Config options:
40
+ host: PostgreSQL host (default: "localhost")
41
+ port: PostgreSQL port (default: 5432)
42
+ database: Database name (default: "vectorstore_db")
43
+ user: Database user (default: "postgres")
44
+ password: Database password (default: "postgres")
45
+ schema: Schema name for chat tables (default: "public")
46
+ min_connections: Minimum pool connections (default: 1)
47
+ max_connections: Maximum pool connections (default: 10)
48
+
49
+ Environment variables (override config):
50
+ POSTGRES_HOST, POSTGRES_PORT, POSTGRES_DB, POSTGRES_USER, POSTGRES_PASSWORD
51
+
52
+ Example:
53
+ >>> history = PostgresChatHistory(
54
+ ... config={
55
+ ... "host": "localhost",
56
+ ... "database": "vectorstore_db",
57
+ ... "user": "postgres",
58
+ ... "password": "postgres"
59
+ ... }
60
+ ... )
61
+ >>> history.setup()
62
+ >>> history.add_message("chat123", {"role": "user", "content": "Hello"})
63
+ >>> history.add_message("chat123", {"role": "assistant", "content": "Hi there!"})
64
+ >>> messages = history.get_chat_history("chat123")
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ name: str = "postgres_chat_history",
70
+ config: Optional[Dict[str, Any]] = None,
71
+ **kwargs
72
+ ) -> None:
73
+ """Initialize the PostgreSQL chat history manager.
74
+
75
+ Args:
76
+ name: Component name for identification.
77
+ config: Configuration dictionary. Supports:
78
+ - host: PostgreSQL host
79
+ - port: PostgreSQL port
80
+ - database: Database name
81
+ - user: Database user
82
+ - password: Database password
83
+ - schema: Schema name (default: "public")
84
+ - min_connections: Min pool size (default: 1)
85
+ - max_connections: Max pool size (default: 10)
86
+ **kwargs: Direct parameter overrides (host, port, database, user, password, schema).
87
+ """
88
+ super().__init__(name, config)
89
+
90
+ # Get connection parameters from kwargs, config, or environment
91
+ self.host = kwargs.get('host') or self.config.get(
92
+ 'host') or os.getenv('POSTGRES_HOST', 'localhost')
93
+ self.port = kwargs.get('port') or self.config.get(
94
+ 'port') or int(os.getenv('POSTGRES_PORT', '5432'))
95
+ self.database = kwargs.get('database') or self.config.get(
96
+ 'database') or os.getenv('POSTGRES_DB', 'vectorstore_db')
97
+ self.user = kwargs.get('user') or self.config.get(
98
+ 'user') or os.getenv('POSTGRES_USER', 'postgres')
99
+ self.password = kwargs.get('password') or self.config.get(
100
+ 'password') or os.getenv('POSTGRES_PASSWORD', 'postgres')
101
+ self.schema = kwargs.get(
102
+ 'schema') or self.config.get('schema', 'public')
103
+
104
+ # Connection pool settings
105
+ self.min_connections = self.config.get('min_connections', 1)
106
+ self.max_connections = self.config.get('max_connections', 10)
107
+
108
+ self._pool: Optional[SimpleConnectionPool] = None
109
+
110
+ def setup(self) -> None:
111
+ """Initialize database connection pool and create tables."""
112
+ self._initialize_pool()
113
+ self._initialize_database()
114
+ super().setup()
115
+
116
+ def shutdown(self) -> None:
117
+ """Cleanup resources and close connection pool."""
118
+ if self._pool:
119
+ self._pool.closeall()
120
+ self._pool = None
121
+ super().shutdown()
122
+
123
+ def _initialize_pool(self) -> None:
124
+ """Initialize PostgreSQL connection pool.
125
+
126
+ Raises:
127
+ Exception: If connection pool initialization fails.
128
+ """
129
+ try:
130
+ self._pool = SimpleConnectionPool(
131
+ self.min_connections,
132
+ self.max_connections,
133
+ host=self.host,
134
+ port=self.port,
135
+ database=self.database,
136
+ user=self.user,
137
+ password=self.password
138
+ )
139
+ except Exception as e:
140
+ raise Exception(
141
+ f"Failed to initialize PostgreSQL connection pool: {e}")
142
+
143
+ def _get_connection(self):
144
+ """Get a connection from the pool.
145
+
146
+ Returns:
147
+ psycopg2 connection object.
148
+
149
+ Raises:
150
+ Exception: If pool is not initialized or connection fails.
151
+ """
152
+ if not self._pool:
153
+ raise Exception(
154
+ "Connection pool not initialized. Call setup() first.")
155
+ return self._pool.getconn()
156
+
157
+ def _return_connection(self, conn) -> None:
158
+ """Return a connection to the pool.
159
+
160
+ Args:
161
+ conn: Connection to return.
162
+ """
163
+ if self._pool:
164
+ self._pool.putconn(conn)
165
+
166
+ def _initialize_database(self) -> None:
167
+ """Initialize PostgreSQL database and create necessary tables.
168
+
169
+ Creates the chat_sessions and chat_messages tables if they don't exist.
170
+ Uses a separate schema if specified.
171
+
172
+ Raises:
173
+ Exception: If database initialization fails.
174
+ """
175
+ conn = None
176
+ try:
177
+ conn = self._get_connection()
178
+ cursor = conn.cursor()
179
+
180
+ # Create schema if not exists (if not using public)
181
+ if self.schema != 'public':
182
+ cursor.execute(f'CREATE SCHEMA IF NOT EXISTS {self.schema}')
183
+
184
+ # Create chat_sessions table
185
+ cursor.execute(f'''
186
+ CREATE TABLE IF NOT EXISTS {self.schema}.chat_sessions (
187
+ chat_id TEXT PRIMARY KEY,
188
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
189
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
190
+ )
191
+ ''')
192
+
193
+ # Create chat_messages table
194
+ cursor.execute(f'''
195
+ CREATE TABLE IF NOT EXISTS {self.schema}.chat_messages (
196
+ id SERIAL PRIMARY KEY,
197
+ chat_id TEXT NOT NULL,
198
+ message_order INTEGER NOT NULL,
199
+ message_data JSONB NOT NULL,
200
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
201
+ FOREIGN KEY (chat_id) REFERENCES {self.schema}.chat_sessions (chat_id) ON DELETE CASCADE
202
+ )
203
+ ''')
204
+
205
+ # Create indexes for faster lookups
206
+ cursor.execute(f'''
207
+ CREATE INDEX IF NOT EXISTS idx_chat_messages_chat_id
208
+ ON {self.schema}.chat_messages (chat_id, message_order)
209
+ ''')
210
+
211
+ cursor.execute(f'''
212
+ CREATE INDEX IF NOT EXISTS idx_chat_messages_data
213
+ ON {self.schema}.chat_messages USING GIN (message_data)
214
+ ''')
215
+
216
+ # Create trigger to update updated_at timestamp
217
+ cursor.execute(f'''
218
+ CREATE OR REPLACE FUNCTION {self.schema}.update_chat_session_timestamp()
219
+ RETURNS TRIGGER AS $$
220
+ BEGIN
221
+ UPDATE {self.schema}.chat_sessions
222
+ SET updated_at = CURRENT_TIMESTAMP
223
+ WHERE chat_id = NEW.chat_id;
224
+ RETURN NEW;
225
+ END;
226
+ $$ LANGUAGE plpgsql;
227
+ ''')
228
+
229
+ cursor.execute(f'''
230
+ DROP TRIGGER IF EXISTS trigger_update_chat_session
231
+ ON {self.schema}.chat_messages
232
+ ''')
233
+
234
+ cursor.execute(f'''
235
+ CREATE TRIGGER trigger_update_chat_session
236
+ AFTER INSERT ON {self.schema}.chat_messages
237
+ FOR EACH ROW
238
+ EXECUTE FUNCTION {self.schema}.update_chat_session_timestamp()
239
+ ''')
240
+
241
+ conn.commit()
242
+
243
+ except Exception as e:
244
+ if conn:
245
+ conn.rollback()
246
+ raise Exception(f"Failed to initialize database: {e}")
247
+ finally:
248
+ if conn:
249
+ self._return_connection(conn)
250
+
251
+ def _ensure_initialized(self) -> None:
252
+ """Ensure the component is initialized before operations."""
253
+ if not self.initialized:
254
+ self.setup()
255
+
256
+ def chat_exists(self, chat_id: str) -> bool:
257
+ """Check if a chat session exists.
258
+
259
+ Args:
260
+ chat_id: Unique identifier for the chat session.
261
+
262
+ Returns:
263
+ True if chat exists, False otherwise.
264
+ """
265
+ self._ensure_initialized()
266
+
267
+ conn = None
268
+ try:
269
+ conn = self._get_connection()
270
+ cursor = conn.cursor()
271
+ cursor.execute(
272
+ f'SELECT 1 FROM {self.schema}.chat_sessions WHERE chat_id = %s',
273
+ (chat_id,)
274
+ )
275
+ return cursor.fetchone() is not None
276
+ finally:
277
+ if conn:
278
+ self._return_connection(conn)
279
+
280
+ def add_message(self, chat_id: str, message: Dict[str, Any]) -> None:
281
+ """Add a single message to a chat session.
282
+
283
+ Args:
284
+ chat_id: Unique identifier for the chat session.
285
+ message: Message object (dict with role, content, timestamp, etc.).
286
+ """
287
+ self._ensure_initialized()
288
+
289
+ conn = None
290
+ try:
291
+ conn = self._get_connection()
292
+ cursor = conn.cursor()
293
+
294
+ # Ensure chat session exists
295
+ cursor.execute(
296
+ f'INSERT INTO {self.schema}.chat_sessions (chat_id) VALUES (%s) ON CONFLICT (chat_id) DO NOTHING',
297
+ (chat_id,)
298
+ )
299
+
300
+ # Get next message order
301
+ cursor.execute(
302
+ f'SELECT COALESCE(MAX(message_order), -1) + 1 FROM {self.schema}.chat_messages WHERE chat_id = %s',
303
+ (chat_id,)
304
+ )
305
+ next_order = cursor.fetchone()[0]
306
+
307
+ # Insert message
308
+ message_json = json.dumps(message, ensure_ascii=False)
309
+ cursor.execute(
310
+ f'''
311
+ INSERT INTO {self.schema}.chat_messages (chat_id, message_order, message_data)
312
+ VALUES (%s, %s, %s::jsonb)
313
+ ''',
314
+ (chat_id, next_order, message_json)
315
+ )
316
+
317
+ conn.commit()
318
+
319
+ except Exception as e:
320
+ if conn:
321
+ conn.rollback()
322
+ raise Exception(f"Failed to add message: {e}")
323
+ finally:
324
+ if conn:
325
+ self._return_connection(conn)
326
+
327
+ def set_messages(self, chat_id: str, messages: List[Dict[str, Any]]) -> None:
328
+ """Set/replace all messages for a chat session.
329
+
330
+ Args:
331
+ chat_id: Unique identifier for the chat session.
332
+ messages: List of message objects to store.
333
+ """
334
+ self._ensure_initialized()
335
+
336
+ conn = None
337
+ try:
338
+ conn = self._get_connection()
339
+ cursor = conn.cursor()
340
+
341
+ # Ensure chat session exists
342
+ cursor.execute(
343
+ f'INSERT INTO {self.schema}.chat_sessions (chat_id) VALUES (%s) ON CONFLICT (chat_id) DO NOTHING',
344
+ (chat_id,)
345
+ )
346
+
347
+ # Delete existing messages
348
+ cursor.execute(
349
+ f'DELETE FROM {self.schema}.chat_messages WHERE chat_id = %s',
350
+ (chat_id,)
351
+ )
352
+
353
+ # Insert new messages with order
354
+ for order, message in enumerate(messages):
355
+ message_json = json.dumps(message, ensure_ascii=False)
356
+ cursor.execute(
357
+ f'''
358
+ INSERT INTO {self.schema}.chat_messages (chat_id, message_order, message_data)
359
+ VALUES (%s, %s, %s::jsonb)
360
+ ''',
361
+ (chat_id, order, message_json)
362
+ )
363
+
364
+ conn.commit()
365
+
366
+ except Exception as e:
367
+ if conn:
368
+ conn.rollback()
369
+ raise Exception(f"Failed to set messages: {e}")
370
+ finally:
371
+ if conn:
372
+ self._return_connection(conn)
373
+
374
+ def get_chat_history(self, chat_id: str) -> List[Dict[str, Any]]:
375
+ """Retrieve all messages for a chat session.
376
+
377
+ Args:
378
+ chat_id: Unique identifier for the chat session.
379
+
380
+ Returns:
381
+ List of message objects, or empty list if chat doesn't exist.
382
+ """
383
+ self._ensure_initialized()
384
+
385
+ conn = None
386
+ try:
387
+ conn = self._get_connection()
388
+ cursor = conn.cursor()
389
+ cursor.execute(
390
+ f'''
391
+ SELECT message_data
392
+ FROM {self.schema}.chat_messages
393
+ WHERE chat_id = %s
394
+ ORDER BY message_order ASC
395
+ ''',
396
+ (chat_id,)
397
+ )
398
+ rows = cursor.fetchall()
399
+
400
+ return [row[0] for row in rows]
401
+
402
+ finally:
403
+ if conn:
404
+ self._return_connection(conn)
405
+
406
+ def get_all_chat_ids(self) -> List[str]:
407
+ """Get all chat IDs currently stored.
408
+
409
+ Returns:
410
+ List of all chat session identifiers.
411
+ """
412
+ self._ensure_initialized()
413
+
414
+ conn = None
415
+ try:
416
+ conn = self._get_connection()
417
+ cursor = conn.cursor()
418
+ cursor.execute(
419
+ f'SELECT chat_id FROM {self.schema}.chat_sessions ORDER BY updated_at DESC')
420
+ return [row[0] for row in cursor.fetchall()]
421
+ finally:
422
+ if conn:
423
+ self._return_connection(conn)
424
+
425
+ def delete_chat_history(self, chat_id: str) -> bool:
426
+ """Delete all messages for a chat session.
427
+
428
+ Args:
429
+ chat_id: Unique identifier for the chat session to delete.
430
+
431
+ Returns:
432
+ True if deletion was successful, False if chat_id didn't exist.
433
+ """
434
+ self._ensure_initialized()
435
+
436
+ conn = None
437
+ try:
438
+ conn = self._get_connection()
439
+ cursor = conn.cursor()
440
+
441
+ # Check if chat exists
442
+ cursor.execute(
443
+ f'SELECT 1 FROM {self.schema}.chat_sessions WHERE chat_id = %s',
444
+ (chat_id,)
445
+ )
446
+ exists = cursor.fetchone() is not None
447
+
448
+ if not exists:
449
+ return False
450
+
451
+ # Delete messages (cascades from foreign key)
452
+ cursor.execute(
453
+ f'DELETE FROM {self.schema}.chat_messages WHERE chat_id = %s',
454
+ (chat_id,)
455
+ )
456
+ cursor.execute(
457
+ f'DELETE FROM {self.schema}.chat_sessions WHERE chat_id = %s',
458
+ (chat_id,)
459
+ )
460
+
461
+ conn.commit()
462
+ return True
463
+
464
+ except Exception as e:
465
+ if conn:
466
+ conn.rollback()
467
+ raise Exception(f"Failed to delete chat history: {e}")
468
+ finally:
469
+ if conn:
470
+ self._return_connection(conn)
471
+
472
+ def clear_all(self) -> None:
473
+ """Delete all chat histories."""
474
+ self._ensure_initialized()
475
+
476
+ conn = None
477
+ try:
478
+ conn = self._get_connection()
479
+ cursor = conn.cursor()
480
+ cursor.execute(f'DELETE FROM {self.schema}.chat_messages')
481
+ cursor.execute(f'DELETE FROM {self.schema}.chat_sessions')
482
+ conn.commit()
483
+ except Exception as e:
484
+ if conn:
485
+ conn.rollback()
486
+ raise Exception(f"Failed to clear all chats: {e}")
487
+ finally:
488
+ if conn:
489
+ self._return_connection(conn)
490
+
491
+ def get_readable_chat_history(
492
+ self,
493
+ chat_id: str,
494
+ user_role: str = "user",
495
+ assistant_role: str = "assistant",
496
+ ) -> List[Dict[str, Any]]:
497
+ """Get chat history in a human-readable format.
498
+
499
+ This method transforms the raw message format into a display-friendly
500
+ format with 'from', 'message', and optional 'timestamp' keys.
501
+
502
+ Args:
503
+ chat_id: Unique identifier for the chat session.
504
+ user_role: The role name for user messages (default: "user").
505
+ assistant_role: The role name for assistant messages (default: "assistant").
506
+
507
+ Returns:
508
+ List of formatted message dictionaries with:
509
+ - 'from': "user" or "assistant"
510
+ - 'message': The message content
511
+ - 'timestamp': Message timestamp (if available)
512
+ """
513
+ self._ensure_initialized()
514
+
515
+ messages = self.get_chat_history(chat_id)
516
+ readable_messages = []
517
+
518
+ for msg in messages:
519
+ role = msg.get("role", "")
520
+ content = msg.get("content", "")
521
+ timestamp = msg.get("timestamp")
522
+
523
+ # Determine the 'from' field based on role
524
+ if role == user_role:
525
+ from_field = "user"
526
+ elif role == assistant_role:
527
+ from_field = "assistant"
528
+ else:
529
+ # Skip system messages or unknown roles
530
+ continue
531
+
532
+ formatted = {
533
+ "from": from_field,
534
+ "message": content,
535
+ }
536
+
537
+ if timestamp:
538
+ formatted["timestamp"] = timestamp
539
+
540
+ readable_messages.append(formatted)
541
+
542
+ return readable_messages
543
+
544
+ # ==================== Pydantic-AI Integration ====================
545
+
546
+ def get_message_history(self, chat_id: str) -> Optional[List[Any]]:
547
+ """Get chat history in pydantic-ai compatible format.
548
+
549
+ This method converts the stored database history to pydantic-ai's
550
+ ModelMessage format, ready to be passed to agent.run() or
551
+ agent.run_stream() as message_history.
552
+
553
+ Args:
554
+ chat_id: Unique identifier for the chat session.
555
+
556
+ Returns:
557
+ List of ModelMessage objects for pydantic-ai, or None if:
558
+ - Chat doesn't exist or is empty
559
+ - pydantic-ai is not installed
560
+
561
+ Example:
562
+ >>> history = PostgresChatHistory()
563
+ >>> message_history = history.get_message_history("chat123")
564
+ >>> result = await agent.run("Hello", message_history=message_history)
565
+ """
566
+ if not PYDANTIC_AI_AVAILABLE:
567
+ raise ImportError(
568
+ "pydantic-ai is not installed. Install it with: pip install pydantic-ai"
569
+ )
570
+
571
+ self._ensure_initialized()
572
+ raw_history = self.get_chat_history(chat_id)
573
+
574
+ if not raw_history:
575
+ return None
576
+
577
+ return ModelMessagesTypeAdapter.validate_python(raw_history)
578
+
579
+ def save_messages(self, chat_id: str, messages: List[Any]) -> None:
580
+ """Save pydantic-ai messages to history.
581
+
582
+ This method converts pydantic-ai's ModelMessage objects to JSON
583
+ and stores them. Typically called with result.all_messages() after
584
+ an agent run.
585
+
586
+ Args:
587
+ chat_id: Unique identifier for the chat session.
588
+ messages: List of pydantic-ai ModelMessage objects
589
+ (e.g., from result.all_messages()).
590
+
591
+ Example:
592
+ >>> result = await agent.run("Hello", message_history=history.get_message_history("chat123"))
593
+ >>> history.save_messages("chat123", result.all_messages())
594
+ """
595
+ if not PYDANTIC_AI_AVAILABLE:
596
+ raise ImportError(
597
+ "pydantic-ai is not installed. Install it with: pip install pydantic-ai"
598
+ )
599
+
600
+ self._ensure_initialized()
601
+
602
+ # Convert pydantic-ai messages to JSON-serializable format
603
+ json_messages = to_jsonable_python(messages)
604
+ self.set_messages(chat_id, json_messages)
605
+
606
+
607
+ if __name__ == "__main__":
608
+ # Example usage
609
+ print("PostgresChatHistory Example")
610
+ print("=" * 50)
611
+
612
+ # Initialize with default settings (uses environment variables)
613
+ history = PostgresChatHistory(
614
+ config={
615
+ "host": "localhost",
616
+ "database": "vectorstore_db",
617
+ "user": "postgres",
618
+ "password": "postgres"
619
+ }
620
+ )
621
+
622
+ try:
623
+ history.setup()
624
+ print("✓ Connected to PostgreSQL")
625
+
626
+ # Add messages
627
+ history.add_message("chat123", {
628
+ "role": "user",
629
+ "content": "Hello!",
630
+ "timestamp": "2025-12-16 10:00:00"
631
+ })
632
+ history.add_message("chat123", {
633
+ "role": "assistant",
634
+ "content": "Hi there! How can I help?",
635
+ "timestamp": "2025-12-16 10:00:05"
636
+ })
637
+
638
+ # Retrieve history
639
+ print("\nChat history:", history.get_chat_history("chat123"))
640
+ print("\nAll chat IDs:", history.get_all_chat_ids())
641
+ print("\nReadable format:", history.get_readable_chat_history("chat123"))
642
+
643
+ # Test set_messages
644
+ history.set_messages("chat456", [
645
+ {"role": "user", "content": "Test message",
646
+ "timestamp": "2025-12-16 10:02:00"}
647
+ ])
648
+ print("\nChat history for chat456:",
649
+ history.get_chat_history("chat456"))
650
+ print("All chat IDs after adding chat456:", history.get_all_chat_ids())
651
+
652
+ # Clean up
653
+ deleted = history.delete_chat_history("chat123")
654
+ print(f"\nDeleted chat123: {deleted}")
655
+ print("After deletion:", history.get_all_chat_ids())
656
+
657
+ # Clear all
658
+ history.clear_all()
659
+ print("After clear_all:", history.get_all_chat_ids())
660
+
661
+ print("\n✓ All tests passed!")
662
+
663
+ except Exception as e:
664
+ print(f"\n❌ Error: {e}")
665
+ print("\nMake sure PostgreSQL is running:")
666
+ print(" docker-compose up -d vectordb")
667
+ finally:
668
+ history.shutdown()