agno 2.3.1__py3-none-any.whl → 2.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. agno/agent/agent.py +514 -186
  2. agno/compression/__init__.py +3 -0
  3. agno/compression/manager.py +176 -0
  4. agno/db/dynamo/dynamo.py +11 -0
  5. agno/db/firestore/firestore.py +5 -1
  6. agno/db/gcs_json/gcs_json_db.py +5 -2
  7. agno/db/in_memory/in_memory_db.py +5 -2
  8. agno/db/json/json_db.py +5 -1
  9. agno/db/migrations/manager.py +4 -4
  10. agno/db/mongo/async_mongo.py +158 -34
  11. agno/db/mongo/mongo.py +6 -2
  12. agno/db/mysql/mysql.py +48 -54
  13. agno/db/postgres/async_postgres.py +61 -51
  14. agno/db/postgres/postgres.py +42 -50
  15. agno/db/redis/redis.py +5 -0
  16. agno/db/redis/utils.py +5 -5
  17. agno/db/schemas/memory.py +7 -5
  18. agno/db/singlestore/singlestore.py +99 -108
  19. agno/db/sqlite/async_sqlite.py +32 -30
  20. agno/db/sqlite/sqlite.py +34 -30
  21. agno/knowledge/reader/pdf_reader.py +2 -2
  22. agno/knowledge/reader/tavily_reader.py +0 -1
  23. agno/memory/__init__.py +14 -1
  24. agno/memory/manager.py +223 -8
  25. agno/memory/strategies/__init__.py +15 -0
  26. agno/memory/strategies/base.py +67 -0
  27. agno/memory/strategies/summarize.py +196 -0
  28. agno/memory/strategies/types.py +37 -0
  29. agno/models/anthropic/claude.py +84 -80
  30. agno/models/aws/bedrock.py +38 -16
  31. agno/models/aws/claude.py +97 -277
  32. agno/models/azure/ai_foundry.py +8 -4
  33. agno/models/base.py +101 -14
  34. agno/models/cerebras/cerebras.py +18 -7
  35. agno/models/cerebras/cerebras_openai.py +4 -2
  36. agno/models/cohere/chat.py +8 -4
  37. agno/models/google/gemini.py +578 -20
  38. agno/models/groq/groq.py +18 -5
  39. agno/models/huggingface/huggingface.py +17 -6
  40. agno/models/ibm/watsonx.py +16 -6
  41. agno/models/litellm/chat.py +17 -7
  42. agno/models/message.py +19 -5
  43. agno/models/meta/llama.py +20 -4
  44. agno/models/mistral/mistral.py +8 -4
  45. agno/models/ollama/chat.py +17 -6
  46. agno/models/openai/chat.py +17 -6
  47. agno/models/openai/responses.py +23 -9
  48. agno/models/vertexai/claude.py +99 -5
  49. agno/os/interfaces/agui/router.py +1 -0
  50. agno/os/interfaces/agui/utils.py +97 -57
  51. agno/os/router.py +16 -1
  52. agno/os/routers/memory/memory.py +146 -0
  53. agno/os/routers/memory/schemas.py +26 -0
  54. agno/os/schema.py +21 -6
  55. agno/os/utils.py +134 -10
  56. agno/run/base.py +2 -1
  57. agno/run/workflow.py +1 -1
  58. agno/team/team.py +571 -225
  59. agno/tools/mcp/mcp.py +1 -1
  60. agno/utils/agent.py +119 -1
  61. agno/utils/dttm.py +33 -0
  62. agno/utils/models/ai_foundry.py +9 -2
  63. agno/utils/models/claude.py +12 -5
  64. agno/utils/models/cohere.py +9 -2
  65. agno/utils/models/llama.py +9 -2
  66. agno/utils/models/mistral.py +4 -2
  67. agno/utils/print_response/agent.py +37 -2
  68. agno/utils/print_response/team.py +52 -0
  69. agno/utils/tokens.py +41 -0
  70. agno/workflow/types.py +2 -2
  71. {agno-2.3.1.dist-info → agno-2.3.3.dist-info}/METADATA +45 -40
  72. {agno-2.3.1.dist-info → agno-2.3.3.dist-info}/RECORD +75 -68
  73. {agno-2.3.1.dist-info → agno-2.3.3.dist-info}/WHEEL +0 -0
  74. {agno-2.3.1.dist-info → agno-2.3.3.dist-info}/licenses/LICENSE +0 -0
  75. {agno-2.3.1.dist-info → agno-2.3.3.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
+ import asyncio
1
2
  import time
2
3
  from datetime import date, datetime, timedelta, timezone
3
- from typing import Any, Dict, List, Optional, Tuple, Union
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
4
5
  from uuid import uuid4
5
6
 
6
7
  from agno.db.base import AsyncBaseDb, SessionType
@@ -25,11 +26,26 @@ from agno.utils.log import log_debug, log_error, log_info
25
26
  from agno.utils.string import generate_id
26
27
 
27
28
  try:
28
- import asyncio
29
-
30
29
  from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorCollection, AsyncIOMotorDatabase # type: ignore
30
+
31
+ MOTOR_AVAILABLE = True
32
+ except ImportError:
33
+ MOTOR_AVAILABLE = False
34
+ AsyncIOMotorClient = None # type: ignore
35
+ AsyncIOMotorCollection = None # type: ignore
36
+ AsyncIOMotorDatabase = None # type: ignore
37
+
38
+ try:
39
+ from pymongo import AsyncMongoClient # type: ignore
40
+ from pymongo.collection import AsyncCollection # type: ignore
41
+ from pymongo.database import AsyncDatabase # type: ignore
42
+
43
+ PYMONGO_ASYNC_AVAILABLE = True
31
44
  except ImportError:
32
- raise ImportError("`motor` not installed. Please install it using `pip install -U motor`")
45
+ PYMONGO_ASYNC_AVAILABLE = False
46
+ AsyncMongoClient = None # type: ignore
47
+ AsyncDatabase = None # type: ignore
48
+ AsyncCollection = None # type: ignore
33
49
 
34
50
  try:
35
51
  from pymongo import ReturnDocument
@@ -37,11 +53,89 @@ try:
37
53
  except ImportError:
38
54
  raise ImportError("`pymongo` not installed. Please install it using `pip install -U pymongo`")
39
55
 
56
+ # Ensure at least one async library is available
57
+ if not MOTOR_AVAILABLE and not PYMONGO_ASYNC_AVAILABLE:
58
+ raise ImportError(
59
+ "Neither `motor` nor PyMongo async is installed. "
60
+ "Please install one of them using:\n"
61
+ " - `pip install -U 'pymongo>=4.9'` (recommended)"
62
+ " - `pip install -U motor` (legacy, deprecated)\n"
63
+ )
64
+
65
+ # Create union types for client, database, and collection
66
+ if TYPE_CHECKING:
67
+ if MOTOR_AVAILABLE and PYMONGO_ASYNC_AVAILABLE:
68
+ AsyncMongoClientType = Union[AsyncIOMotorClient, AsyncMongoClient] # type: ignore
69
+ AsyncMongoDatabaseType = Union[AsyncIOMotorDatabase, AsyncDatabase] # type: ignore
70
+ AsyncMongoCollectionType = Union[AsyncIOMotorCollection, AsyncCollection] # type: ignore
71
+ elif MOTOR_AVAILABLE:
72
+ AsyncMongoClientType = AsyncIOMotorClient # type: ignore
73
+ AsyncMongoDatabaseType = AsyncIOMotorDatabase # type: ignore
74
+ AsyncMongoCollectionType = AsyncIOMotorCollection # type: ignore
75
+ else:
76
+ AsyncMongoClientType = AsyncMongoClient # type: ignore
77
+ AsyncMongoDatabaseType = AsyncDatabase # type: ignore
78
+ AsyncMongoCollectionType = AsyncCollection # type: ignore
79
+ else:
80
+ # Runtime type - use Any to avoid import issues
81
+ AsyncMongoClientType = Any
82
+ AsyncMongoDatabaseType = Any
83
+ AsyncMongoCollectionType = Any
84
+
85
+
86
+ # Client type constants (defined before class to allow use in _detect_client_type)
87
+ _CLIENT_TYPE_MOTOR = "motor"
88
+ _CLIENT_TYPE_PYMONGO_ASYNC = "pymongo_async"
89
+ _CLIENT_TYPE_UNKNOWN = "unknown"
90
+
91
+
92
+ def _detect_client_type(client: Any) -> str:
93
+ """Detect whether a client is Motor or PyMongo async."""
94
+ if client is None:
95
+ return _CLIENT_TYPE_UNKNOWN
96
+
97
+ # Check PyMongo async
98
+ if PYMONGO_ASYNC_AVAILABLE and AsyncMongoClient is not None:
99
+ try:
100
+ if isinstance(client, AsyncMongoClient):
101
+ return _CLIENT_TYPE_PYMONGO_ASYNC
102
+ except (TypeError, AttributeError):
103
+ pass # Fall through to next check
104
+
105
+ if MOTOR_AVAILABLE and AsyncIOMotorClient is not None:
106
+ try:
107
+ if isinstance(client, AsyncIOMotorClient):
108
+ return _CLIENT_TYPE_MOTOR
109
+ except (TypeError, AttributeError):
110
+ pass # Fall through to fallback
111
+
112
+ # Fallback to string matching only if isinstance fails
113
+ # (should rarely happen, but useful for edge cases)
114
+ client_type_name = type(client).__name__
115
+ if "Motor" in client_type_name or "AsyncIOMotor" in client_type_name:
116
+ return _CLIENT_TYPE_MOTOR
117
+ elif "AsyncMongo" in client_type_name:
118
+ return _CLIENT_TYPE_PYMONGO_ASYNC
119
+
120
+ # Last resort: check module name
121
+ module_name = type(client).__module__
122
+ if "motor" in module_name:
123
+ return _CLIENT_TYPE_MOTOR
124
+ elif "pymongo" in module_name:
125
+ return _CLIENT_TYPE_PYMONGO_ASYNC
126
+
127
+ return _CLIENT_TYPE_UNKNOWN
128
+
40
129
 
41
130
  class AsyncMongoDb(AsyncBaseDb):
131
+ # Client type constants (class-level access to module constants)
132
+ CLIENT_TYPE_MOTOR = _CLIENT_TYPE_MOTOR
133
+ CLIENT_TYPE_PYMONGO_ASYNC = _CLIENT_TYPE_PYMONGO_ASYNC
134
+ CLIENT_TYPE_UNKNOWN = _CLIENT_TYPE_UNKNOWN
135
+
42
136
  def __init__(
43
137
  self,
44
- db_client: Optional[AsyncIOMotorClient] = None,
138
+ db_client: Optional[Union["AsyncIOMotorClient", "AsyncMongoClient"]] = None,
45
139
  db_name: Optional[str] = None,
46
140
  db_url: Optional[str] = None,
47
141
  session_collection: Optional[str] = None,
@@ -53,10 +147,16 @@ class AsyncMongoDb(AsyncBaseDb):
53
147
  id: Optional[str] = None,
54
148
  ):
55
149
  """
56
- Async interface for interacting with a MongoDB database using Motor.
150
+ Async interface for interacting with a MongoDB database.
151
+
152
+ Supports both Motor (legacy) and PyMongo async (recommended) clients.
153
+ When both libraries are available, PyMongo async is preferred.
57
154
 
58
155
  Args:
59
- db_client (Optional[AsyncIOMotorClient]): The MongoDB async client to use.
156
+ db_client (Optional[Union[AsyncIOMotorClient, AsyncMongoClient]]):
157
+ The MongoDB async client to use. Can be either Motor's AsyncIOMotorClient
158
+ or PyMongo's AsyncMongoClient. If not provided, a client will be created
159
+ from db_url using the preferred available library.
60
160
  db_name (Optional[str]): The name of the database to use.
61
161
  db_url (Optional[str]): The database URL to connect to.
62
162
  session_collection (Optional[str]): Name of the collection to store sessions.
@@ -68,7 +168,8 @@ class AsyncMongoDb(AsyncBaseDb):
68
168
  id (Optional[str]): ID of the database.
69
169
 
70
170
  Raises:
71
- ValueError: If neither db_url nor db_client is provided.
171
+ ValueError: If neither db_url nor db_client is provided, or if db_client type is unsupported.
172
+ ImportError: If neither motor nor pymongo async is installed.
72
173
  """
73
174
  if id is None:
74
175
  base_seed = db_url or str(db_client)
@@ -86,8 +187,21 @@ class AsyncMongoDb(AsyncBaseDb):
86
187
  culture_table=culture_collection,
87
188
  )
