tinyagent-py 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
storage/base.py ADDED
@@ -0,0 +1,49 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, TYPE_CHECKING, Optional
3
+
4
+ if TYPE_CHECKING:
5
+ from tinyagent.tiny_agent import TinyAgent
6
+
7
+ class Storage(ABC):
8
+ """
9
+ Abstract base class for TinyAgent session storage.
10
+ """
11
+
12
+ @abstractmethod
13
+ async def save_session(self, session_id: str, data: Dict[str, Any], user_id: Optional[str] = None) -> None:
14
+ """
15
+ Persist the given agent state under `session_id`.
16
+ """
17
+ ...
18
+
19
+ @abstractmethod
20
+ async def load_session(self, session_id: str, user_id: Optional[str] = None) -> Dict[str, Any]:
21
+ """
22
+ Retrieve the agent state for `session_id`, or return {} if not found.
23
+ """
24
+ ...
25
+
26
+ @abstractmethod
27
+ async def close(self) -> None:
28
+ """
29
+ Clean up any resources (DB connections, file handles, etc.).
30
+ """
31
+ ...
32
+
33
+ def attach(self, agent: "TinyAgent") -> None:
34
+ """
35
+ Hook this storage to a TinyAgent so that on every `llm_end`
36
+ it will auto‐persist the agent's state.
37
+
38
+ Usage:
39
+ storage.attach(agent)
40
+ or in TinyAgent.__init__:
41
+ if storage: storage.attach(self)
42
+ """
43
+ async def _auto_save(event_name: str, agent: "TinyAgent", **kwargs):
44
+ if event_name != "llm_end":
45
+ return
46
+ state = agent.to_dict()
47
+ await self.save_session(agent.session_id, state)
48
+
49
+ agent.callbacks.append(_auto_save)
@@ -0,0 +1,30 @@
1
+ import json
2
+ import asyncio
3
+ from pathlib import Path
4
+ from typing import Dict, Any, Union, Optional
5
+ from tinyagent.storage import Storage
6
+
7
+ class JsonFileStorage(Storage):
8
+ """
9
+ Persist TinyAgent sessions as individual JSON files.
10
+ """
11
+
12
+ def __init__(self, folder: Union[str, Path]):
13
+ self.folder = Path(folder)
14
+ self.folder.mkdir(parents=True, exist_ok=True)
15
+
16
+ async def save_session(self, session_id: str, data: Dict[str, Any], user_id: Optional[str] = None) -> None:
17
+ path = self.folder / f"{session_id}_{user_id}.json"
18
+ # Write in a thread pool to avoid blocking the event loop
19
+ await asyncio.to_thread(path.write_text, json.dumps(data, indent=2), "utf-8")
20
+
21
+ async def load_session(self, session_id: str, user_id: Optional[str] = None) -> Dict[str, Any]:
22
+ path = self.folder / f"{session_id}_{user_id}.json"
23
+ if not path.exists():
24
+ return {}
25
+ text = await asyncio.to_thread(path.read_text, "utf-8")
26
+ return json.loads(text)
27
+
28
+ async def close(self) -> None:
29
+ # Nothing to clean up for file storage
30
+ return
@@ -0,0 +1,201 @@
1
+ import asyncpg
2
+ import json
3
+ import logging
4
+ from typing import Optional, Dict, Any
5
+ from tinyagent.storage import Storage
6
+
7
+ class PostgresStorage(Storage):
8
+ """
9
+ Persist TinyAgent sessions in a Postgres table with JSONB state.
10
+ """
11
+
12
+ def __init__(self, db_url: str, table_name: str = "tny_agent_sessions"):
13
+ self._dsn = db_url
14
+ self._table = table_name
15
+ self._pool: Optional[asyncpg.pool.Pool] = None
16
+ self.logger = logging.getLogger(__name__)
17
+
18
+ async def _ensure_table(self):
19
+ """Create the sessions table if it doesn't exist."""
20
+ self.logger.debug(f"Ensuring table {self._table} exists")
21
+ try:
22
+ async with self._pool.acquire() as conn:
23
+ await conn.execute(f"""
24
+ CREATE TABLE IF NOT EXISTS {self._table} (
25
+ agent_id TEXT PRIMARY KEY,
26
+ session_id TEXT NOT NULL,
27
+ user_id TEXT,
28
+ memories JSONB,
29
+ metadata JSONB,
30
+ session_data JSONB,
31
+ model_meta JSONB,
32
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
33
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
34
+ );
35
+ CREATE INDEX IF NOT EXISTS idx_{self._table}_session_id ON {self._table} (session_id);
36
+ CREATE INDEX IF NOT EXISTS idx_{self._table}_user_id ON {self._table} (user_id);
37
+ """)
38
+ self.logger.info(f"Table {self._table} and indexes created/verified")
39
+ except Exception as e:
40
+ self.logger.error(f"Error creating table {self._table}: {str(e)}")
41
+ raise
42
+
43
+ async def _connect(self):
44
+ if not self._pool:
45
+ self.logger.debug(f"Connecting to PostgreSQL with DSN: {self._dsn[:10]}...")
46
+ try:
47
+ # Ensure statement_cache_size=0 to disable prepared statements for pgbouncer compatibility
48
+ self._pool = await asyncpg.create_pool(
49
+ dsn=self._dsn,
50
+ statement_cache_size=0,
51
+ min_size=1,
52
+ max_size=10
53
+ )
54
+ self.logger.info("PostgreSQL connection pool created")
55
+ await self._ensure_table()
56
+ except Exception as e:
57
+ self.logger.error(f"Failed to connect to PostgreSQL: {str(e)}")
58
+ raise
59
+
60
+ async def save_session(self, session_id: str, data: Dict[str, Any], user_id: Optional[str] = None):
61
+ self.logger.info(f"Saving session {session_id} for user {user_id}")
62
+ self.logger.debug(f"Save data: {json.dumps(data)[:200]}...")
63
+
64
+ try:
65
+ await self._connect()
66
+
67
+ # Extract data following the TinyAgent schema
68
+ metadata = data.get("metadata", {}) or {}
69
+ session_state = data.get("session_state", {}) or {}
70
+
71
+ # Use session_id as agent_id if not provided
72
+ agent_id = metadata.get("agent_id", session_id)
73
+ self.logger.debug(f"Using agent_id: {agent_id}")
74
+
75
+ # Extract specific components
76
+ memories = session_state.get("memory", {})
77
+ session_data = {"messages": session_state.get("messages", [])}
78
+ model_meta = metadata.get("model_meta", {})
79
+
80
+ # Convert Python dictionaries to JSON strings for PostgreSQL
81
+ self.logger.debug("Converting Python dictionaries to JSON")
82
+ try:
83
+ memories_json = json.dumps(memories)
84
+ metadata_json = json.dumps(metadata)
85
+ session_data_json = json.dumps(session_data)
86
+ model_meta_json = json.dumps(model_meta)
87
+ except Exception as e:
88
+ self.logger.error(f"JSON serialization error: {str(e)}")
89
+ raise
90
+
91
+ self.logger.debug("Executing PostgreSQL INSERT/UPDATE")
92
+ async with self._pool.acquire() as conn:
93
+ try:
94
+ await conn.execute(f"""
95
+ INSERT INTO {self._table}
96
+ (agent_id, session_id, user_id, memories, metadata, session_data, model_meta, updated_at)
97
+ VALUES ($1, $2, $3, $4::jsonb, $5::jsonb, $6::jsonb, $7::jsonb, NOW())
98
+ ON CONFLICT (agent_id) DO UPDATE
99
+ SET session_id = EXCLUDED.session_id,
100
+ user_id = EXCLUDED.user_id,
101
+ memories = EXCLUDED.memories,
102
+ metadata = EXCLUDED.metadata,
103
+ session_data = EXCLUDED.session_data,
104
+ model_meta = EXCLUDED.model_meta,
105
+ updated_at = NOW();
106
+ """, agent_id, session_id, user_id, memories_json, metadata_json, session_data_json, model_meta_json)
107
+ self.logger.info(f"Session {session_id} saved successfully")
108
+ except Exception as e:
109
+ self.logger.error(f"Database error during save: {str(e)}")
110
+ raise
111
+ except Exception as e:
112
+ self.logger.error(f"Failed to save session {session_id}: {str(e)}")
113
+ raise
114
+
115
+ async def load_session(self, session_id: str, user_id: Optional[str] = None) -> Dict[str, Any]:
116
+ self.logger.info(f"Loading session {session_id} for user {user_id}")
117
+
118
+ try:
119
+ await self._connect()
120
+
121
+ async with self._pool.acquire() as conn:
122
+ # First try to find by session_id
123
+ query = f"""
124
+ SELECT agent_id, session_id, user_id, memories, metadata, session_data, model_meta
125
+ FROM {self._table}
126
+ WHERE session_id = $1
127
+ """
128
+ params = [session_id]
129
+
130
+ # Add user_id filter if provided
131
+ if user_id:
132
+ query += " AND user_id = $2"
133
+ params.append(user_id)
134
+
135
+ self.logger.debug(f"Executing query: {query} with params: {params}")
136
+ row = await conn.fetchrow(query, *params)
137
+
138
+ if not row:
139
+ self.logger.warning(f"No session found for session_id={session_id}, user_id={user_id}")
140
+ return {}
141
+
142
+ self.logger.debug(f"Session found: {dict(row)}")
143
+
144
+ # Parse JSON from PostgreSQL
145
+ try:
146
+ # Check if values are already dictionaries or need parsing
147
+ memories = row["memories"]
148
+ if isinstance(memories, str):
149
+ memories = json.loads(memories)
150
+
151
+ metadata = row["metadata"]
152
+ if isinstance(metadata, str):
153
+ metadata = json.loads(metadata)
154
+
155
+ session_data = row["session_data"]
156
+ if isinstance(session_data, str):
157
+ session_data = json.loads(session_data)
158
+
159
+ model_meta = row["model_meta"]
160
+ if isinstance(model_meta, str):
161
+ model_meta = json.loads(model_meta)
162
+ except Exception as e:
163
+ self.logger.error(f"Error parsing JSON from database: {str(e)}")
164
+ raise
165
+
166
+ # Update metadata with additional fields
167
+ metadata.update({
168
+ "agent_id": row["agent_id"],
169
+ "user_id": row["user_id"],
170
+ "model_meta": model_meta
171
+ })
172
+
173
+ # Construct session state
174
+ session_state = {
175
+ "messages": session_data.get("messages", []),
176
+ "memory": memories,
177
+ }
178
+
179
+ result = {
180
+ "session_id": row["session_id"],
181
+ "metadata": metadata,
182
+ "session_state": session_state
183
+ }
184
+
185
+ self.logger.info(f"Session {session_id} loaded successfully")
186
+ self.logger.debug(f"Loaded data: {json.dumps(result)[:200]}...")
187
+ return result
188
+ except Exception as e:
189
+ self.logger.error(f"Failed to load session {session_id}: {str(e)}")
190
+ raise
191
+
192
+ async def close(self):
193
+ if self._pool:
194
+ self.logger.info("Closing PostgreSQL connection pool")
195
+ try:
196
+ await self._pool.close()
197
+ self._pool = None
198
+ self.logger.debug("PostgreSQL connection pool closed")
199
+ except Exception as e:
200
+ self.logger.error(f"Error closing PostgreSQL connection: {str(e)}")
201
+
@@ -0,0 +1,48 @@
1
+ import json
2
+ from typing import Dict, Any, Optional
3
+ import redis.asyncio as aioredis
4
+ from tinyagent.storage import Storage
5
+
6
+ class RedisStorage(Storage):
7
+ """
8
+ Persist TinyAgent sessions in Redis. Optionally expire them after `ttl` seconds.
9
+ """
10
+
11
+ def __init__(self, url: str = "redis://localhost", ttl: Optional[int] = None):
12
+ """
13
+ :param url: Redis connection URL, e.g. "redis://localhost:6379/0"
14
+ :param ttl: time‐to‐live in seconds (None ⇒ no expiry)
15
+ """
16
+ self.url = url
17
+ self.ttl = ttl
18
+ self._client: Optional[aioredis.Redis] = None
19
+
20
+ async def _connect(self):
21
+ if not self._client:
22
+ # from_url returns an asyncio‐enabled Redis client
23
+ self._client = aioredis.from_url(self.url)
24
+
25
+ async def save_session(self, session_id: str, data: Dict[str, Any], user_id: Optional[str] = None) -> None:
26
+ await self._connect()
27
+ payload = json.dumps(data)
28
+ if self.ttl is not None:
29
+ # set with expiration
30
+ await self._client.set(f"{session_id}_{user_id}", payload, ex=self.ttl)
31
+ else:
32
+ # set without expiration
33
+ await self._client.set(f"{session_id}_{user_id}", payload)
34
+
35
+ async def load_session(self, session_id: str, user_id: Optional[str] = None) -> Dict[str, Any]:
36
+ await self._connect()
37
+ raw = await self._client.get(f"{session_id}_{user_id}")
38
+ if not raw:
39
+ return {}
40
+ # raw may be bytes or str
41
+ if isinstance(raw, bytes):
42
+ raw = raw.decode("utf-8")
43
+ return json.loads(raw)
44
+
45
+ async def close(self) -> None:
46
+ if self._client:
47
+ await self._client.close()
48
+ self._client = None
@@ -0,0 +1,156 @@
1
+ import aiosqlite
2
+ import json
3
+ import os
4
+ from typing import Optional, Dict, Any
5
+ from tinyagent.storage import Storage
6
+
7
+ class SqliteStorage(Storage):
8
+ """
9
+ Persist TinyAgent sessions in a SQLite database with JSON state.
10
+ """
11
+
12
+ def __init__(self, db_path: str, table_name: str = "tny_agent_sessions"):
13
+ self._db_path = db_path
14
+ self._table = table_name
15
+ self._conn: Optional[aiosqlite.Connection] = None
16
+
17
+ # Ensure the directory exists
18
+ os.makedirs(os.path.dirname(os.path.abspath(db_path)), exist_ok=True)
19
+
20
+ async def _ensure_table(self):
21
+ """Create the sessions table if it doesn't exist."""
22
+ await self._conn.execute(f"""
23
+ CREATE TABLE IF NOT EXISTS {self._table} (
24
+ agent_id TEXT PRIMARY KEY,
25
+ session_id TEXT NOT NULL,
26
+ user_id TEXT,
27
+ memories TEXT,
28
+ metadata TEXT,
29
+ session_data TEXT,
30
+ model_meta TEXT,
31
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
32
+ updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
33
+ );
34
+ """)
35
+
36
+ # Create indexes
37
+ await self._conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self._table}_session_id ON {self._table} (session_id);")
38
+ await self._conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self._table}_user_id ON {self._table} (user_id);")
39
+ await self._conn.commit()
40
+
41
+ async def _connect(self):
42
+ if not self._conn:
43
+ self._conn = await aiosqlite.connect(self._db_path)
44
+ self._conn.row_factory = aiosqlite.Row
45
+ await self._ensure_table()
46
+
47
+ async def save_session(self, session_id: str, data: Dict[str, Any], user_id: Optional[str] = None):
48
+ await self._connect()
49
+ print(f"Saving session {session_id} for user {user_id} to sqlite {data}")
50
+ # Extract data following the TinyAgent schema
51
+ metadata = data.get("metadata", {}) or {}
52
+ session_state = data.get("session_state", {}) or {}
53
+
54
+ # Use session_id as agent_id if not provided
55
+ agent_id = metadata.get("agent_id", session_id)
56
+
57
+ # Extract specific components
58
+ memories = session_state.get("memory", {})
59
+ session_data = {"messages": session_state.get("messages", [])}
60
+ model_meta = metadata.get("model_meta", {})
61
+
62
+ # Convert dictionaries to JSON strings
63
+ memories_json = json.dumps(memories)
64
+ metadata_json = json.dumps(metadata)
65
+ session_data_json = json.dumps(session_data)
66
+ model_meta_json = json.dumps(model_meta)
67
+
68
+ # Check if record exists
69
+ cursor = await self._conn.execute(
70
+ f"SELECT 1 FROM {self._table} WHERE agent_id = ?",
71
+ (agent_id,)
72
+ )
73
+ exists = await cursor.fetchone() is not None
74
+
75
+ if exists:
76
+ # Update existing record
77
+ await self._conn.execute(f"""
78
+ UPDATE {self._table} SET
79
+ session_id = ?,
80
+ user_id = ?,
81
+ memories = ?,
82
+ metadata = ?,
83
+ session_data = ?,
84
+ model_meta = ?,
85
+ updated_at = CURRENT_TIMESTAMP
86
+ WHERE agent_id = ?
87
+ """, (session_id, user_id, memories_json, metadata_json, session_data_json, model_meta_json, agent_id))
88
+ else:
89
+ # Insert new record
90
+ await self._conn.execute(f"""
91
+ INSERT INTO {self._table}
92
+ (agent_id, session_id, user_id, memories, metadata, session_data, model_meta)
93
+ VALUES (?, ?, ?, ?, ?, ?, ?)
94
+ """, (agent_id, session_id, user_id, memories_json, metadata_json, session_data_json, model_meta_json))
95
+
96
+ await self._conn.commit()
97
+
98
+ async def load_session(self, session_id: str, user_id: Optional[str] = None) -> Dict[str, Any]:
99
+ await self._connect()
100
+
101
+ # Build query
102
+ query = f"""
103
+ SELECT agent_id, session_id, user_id, memories, metadata, session_data, model_meta
104
+ FROM {self._table}
105
+ WHERE session_id = ?
106
+ """
107
+ params = [session_id]
108
+
109
+ # Add user_id filter if provided
110
+ if user_id:
111
+ query += " AND user_id = ?"
112
+ params.append(user_id)
113
+
114
+ # Execute query
115
+ cursor = await self._conn.execute(query, params)
116
+ row = await cursor.fetchone()
117
+
118
+ if not row:
119
+ return {}
120
+
121
+ # Parse JSON strings
122
+ memories = json.loads(row["memories"]) if row["memories"] else {}
123
+ metadata = json.loads(row["metadata"]) if row["metadata"] else {}
124
+ session_data = json.loads(row["session_data"]) if row["session_data"] else {}
125
+ model_meta = json.loads(row["model_meta"]) if row["model_meta"] else {}
126
+
127
+ # Update metadata with additional fields
128
+ metadata.update({
129
+ "agent_id": row["agent_id"],
130
+ "user_id": row["user_id"],
131
+ "model_meta": model_meta
132
+ })
133
+
134
+ # Construct session state
135
+ session_state = {
136
+ "messages": session_data.get("messages", []),
137
+ "memory": memories,
138
+ }
139
+
140
+ return {
141
+ "session_id": row["session_id"],
142
+ "metadata": metadata,
143
+ "session_state": session_state
144
+ }
145
+
146
+ async def close(self):
147
+ if self._conn:
148
+ await self._conn.close()
149
+ self._conn = None
150
+
151
+ async def __aenter__(self):
152
+ await self._connect()
153
+ return self
154
+
155
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
156
+ await self.close()