88
189
 
190
+ # Detect client type if provided
191
+ if db_client is not None:
192
+ self._client_type = _detect_client_type(db_client)
193
+ if self._client_type == self.CLIENT_TYPE_UNKNOWN:
194
+ raise ValueError(
195
+ f"Unsupported MongoDB client type: {type(db_client).__name__}. "
196
+ "Only Motor (AsyncIOMotorClient) or PyMongo async (AsyncMongoClient) are supported."
197
+ )
198
+ else:
199
+ # Auto-select preferred library when creating from URL
200
+ # Prefer PyMongo async if available, fallback to Motor
201
+ self._client_type = self.CLIENT_TYPE_PYMONGO_ASYNC if PYMONGO_ASYNC_AVAILABLE else self.CLIENT_TYPE_MOTOR
202
+
89
203
  # Store configuration for lazy initialization
90
- self._provided_client: Optional[AsyncIOMotorClient] = db_client
204
+ self._provided_client: Optional[AsyncMongoClientType] = db_client
91
205
  self.db_url: Optional[str] = db_url
92
206
  self.db_name: str = db_name if db_name is not None else "agno"
93
207
 
@@ -95,8 +209,8 @@ class AsyncMongoDb(AsyncBaseDb):
95
209
  raise ValueError("One of db_url or db_client must be provided")
96
210
 
97
211
  # Client and database will be lazily initialized per event loop
98
- self._client: Optional[AsyncIOMotorClient] = None
99
- self._database: Optional[AsyncIOMotorDatabase] = None
212
+ self._client: Optional[AsyncMongoClientType] = None
213
+ self._database: Optional[AsyncMongoDatabaseType] = None
100
214
  self._event_loop: Optional[asyncio.AbstractEventLoop] = None
101
215
 
102
216
  async def table_exists(self, table_name: str) -> bool:
@@ -126,15 +240,16 @@ class AsyncMongoDb(AsyncBaseDb):
126
240
  if collection_name and not await self.table_exists(collection_name):
127
241
  await self._get_collection(collection_type, create_collection_if_not_found=True)
128
242
 
129
- def _ensure_client(self) -> AsyncIOMotorClient:
243
+ def _ensure_client(self) -> AsyncMongoClientType:
130
244
  """
131
- Ensure the Motor client is valid for the current event loop.
245
+ Ensure the MongoDB async client is valid for the current event loop.
132
246
 
133
- Motor's AsyncIOMotorClient is tied to the event loop it was created in.
134
- If we detect a new event loop, we need to refresh the client.
247
+ Both Motor's AsyncIOMotorClient and PyMongo's AsyncMongoClient are tied to
248
+ the event loop they were created in. If we detect a new event loop, we need
249
+ to refresh the client.
135
250
 
136
251
  Returns:
137
- AsyncIOMotorClient: A valid client for the current event loop.
252
+ Union[AsyncIOMotorClient, AsyncMongoClient]: A valid client for the current event loop.
138
253
  """
139
254
  try:
140
255
  current_loop = asyncio.get_running_loop()
@@ -144,8 +259,13 @@ class AsyncMongoDb(AsyncBaseDb):
144
259
  if self._provided_client is not None:
145
260
  self._client = self._provided_client
146
261
  elif self.db_url is not None:
147
- self._client = AsyncIOMotorClient(self.db_url)
148
- log_debug("Created AsyncIOMotorClient outside event loop")
262
+ # Create client based on detected type
263
+ if self._client_type == self.CLIENT_TYPE_PYMONGO_ASYNC and PYMONGO_ASYNC_AVAILABLE:
264
+ self._client = AsyncMongoClient(self.db_url) # type: ignore
265
+ elif self._client_type == self.CLIENT_TYPE_MOTOR and MOTOR_AVAILABLE:
266
+ self._client = AsyncIOMotorClient(self.db_url) # type: ignore
267
+ else:
268
+ raise RuntimeError(f"Client type '{self._client_type}' not available")
149
269
  return self._client # type: ignore
150
270
 
151
271
  # Check if we're in a different event loop
@@ -153,17 +273,21 @@ class AsyncMongoDb(AsyncBaseDb):
153
273
  # New event loop detected, create new client
154
274
  if self._provided_client is not None:
155
275
  # User provided a client, use it but warn them
276
+ client_type_name = (
277
+ "AsyncMongoClient" if self._client_type == self.CLIENT_TYPE_PYMONGO_ASYNC else "AsyncIOMotorClient"
278
+ )
156
279
  log_debug(
157
- "New event loop detected. Using provided AsyncIOMotorClient, "
280
+ f"New event loop detected. Using provided {client_type_name}, "
158
281
  "which may cause issues if it was created in a different event loop."
159
282
  )
160
283
  self._client = self._provided_client
161
284
  elif self.db_url is not None:
162
- # Create a new client for this event loop
163
- old_loop_id = id(self._event_loop) if self._event_loop else "None"
164
- new_loop_id = id(current_loop)
165
- log_debug(f"Event loop changed from {old_loop_id} to {new_loop_id}, creating new AsyncIOMotorClient")
166
- self._client = AsyncIOMotorClient(self.db_url)
285
+ if self._client_type == self.CLIENT_TYPE_PYMONGO_ASYNC and PYMONGO_ASYNC_AVAILABLE:
286
+ self._client = AsyncMongoClient(self.db_url) # type: ignore
287
+ elif self._client_type == self.CLIENT_TYPE_MOTOR and MOTOR_AVAILABLE:
288
+ self._client = AsyncIOMotorClient(self.db_url) # type: ignore
289
+ else:
290
+ raise RuntimeError(f"Client type '{self._client_type}' not available")
167
291
 
168
292
  self._event_loop = current_loop
169
293
  self._database = None # Reset database reference
@@ -175,21 +299,21 @@ class AsyncMongoDb(AsyncBaseDb):
175
299
  return self._client # type: ignore
176
300
 
177
301
  @property
178
- def db_client(self) -> AsyncIOMotorClient:
302
+ def db_client(self) -> AsyncMongoClientType:
179
303
  """Get the MongoDB client, ensuring it's valid for the current event loop."""
180
304
  return self._ensure_client()
181
305
 
182
306
  @property
183
- def database(self) -> AsyncIOMotorDatabase:
307
+ def database(self) -> AsyncMongoDatabaseType:
184
308
  """Get the MongoDB database, ensuring it's valid for the current event loop."""
185
309
  try:
186
310
  current_loop = asyncio.get_running_loop()
187
311
  if self._database is None or self._event_loop != current_loop:
188
- self._database = self.db_client[self.db_name]
312
+ self._database = self.db_client[self.db_name] # type: ignore
189
313
  except RuntimeError:
190
314
  # No running loop - fallback to existing database or create new one
191
315
  if self._database is None:
192
- self._database = self.db_client[self.db_name]
316
+ self._database = self.db_client[self.db_name] # type: ignore
193
317
  return self._database
194
318
 
195
319
  # -- DB methods --
@@ -204,7 +328,7 @@ class AsyncMongoDb(AsyncBaseDb):
204
328
 
205
329
  async def _get_collection(
206
330
  self, table_type: str, create_collection_if_not_found: Optional[bool] = True
207
- ) -> Optional[AsyncIOMotorCollection]:
331
+ ) -> Optional[AsyncMongoCollectionType]:
208
332
  """Get or create a collection based on table type.
209
333
 
210
334
  Args:
@@ -212,7 +336,7 @@ class AsyncMongoDb(AsyncBaseDb):
212
336
  create_collection_if_not_found (Optional[bool]): Whether to create the collection if it doesn't exist.
213
337
 
214
338
  Returns:
215
- AsyncIOMotorCollection: The collection object.
339
+ Union[AsyncIOMotorCollection, AsyncCollection]: The collection object.
216
340
  """
217
341
  # Ensure client is valid for current event loop before accessing collections
218
342
  _ = self.db_client # This triggers _ensure_client()
@@ -290,7 +414,7 @@ class AsyncMongoDb(AsyncBaseDb):
290
414
 
291
415
  async def _get_or_create_collection(
292
416
  self, collection_name: str, collection_type: str, create_collection_if_not_found: Optional[bool] = True
293
- ) -> Optional[AsyncIOMotorCollection]:
417
+ ) -> Optional[AsyncMongoCollectionType]:
294
418
  """Get or create a collection with proper indexes.
295
419
 
296
420
  Args:
@@ -299,7 +423,7 @@ class AsyncMongoDb(AsyncBaseDb):
299
423
  create_collection_if_not_found (Optional[bool]): Whether to create the collection if it doesn't exist.
300
424
 
301
425
  Returns:
302
- Optional[AsyncIOMotorCollection]: The collection object.
426
+ Union[AsyncIOMotorCollection, AsyncCollection]: The collection object.
303
427
  """
304
428
  try:
305
429
  collection = self.database[collection_name]
@@ -307,7 +431,7 @@ class AsyncMongoDb(AsyncBaseDb):
307
431
  if not hasattr(self, f"_{collection_name}_initialized"):
308
432
  if not create_collection_if_not_found:
309
433
  return None
310
- # Create indexes asynchronously for Motor collections
434
+ # Create indexes asynchronously for async MongoDB collections
311
435
  await create_collection_indexes_async(collection, collection_type)
312
436
  setattr(self, f"_{collection_name}_initialized", True)
313
437
  log_debug(f"Initialized collection '{collection_name}'")
@@ -1543,7 +1667,7 @@ class AsyncMongoDb(AsyncBaseDb):
1543
1667
  log_error(f"Exception reading from sessions collection: {e}")
1544
1668
  return []
1545
1669
 
1546
- async def _get_metrics_calculation_starting_date(self, collection: AsyncIOMotorCollection) -> Optional[date]:
1670
+ async def _get_metrics_calculation_starting_date(self, collection: AsyncMongoCollectionType) -> Optional[date]:
1547
1671
  """Get the first date for which metrics calculation is needed."""
1548
1672
  try:
1549
1673
  result = await collection.find_one({}, sort=[("date", -1)], limit=1)
agno/db/mongo/mongo.py CHANGED
@@ -984,12 +984,14 @@ class MongoDb(BaseDb):
984
984
  self,
985
985
  limit: Optional[int] = None,
986
986
  page: Optional[int] = None,
987
+ user_id: Optional[str] = None,
987
988
  ) -> Tuple[List[Dict[str, Any]], int]:
988
989
  """Get user memories stats.
989
990
 
990
991
  Args:
991
992
  limit (Optional[int]): The limit of the memories to get.
992
993
  page (Optional[int]): The page number to get.
994
+ user_id (Optional[str]): User ID for filtering.
993
995
 
994
996
  Returns:
995
997
  Tuple[List[Dict[str, Any]], int]: A tuple containing the memories stats and the total count.
@@ -1002,9 +1004,11 @@ class MongoDb(BaseDb):
1002
1004
  if collection is None:
1003
1005
  return [], 0
1004
1006
 
1005
- match_stage = {"user_id": {"$ne": None}}
1007
+ match_stage: Dict[str, Any] = {"user_id": {"$ne": None}}
1008
+ if user_id is not None:
1009
+ match_stage["user_id"] = user_id
1006
1010
 
1007
- pipeline = [
1011
+ pipeline: List[Dict[str, Any]] = [
1008
1012
  {"$match": match_stage},
1009
1013
  {
1010
1014
  "$group": {
agno/db/mysql/mysql.py CHANGED
@@ -105,7 +105,7 @@ class MySQLDb(BaseDb):
105
105
  self.db_url: Optional[str] = db_url
106
106
  self.db_engine: Engine = _engine
107
107
  self.db_schema: str = db_schema if db_schema is not None else "ai"
108
- self.metadata: MetaData = MetaData()
108
+ self.metadata: MetaData = MetaData(schema=self.db_schema)
109
109
 
110
110
  # Initialize database session
111
111
  self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine))
@@ -123,14 +123,13 @@ class MySQLDb(BaseDb):
123
123
  with self.Session() as sess:
124
124
  return is_table_available(session=sess, table_name=table_name, db_schema=self.db_schema)
125
125
 
126
- def _create_table(self, table_name: str, table_type: str, db_schema: str) -> Table:
126
+ def _create_table(self, table_name: str, table_type: str) -> Table:
127
127
  """
128
128
  Create a table with the appropriate schema based on the table type.
129
129
 
130
130
  Args:
131
131
  table_name (str): Name of the table to create
132
132
  table_type (str): Type of table (used to get schema definition)
133
- db_schema (str): Database schema name
134
133
 
135
134
  Returns:
136
135
  Table: SQLAlchemy Table object
@@ -138,8 +137,6 @@ class MySQLDb(BaseDb):
138
137
  try:
139
138
  table_schema = get_table_schema_definition(table_type)
140
139
 
141
- log_debug(f"Creating table {table_name}")
142
-
143
140
  columns: List[Column] = []
144
141
  indexes: List[str] = []
145
142
  unique_constraints: List[str] = []
@@ -161,8 +158,7 @@ class MySQLDb(BaseDb):
161
158
  columns.append(Column(*column_args, **column_kwargs)) # type: ignore
162
159
 
163
160
  # Create the table object
164
- table_metadata = MetaData(schema=db_schema)
165
- table = Table(table_name, table_metadata, *columns, schema=db_schema)
161
+ table = Table(table_name, self.metadata, *columns, schema=self.db_schema)
166
162
 
167
163
  # Add multi-column unique constraints with table-specific names
168
164
  for constraint in schema_unique_constraints:
@@ -176,16 +172,20 @@ class MySQLDb(BaseDb):
176
172
  table.append_constraint(Index(idx_name, idx_col))
177
173
 
178
174
  with self.Session() as sess, sess.begin():
179
- create_schema(session=sess, db_schema=db_schema)
175
+ create_schema(session=sess, db_schema=self.db_schema)
180
176
 
181
177
  # Create table
182
- table.create(self.db_engine, checkfirst=True)
178
+ table_created = False
179
+ if not self.table_exists(table_name):
180
+ table.create(self.db_engine, checkfirst=True)
181
+ log_debug(f"Successfully created table '{table_name}'")
182
+ table_created = True
183
+ else:
184
+ log_debug(f"Table {self.db_schema}.{table_name} already exists, skipping creation")
183
185
 
184
186
  # Create indexes
185
187
  for idx in table.indexes:
186
188
  try:
187
- log_debug(f"Creating index: {idx.name}")
188
-
189
189
  # Check if index already exists
190
190
  with self.Session() as sess:
191
191
  exists_query = text(
@@ -194,24 +194,35 @@ class MySQLDb(BaseDb):
194
194
  )
195
195
  exists = (
196
196
  sess.execute(
197
- exists_query, {"schema": db_schema, "table_name": table_name, "index_name": idx.name}
197
+ exists_query,
198
+ {"schema": self.db_schema, "table_name": table_name, "index_name": idx.name},
198
199
  ).scalar()
199
200
  is not None
200
201
  )
201
202
  if exists:
202
- log_debug(f"Index {idx.name} already exists in {db_schema}.{table_name}, skipping creation")
203
+ log_debug(
204
+ f"Index {idx.name} already exists in {self.db_schema}.{table_name}, skipping creation"
205
+ )
203
206
  continue
204
207
 
205
208
  idx.create(self.db_engine)
206
209
 
210
+ log_debug(f"Created index: {idx.name} for table {self.db_schema}.{table_name}")
207
211
  except Exception as e:
208
212
  log_error(f"Error creating index {idx.name}: {e}")
209
213
 
210
- log_debug(f"Successfully created table {db_schema}.{table_name}")
214
+ # Store the schema version for the created table
215
+ if table_name != self.versions_table_name and table_created:
216
+ latest_schema_version = MigrationManager(self).latest_schema_version
217
+ self.upsert_schema_version(table_name=table_name, version=latest_schema_version.public)
218
+ log_info(
219
+ f"Successfully stored version {latest_schema_version.public} in database for table {table_name}"
220
+ )
221
+
211
222
  return table
212
223
 
213
224
  except Exception as e:
214
- log_error(f"Could not create table {db_schema}.{table_name}: {e}")
225
+ log_error(f"Could not create table {self.db_schema}.{table_name}: {e}")
215
226
  raise
216
227
 
217
228
  def _create_all_tables(self):
@@ -226,19 +237,13 @@ class MySQLDb(BaseDb):
226
237
  ]
227
238
 
228
239
  for table_name, table_type in tables_to_create:
229
- if table_name != self.versions_table_name:
230
- # Also store the schema version for the created table
231
- latest_schema_version = MigrationManager(self).latest_schema_version
232
- self.upsert_schema_version(table_name=table_name, version=latest_schema_version.public)
233
-
234
- self._create_table(table_name=table_name, table_type=table_type, db_schema=self.db_schema)
240
+ self._get_or_create_table(table_name=table_name, table_type=table_type, create_table_if_not_found=True)
235
241
 
236
242
  def _get_table(self, table_type: str, create_table_if_not_found: Optional[bool] = False) -> Optional[Table]:
237
243
  if table_type == "sessions":
238
244
  self.session_table = self._get_or_create_table(
239
245
  table_name=self.session_table_name,
240
246
  table_type="sessions",
241
- db_schema=self.db_schema,
242
247
  create_table_if_not_found=create_table_if_not_found,
243
248
  )
244
249
  return self.session_table
@@ -247,7 +252,6 @@ class MySQLDb(BaseDb):
247
252
  self.memory_table = self._get_or_create_table(
248
253
  table_name=self.memory_table_name,
249
254
  table_type="memories",
250
- db_schema=self.db_schema,
251
255
  create_table_if_not_found=create_table_if_not_found,
252
256
  )
253
257
  return self.memory_table
@@ -256,7 +260,6 @@ class MySQLDb(BaseDb):
256
260
  self.metrics_table = self._get_or_create_table(
257
261
  table_name=self.metrics_table_name,
258
262
  table_type="metrics",
259
- db_schema=self.db_schema,
260
263
  create_table_if_not_found=create_table_if_not_found,
261
264
  )
262
265
  return self.metrics_table
@@ -265,7 +268,6 @@ class MySQLDb(BaseDb):
265
268
  self.eval_table = self._get_or_create_table(
266
269
  table_name=self.eval_table_name,
267
270
  table_type="evals",
268
- db_schema=self.db_schema,
269
271
  create_table_if_not_found=create_table_if_not_found,
270
272
  )
271
273
  return self.eval_table
@@ -274,7 +276,6 @@ class MySQLDb(BaseDb):
274
276
  self.knowledge_table = self._get_or_create_table(
275
277
  table_name=self.knowledge_table_name,
276
278
  table_type="knowledge",
277
- db_schema=self.db_schema,
278
279
  create_table_if_not_found=create_table_if_not_found,
279
280
  )
280
281
  return self.knowledge_table
@@ -283,7 +284,6 @@ class MySQLDb(BaseDb):
283
284
  self.culture_table = self._get_or_create_table(
284
285
  table_name=self.culture_table_name,
285
286
  table_type="culture",
286
- db_schema=self.db_schema,
287
287
  create_table_if_not_found=create_table_if_not_found,
288
288
  )
289
289
  return self.culture_table
@@ -292,7 +292,6 @@ class MySQLDb(BaseDb):
292
292
  self.versions_table = self._get_or_create_table(
293
293
  table_name=self.versions_table_name,
294
294
  table_type="versions",
295
- db_schema=self.db_schema,
296
295
  create_table_if_not_found=create_table_if_not_found,
297
296
  )
298
297
  return self.versions_table
@@ -300,7 +299,7 @@ class MySQLDb(BaseDb):
300
299
  raise ValueError(f"Unknown table type: {table_type}")
301
300
 
302
301
  def _get_or_create_table(
303
- self, table_name: str, table_type: str, db_schema: str, create_table_if_not_found: Optional[bool] = False
302
+ self, table_name: str, table_type: str, create_table_if_not_found: Optional[bool] = False
304
303
  ) -> Optional[Table]:
305
304
  """
306
305
  Check if the table exists and is valid, else create it.
@@ -308,25 +307,19 @@ class MySQLDb(BaseDb):
308
307
  Args:
309
308
  table_name (str): Name of the table to get or create
310
309
  table_type (str): Type of table (used to get schema definition)
311
- db_schema (str): Database schema name
312
310
 
313
311
  Returns:
314
312
  Table: SQLAlchemy Table object representing the schema.
315
313
  """
316
314
 
317
315
  with self.Session() as sess, sess.begin():
318
- table_is_available = is_table_available(session=sess, table_name=table_name, db_schema=db_schema)
316
+ table_is_available = is_table_available(session=sess, table_name=table_name, db_schema=self.db_schema)
319
317
 
320
318
  if not table_is_available:
321
319
  if not create_table_if_not_found:
322
320
  return None
323
321
 
324
- created_table = self._create_table(table_name=table_name, table_type=table_type, db_schema=db_schema)
325
-
326
- if table_name != self.versions_table_name:
327
- # Also store the schema version for the created table
328
- latest_schema_version = MigrationManager(self).latest_schema_version
329
- self.upsert_schema_version(table_name=table_name, version=latest_schema_version.public)
322
+ created_table = self._create_table(table_name=table_name, table_type=table_type)
330
323
 
331
324
  return created_table
332
325
 
@@ -334,17 +327,16 @@ class MySQLDb(BaseDb):
334
327
  db_engine=self.db_engine,
335
328
  table_name=table_name,
336
329
  table_type=table_type,
337
- db_schema=db_schema,
330
+ db_schema=self.db_schema,
338
331
  ):
339
- raise ValueError(f"Table {db_schema}.{table_name} has an invalid schema")
332
+ raise ValueError(f"Table {self.db_schema}.{table_name} has an invalid schema")
340
333
 
341
334
  try:
342
- table = Table(table_name, self.metadata, schema=db_schema, autoload_with=self.db_engine)
343
- log_debug(f"Loaded existing table {db_schema}.{table_name}")
335
+ table = Table(table_name, self.metadata, schema=self.db_schema, autoload_with=self.db_engine)
344
336
  return table
345
337
 
346
338
  except Exception as e:
347
- log_error(f"Error loading existing table {db_schema}.{table_name}: {e}")
339
+ log_error(f"Error loading existing table {self.db_schema}.{table_name}: {e}")
348
340
  raise
349
341
 
350
342
  def get_latest_schema_version(self, table_name: str) -> str:
@@ -513,7 +505,7 @@ class MySQLDb(BaseDb):
513
505
  Args:
514
506
  session_type (Optional[SessionType]): The type of sessions to get.
515
507
  user_id (Optional[str]): The ID of the user to filter by.
516
- entity_id (Optional[str]): The ID of the agent / workflow to filter by.
508
+ component_id (Optional[str]): The ID of the agent / workflow to filter by.
517
509
  start_timestamp (Optional[int]): The start timestamp to filter by.
518
510
  end_timestamp (Optional[int]): The end timestamp to filter by.
519
511
  session_name (Optional[str]): The name of the session to filter by.
@@ -522,7 +514,6 @@ class MySQLDb(BaseDb):
522
514
  sort_by (Optional[str]): The field to sort by. Defaults to None.
523
515
  sort_order (Optional[str]): The sort order. Defaults to None.
524
516
  deserialize (Optional[bool]): Whether to serialize the sessions. Defaults to True.
525
- create_table_if_not_found (Optional[bool]): Whether to create the table if it doesn't exist.
526
517
 
527
518
  Returns:
528
519
  Union[List[Session], Tuple[List[Dict], int]]:
@@ -1254,7 +1245,7 @@ class MySQLDb(BaseDb):
1254
1245
  log_error(f"Exception clearing user memories: {e}")
1255
1246
 
1256
1247
  def get_user_memory_stats(
1257
- self, limit: Optional[int] = None, page: Optional[int] = None
1248
+ self, limit: Optional[int] = None, page: Optional[int] = None, user_id: Optional[str] = None
1258
1249
  ) -> Tuple[List[Dict[str, Any]], int]:
1259
1250
  """Get user memories stats.
1260
1251
 
@@ -1283,17 +1274,20 @@ class MySQLDb(BaseDb):
1283
1274
  return [], 0
1284
1275
 
1285
1276
  with self.Session() as sess, sess.begin():
1286
- stmt = (
1287
- select(
1288
- table.c.user_id,
1289
- func.count(table.c.memory_id).label("total_memories"),
1290
- func.max(table.c.updated_at).label("last_memory_updated_at"),
1291
- )
1292
- .where(table.c.user_id.is_not(None))
1293
- .group_by(table.c.user_id)
1294
- .order_by(func.max(table.c.updated_at).desc())
1277
+ stmt = select(
1278
+ table.c.user_id,
1279
+ func.count(table.c.memory_id).label("total_memories"),
1280
+ func.max(table.c.updated_at).label("last_memory_updated_at"),
1295
1281
  )
1296
1282
 
1283
+ if user_id is not None:
1284
+ stmt = stmt.where(table.c.user_id == user_id)
1285
+ else:
1286
+ stmt = stmt.where(table.c.user_id.is_not(None))
1287
+
1288
+ stmt = stmt.group_by(table.c.user_id)
1289
+ stmt = stmt.order_by(func.max(table.c.updated_at).desc())
1290
+
1297
1291
  count_stmt = select(func.count()).select_from(stmt.alias())
1298
1292
  total_count = sess.execute(count_stmt).scalar()
1299
1293