agno 2.1.5__py3-none-any.whl → 2.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1193 @@
1
+ from datetime import date, datetime, timedelta, timezone
2
+ from textwrap import dedent
3
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
4
+
5
+ from agno.db.base import BaseDb, SessionType
6
+ from agno.db.postgres.utils import (
7
+ get_dates_to_calculate_metrics_for,
8
+ )
9
+ from agno.db.schemas import UserMemory
10
+ from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
11
+ from agno.db.schemas.knowledge import KnowledgeRow
12
+ from agno.db.surrealdb import utils
13
+ from agno.db.surrealdb.metrics import (
14
+ bulk_upsert_metrics,
15
+ calculate_date_metrics,
16
+ fetch_all_sessions_data,
17
+ get_all_sessions_for_metrics_calculation,
18
+ get_metrics_calculation_starting_date,
19
+ )
20
+ from agno.db.surrealdb.models import (
21
+ TableType,
22
+ deserialize_eval_run_record,
23
+ deserialize_knowledge_row,
24
+ deserialize_session,
25
+ deserialize_sessions,
26
+ deserialize_user_memories,
27
+ deserialize_user_memory,
28
+ desurrealize_eval_run_record,
29
+ desurrealize_session,
30
+ desurrealize_user_memory,
31
+ get_schema,
32
+ get_session_type,
33
+ serialize_eval_run_record,
34
+ serialize_knowledge_row,
35
+ serialize_session,
36
+ serialize_user_memory,
37
+ )
38
+ from agno.db.surrealdb.queries import COUNT_QUERY, WhereClause, order_limit_start
39
+ from agno.db.surrealdb.utils import build_client
40
+ from agno.session import Session
41
+ from agno.utils.log import log_debug, log_error, log_info
42
+ from agno.utils.string import generate_id
43
+
44
+ try:
45
+ from surrealdb import BlockingHttpSurrealConnection, BlockingWsSurrealConnection, RecordID
46
+ except ImportError:
47
+ raise ImportError("The `surrealdb` package is not installed. Please install it via `pip install surrealdb`.")
48
+
49
+
50
+ class SurrealDb(BaseDb):
51
+ def __init__(
52
+ self,
53
+ client: Optional[Union[BlockingWsSurrealConnection, BlockingHttpSurrealConnection]],
54
+ db_url: str,
55
+ db_creds: dict[str, str],
56
+ db_ns: str,
57
+ db_db: str,
58
+ session_table: Optional[str] = None,
59
+ memory_table: Optional[str] = None,
60
+ metrics_table: Optional[str] = None,
61
+ eval_table: Optional[str] = None,
62
+ knowledge_table: Optional[str] = None,
63
+ id: Optional[str] = None,
64
+ ):
65
+ """
66
+ Interface for interacting with a SurrealDB database.
67
+
68
+ Args:
69
+ client: A blocking connection, either HTTP or WS
70
+ """
71
+ if id is None:
72
+ base_seed = db_url
73
+ seed = f"{base_seed}#{db_db}"
74
+ id = generate_id(seed)
75
+
76
+ super().__init__(
77
+ id=id,
78
+ session_table=session_table,
79
+ memory_table=memory_table,
80
+ metrics_table=metrics_table,
81
+ eval_table=eval_table,
82
+ knowledge_table=knowledge_table,
83
+ )
84
+ self._client = client
85
+ self._db_url = db_url
86
+ self._db_creds = db_creds
87
+ self._db_ns = db_ns
88
+ self._db_db = db_db
89
+ self._users_table_name: str = "agno_users"
90
+ self._agents_table_name: str = "agno_agents"
91
+ self._teams_table_name: str = "agno_teams"
92
+ self._workflows_table_name: str = "agno_workflows"
93
+
94
+ @property
95
+ def client(self) -> Union[BlockingWsSurrealConnection, BlockingHttpSurrealConnection]:
96
+ if self._client is None:
97
+ self._client = build_client(self._db_url, self._db_creds, self._db_ns, self._db_db)
98
+ return self._client
99
+
100
+ @property
101
+ def table_names(self) -> dict[TableType, str]:
102
+ return {
103
+ "agents": self._agents_table_name,
104
+ "evals": self.eval_table_name,
105
+ "knowledge": self.knowledge_table_name,
106
+ "memories": self.memory_table_name,
107
+ "sessions": self.session_table_name,
108
+ "teams": self._teams_table_name,
109
+ "users": self._users_table_name,
110
+ "workflows": self._workflows_table_name,
111
+ }
112
+
113
+ def _table_exists(self, table_name: str) -> bool:
114
+ response = self._query_one("INFO FOR DB", {}, dict)
115
+ if response is None:
116
+ raise Exception("Failed to retrieve database information")
117
+ return table_name in response.get("tables", [])
118
+
119
+ def _create_table(self, table_type: TableType, table_name: str):
120
+ query = get_schema(table_type, table_name)
121
+ self.client.query(query)
122
+
123
+ def _get_table(self, table_type: TableType, create_table_if_not_found: bool = True):
124
+ if table_type == "sessions":
125
+ table_name = self.session_table_name
126
+ elif table_type == "memories":
127
+ table_name = self.memory_table_name
128
+ elif table_type == "knowledge":
129
+ table_name = self.knowledge_table_name
130
+ elif table_type == "users":
131
+ table_name = self._users_table_name
132
+ elif table_type == "agents":
133
+ table_name = self._agents_table_name
134
+ elif table_type == "teams":
135
+ table_name = self._teams_table_name
136
+ elif table_type == "workflows":
137
+ table_name = self._workflows_table_name
138
+ elif table_type == "evals":
139
+ table_name = self.eval_table_name
140
+ elif table_type == "metrics":
141
+ table_name = self.metrics_table_name
142
+ else:
143
+ raise NotImplementedError(f"Unknown table type: {table_type}")
144
+
145
+ if create_table_if_not_found and not self._table_exists(table_name):
146
+ self._create_table(table_type, table_name)
147
+
148
+ return table_name
149
+
150
+ def _query(
151
+ self,
152
+ query: str,
153
+ vars: dict[str, Any],
154
+ record_type: type[utils.RecordType],
155
+ ) -> Sequence[utils.RecordType]:
156
+ return utils.query(self.client, query, vars, record_type)
157
+
158
+ def _query_one(
159
+ self,
160
+ query: str,
161
+ vars: dict[str, Any],
162
+ record_type: type[utils.RecordType],
163
+ ) -> Optional[utils.RecordType]:
164
+ return utils.query_one(self.client, query, vars, record_type)
165
+
166
+ def _count(self, table: str, where_clause: str, where_vars: dict[str, Any], group_by: Optional[str] = None) -> int:
167
+ total_count_query = COUNT_QUERY.format(
168
+ table=table,
169
+ where_clause=where_clause,
170
+ group_clause="GROUP ALL" if group_by is None else f"GROUP BY {group_by}",
171
+ group_fields="" if group_by is None else f", {group_by}",
172
+ )
173
+ count_result = self._query_one(total_count_query, where_vars, dict)
174
+ total_count = count_result.get("count") if count_result else 0
175
+ assert isinstance(total_count, int), f"Expected int, got {type(total_count)}"
176
+ total_count = int(total_count)
177
+ return total_count
178
+
179
+ # --- Sessions ---
180
+ def clear_sessions(self) -> None:
181
+ """Delete all session rows from the database.
182
+
183
+ Raises:
184
+ Exception: If an error occurs during deletion.
185
+ """
186
+ table = self._get_table("sessions")
187
+ _ = self.client.delete(table)
188
+
189
+ def delete_session(self, session_id: str) -> bool:
190
+ table = self._get_table(table_type="sessions")
191
+ if table is None:
192
+ return False
193
+ res = self.client.delete(RecordID(table, session_id))
194
+ return bool(res)
195
+
196
+ def delete_sessions(self, session_ids: list[str]) -> None:
197
+ table = self._get_table(table_type="sessions")
198
+ if table is None:
199
+ return
200
+
201
+ records = [RecordID(table, id) for id in session_ids]
202
+ self.client.query(f"DELETE FROM {table} WHERE id IN $records", {"records": records})
203
+
204
+ def get_session(
205
+ self,
206
+ session_id: str,
207
+ session_type: SessionType,
208
+ user_id: Optional[str] = None,
209
+ deserialize: Optional[bool] = True,
210
+ ) -> Optional[Union[Session, Dict[str, Any]]]:
211
+ r"""
212
+ Read a session from the database.
213
+
214
+ Args:
215
+ session_id (str): ID of the session to read.
216
+ session_type (SessionType): Type of session to get.
217
+ user_id (Optional[str]): User ID to filter by. Defaults to None.
218
+ deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
219
+
220
+ Returns:
221
+ Optional[Union[Session, Dict[str, Any]]]:
222
+ - When deserialize=True: Session object
223
+ - When deserialize=False: Session dictionary
224
+
225
+ Raises:
226
+ Exception: If an error occurs during retrieval.
227
+ """
228
+ sessions_table = self._get_table("sessions")
229
+ record = RecordID(sessions_table, session_id)
230
+ where = WhereClause()
231
+ if user_id is not None:
232
+ where = where.and_("user_id", user_id)
233
+ if session_type == SessionType.AGENT:
234
+ where = where.and_("agent", None, "!=")
235
+ elif session_type == SessionType.TEAM:
236
+ where = where.and_("team", None, "!=")
237
+ elif session_type == SessionType.WORKFLOW:
238
+ where = where.and_("workflow", None, "!=")
239
+ where_clause, where_vars = where.build()
240
+ query = dedent(f"""
241
+ SELECT *
242
+ FROM ONLY $record
243
+ {where_clause}
244
+ """)
245
+ vars = {"record": record, **where_vars}
246
+ raw = self._query_one(query, vars, dict)
247
+ if raw is None or not deserialize:
248
+ return raw
249
+
250
+ return deserialize_session(session_type, raw)
251
+
252
+ def get_sessions(
253
+ self,
254
+ session_type: Optional[SessionType] = None,
255
+ user_id: Optional[str] = None,
256
+ component_id: Optional[str] = None,
257
+ session_name: Optional[str] = None,
258
+ start_timestamp: Optional[int] = None,
259
+ end_timestamp: Optional[int] = None,
260
+ limit: Optional[int] = None,
261
+ page: Optional[int] = None,
262
+ sort_by: Optional[str] = None,
263
+ sort_order: Optional[str] = None,
264
+ deserialize: Optional[bool] = True,
265
+ ) -> Union[List[Session], Tuple[List[Dict[str, Any]], int]]:
266
+ r"""
267
+ Get all sessions in the given table. Can filter by user_id and entity_id.
268
+
269
+ Args:
270
+ session_type (SessionType): The type of session to get.
271
+ user_id (Optional[str]): The ID of the user to filter by.
272
+ component_id (Optional[str]): The ID of the agent / team / workflow to filter by.
273
+ session_name (Optional[str]): The name of the session to filter by.
274
+ start_timestamp (Optional[int]): The start timestamp to filter by.
275
+ end_timestamp (Optional[int]): The end timestamp to filter by.
276
+ limit (Optional[int]): The maximum number of sessions to return. Defaults to None.
277
+ page (Optional[int]): The page number to return. Defaults to None.
278
+ sort_by (Optional[str]): The field to sort by. Defaults to None.
279
+ sort_order (Optional[str]): The sort order. Defaults to None.
280
+ deserialize (Optional[bool]): Whether to serialize the sessions. Defaults to True.
281
+
282
+ Returns:
283
+ Union[List[Session], Tuple[List[Dict], int]]:
284
+ - When deserialize=True: List of Session objects
285
+ - When deserialize=False: Tuple of (session dictionaries, total count)
286
+
287
+ Raises:
288
+ Exception: If an error occurs during retrieval.
289
+ """
290
+ table = self._get_table("sessions")
291
+ # users_table = self._get_table("users", False) # Not used, commenting out for now.
292
+ agents_table = self._get_table("agents", False)
293
+ teams_table = self._get_table("teams", False)
294
+ workflows_table = self._get_table("workflows", False)
295
+
296
+ # -- Filters
297
+ where = WhereClause()
298
+
299
+ # user_id
300
+ if user_id is not None:
301
+ where = where.and_("user_id", user_id)
302
+
303
+ # component_id
304
+ if component_id is not None:
305
+ if session_type == SessionType.AGENT:
306
+ where = where.and_("agent", RecordID(agents_table, component_id))
307
+ elif session_type == SessionType.TEAM:
308
+ where = where.and_("team", RecordID(teams_table, component_id))
309
+ elif session_type == SessionType.WORKFLOW:
310
+ where = where.and_("workflow", RecordID(workflows_table, component_id))
311
+
312
+ # session_name
313
+ if session_name is not None:
314
+ where = where.and_("session_name", session_name, "~")
315
+
316
+ # start_timestamp
317
+ if start_timestamp is not None:
318
+ where = where.and_("start_timestamp", start_timestamp, ">=")
319
+
320
+ # end_timestamp
321
+ if end_timestamp is not None:
322
+ where = where.and_("end_timestamp", end_timestamp, "<=")
323
+
324
+ where_clause, where_vars = where.build()
325
+
326
+ # Total count
327
+ total_count = self._count(table, where_clause, where_vars)
328
+
329
+ # Query
330
+ order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
331
+ query = dedent(f"""
332
+ SELECT *
333
+ FROM {table}
334
+ {where_clause}
335
+ {order_limit_start_clause}
336
+ """)
337
+ sessions_raw = self._query(query, where_vars, dict)
338
+ converted_sessions_raw = [desurrealize_session(session, session_type) for session in sessions_raw]
339
+
340
+ if not deserialize:
341
+ return list(converted_sessions_raw), total_count
342
+
343
+ if session_type is None:
344
+ raise ValueError("session_type is required when deserialize=True")
345
+
346
+ return deserialize_sessions(session_type, list(sessions_raw))
347
+
348
+ def rename_session(
349
+ self, session_id: str, session_type: SessionType, session_name: str, deserialize: Optional[bool] = True
350
+ ) -> Optional[Union[Session, Dict[str, Any]]]:
351
+ """
352
+ Rename a session in the database.
353
+
354
+ Args:
355
+ session_id (str): The ID of the session to rename.
356
+ session_type (SessionType): The type of session to rename.
357
+ session_name (str): The new name for the session.
358
+ deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
359
+
360
+ Returns:
361
+ Optional[Union[Session, Dict[str, Any]]]:
362
+ - When deserialize=True: Session object
363
+ - When deserialize=False: Session dictionary
364
+
365
+ Raises:
366
+ Exception: If an error occurs during renaming.
367
+ """
368
+ table = self._get_table("sessions")
369
+ vars = {"record": RecordID(table, session_id), "name": session_name}
370
+
371
+ # Query
372
+ query = dedent("""
373
+ UPDATE ONLY $record
374
+ SET session_name = $name
375
+ """)
376
+ session_raw = self._query_one(query, vars, dict)
377
+
378
+ if session_raw is None or not deserialize:
379
+ return session_raw
380
+ return deserialize_session(session_type, session_raw)
381
+
382
+ def upsert_session(
383
+ self, session: Session, deserialize: Optional[bool] = True
384
+ ) -> Optional[Union[Session, Dict[str, Any]]]:
385
+ """
386
+ Insert or update a session in the database.
387
+
388
+ Args:
389
+ session (Session): The session data to upsert.
390
+ deserialize (Optional[bool]): Whether to deserialize the session. Defaults to True.
391
+
392
+ Returns:
393
+ Optional[Union[Session, Dict[str, Any]]]:
394
+ - When deserialize=True: Session object
395
+ - When deserialize=False: Session dictionary
396
+
397
+ Raises:
398
+ Exception: If an error occurs during upsert.
399
+ """
400
+ session_type = get_session_type(session)
401
+ table = self._get_table("sessions")
402
+ session_raw = self._query_one(
403
+ "UPSERT ONLY $record CONTENT $content",
404
+ {
405
+ "record": RecordID(table, session.session_id),
406
+ "content": serialize_session(session, self.table_names),
407
+ },
408
+ dict,
409
+ )
410
+ if session_raw is None or not deserialize:
411
+ return session_raw
412
+
413
+ return deserialize_session(session_type, session_raw)
414
+
415
+ def upsert_sessions(
416
+ self, sessions: List[Session], deserialize: Optional[bool] = True
417
+ ) -> List[Union[Session, Dict[str, Any]]]:
418
+ """
419
+ Bulk insert or update multiple sessions.
420
+
421
+ Args:
422
+ sessions (List[Session]): The list of session data to upsert.
423
+ deserialize (Optional[bool]): Whether to deserialize the sessions. Defaults to True.
424
+
425
+ Returns:
426
+ List[Union[Session, Dict[str, Any]]]: List of upserted sessions
427
+
428
+ Raises:
429
+ Exception: If an error occurs during bulk upsert.
430
+ """
431
+ if not sessions:
432
+ return []
433
+ session_type = get_session_type(sessions[0])
434
+ table = self._get_table("sessions")
435
+ sessions_raw: List[Dict[str, Any]] = []
436
+ for session in sessions:
437
+ # UPSERT does only work for one record at a time
438
+ session_raw = self._query_one(
439
+ "UPSERT ONLY $record CONTENT $content",
440
+ {
441
+ "record": RecordID(table, session.session_id),
442
+ "content": serialize_session(session, self.table_names),
443
+ },
444
+ dict,
445
+ )
446
+ if session_raw:
447
+ sessions_raw.append(session_raw)
448
+ if not deserialize:
449
+ return list(sessions_raw)
450
+
451
+ # wrapping with list because of:
452
+ # Type "List[Session]" is not assignable to return type "List[Session | Dict[str, Any]]"
453
+ # Consider switching from "list" to "Sequence" which is covariant
454
+ return list(deserialize_sessions(session_type, sessions_raw))
455
+
456
+ # --- Memory ---
457
+ def clear_memories(self) -> None:
458
+ """Delete all memories from the database.
459
+
460
+ Raises:
461
+ Exception: If an error occurs during deletion.
462
+ """
463
+ table = self._get_table("memories")
464
+ _ = self.client.delete(table)
465
+
466
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None) -> None:
467
+ """Delete a user memory from the database.
468
+
469
+ Args:
470
+ memory_id (str): The ID of the memory to delete.
471
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
472
+
473
+ Returns:
474
+ bool: True if deletion was successful, False otherwise.
475
+
476
+ Raises:
477
+ Exception: If an error occurs during deletion.
478
+ """
479
+ table = self._get_table("memories")
480
+ mem_rec_id = RecordID(table, memory_id)
481
+ if user_id is None:
482
+ self.client.delete(mem_rec_id)
483
+ else:
484
+ user_rec_id = RecordID(self._get_table("users"), user_id)
485
+ self.client.query(
486
+ f"DELETE FROM {table} WHERE user = $user AND id = $memory",
487
+ {"user": user_rec_id, "memory": mem_rec_id},
488
+ )
489
+
490
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
491
+ """Delete user memories from the database.
492
+
493
+ Args:
494
+ memory_ids (List[str]): The IDs of the memories to delete.
495
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
496
+
497
+ Raises:
498
+ Exception: If an error occurs during deletion.
499
+ """
500
+ table = self._get_table("memories")
501
+ records = [RecordID(table, memory_id) for memory_id in memory_ids]
502
+ if user_id is None:
503
+ _ = self.client.query(f"DELETE FROM {table} WHERE id IN $records", {"records": records})
504
+ else:
505
+ user_rec_id = RecordID(self._get_table("users"), user_id)
506
+ _ = self.client.query(
507
+ f"DELETE FROM {table} WHERE id IN $records AND user = $user", {"records": records, "user": user_rec_id}
508
+ )
509
+
510
+ def get_all_memory_topics(self, user_id: Optional[str] = None) -> List[str]:
511
+ """Get all memory topics from the database.
512
+
513
+ Args:
514
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
515
+
516
+ Returns:
517
+ List[str]: List of memory topics.
518
+ """
519
+ table = self._get_table("memories")
520
+ vars: dict[str, Any] = {}
521
+
522
+ # Query
523
+ if user_id is None:
524
+ query = dedent(f"""
525
+ RETURN (
526
+ SELECT
527
+ array::flatten(topics) as topics
528
+ FROM ONLY {table}
529
+ GROUP ALL
530
+ ).topics.distinct();
531
+ """)
532
+ else:
533
+ query = dedent(f"""
534
+ RETURN (
535
+ SELECT
536
+ array::flatten(topics) as topics
537
+ FROM ONLY {table}
538
+ WHERE user = $user
539
+ GROUP ALL
540
+ ).topics.distinct();
541
+ """)
542
+ vars["user"] = RecordID(self._get_table("users"), user_id)
543
+
544
+ result = self._query(query, vars, str)
545
+ return list(result)
546
+
547
+ def get_user_memory(
548
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
549
+ ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
550
+ """Get a memory from the database.
551
+
552
+ Args:
553
+ memory_id (str): The ID of the memory to get.
554
+ deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
555
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
556
+
557
+ Returns:
558
+ Optional[Union[UserMemory, Dict[str, Any]]]:
559
+ - When deserialize=True: UserMemory object
560
+ - When deserialize=False: UserMemory dictionary
561
+
562
+ Raises:
563
+ Exception: If an error occurs during retrieval.
564
+ """
565
+ table_name = self._get_table("memories")
566
+ record = RecordID(table_name, memory_id)
567
+ vars = {"record": record}
568
+
569
+ if user_id is None:
570
+ query = "SELECT * FROM ONLY $record"
571
+ else:
572
+ query = "SELECT * FROM ONLY $record WHERE user = $user"
573
+ vars["user"] = RecordID(self._get_table("users"), user_id)
574
+
575
+ result = self._query_one(query, vars, dict)
576
+ if result is None or not deserialize:
577
+ return result
578
+ return deserialize_user_memory(result)
579
+
580
+ def get_user_memories(
581
+ self,
582
+ user_id: Optional[str] = None,
583
+ agent_id: Optional[str] = None,
584
+ team_id: Optional[str] = None,
585
+ topics: Optional[List[str]] = None,
586
+ search_content: Optional[str] = None,
587
+ limit: Optional[int] = None,
588
+ page: Optional[int] = None,
589
+ sort_by: Optional[str] = None,
590
+ sort_order: Optional[str] = None,
591
+ deserialize: Optional[bool] = True,
592
+ ) -> Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
593
+ """Get all memories from the database as UserMemory objects.
594
+
595
+ Args:
596
+ user_id (Optional[str]): The ID of the user to filter by.
597
+ agent_id (Optional[str]): The ID of the agent to filter by.
598
+ team_id (Optional[str]): The ID of the team to filter by.
599
+ topics (Optional[List[str]]): The topics to filter by.
600
+ search_content (Optional[str]): The content to search for.
601
+ limit (Optional[int]): The maximum number of memories to return.
602
+ page (Optional[int]): The page number.
603
+ sort_by (Optional[str]): The column to sort by.
604
+ sort_order (Optional[str]): The order to sort by.
605
+ deserialize (Optional[bool]): Whether to serialize the memories. Defaults to True.
606
+
607
+
608
+ Returns:
609
+ Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
610
+ - When deserialize=True: List of UserMemory objects
611
+ - When deserialize=False: Tuple of (memory dictionaries, total count)
612
+
613
+ Raises:
614
+ Exception: If an error occurs during retrieval.
615
+ """
616
+ table = self._get_table("memories")
617
+ where = WhereClause()
618
+ if user_id is not None:
619
+ rec_id = RecordID(self._get_table("users"), user_id)
620
+ where.and_("user", rec_id)
621
+ if agent_id is not None:
622
+ rec_id = RecordID(self._get_table("agents"), agent_id)
623
+ where.and_("agent", rec_id)
624
+ if team_id is not None:
625
+ rec_id = RecordID(self._get_table("teams"), team_id)
626
+ where.and_("team", rec_id)
627
+ if topics is not None:
628
+ where.and_("topics", topics, "CONTAINSANY")
629
+ if search_content is not None:
630
+ where.and_("memory", search_content, "~")
631
+ where_clause, where_vars = where.build()
632
+
633
+ # Total count
634
+ total_count = self._count(table, where_clause, where_vars)
635
+
636
+ # Query
637
+ order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
638
+ query = dedent(f"""
639
+ SELECT *
640
+ FROM {table}
641
+ {where_clause}
642
+ {order_limit_start_clause}
643
+ """)
644
+ result = self._query(query, where_vars, dict)
645
+ if deserialize:
646
+ return deserialize_user_memories(result)
647
+ return [desurrealize_user_memory(x) for x in result], total_count
648
+
649
+ def get_user_memory_stats(
650
+ self,
651
+ limit: Optional[int] = None,
652
+ page: Optional[int] = None,
653
+ user_id: Optional[str] = None,
654
+ ) -> Tuple[List[Dict[str, Any]], int]:
655
+ """Get user memories stats.
656
+
657
+ Args:
658
+ limit (Optional[int]): The maximum number of user stats to return.
659
+ page (Optional[int]): The page number.
660
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
661
+
662
+ Returns:
663
+ Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
664
+
665
+ Example:
666
+ (
667
+ [
668
+ {
669
+ "user_id": "123",
670
+ "total_memories": 10,
671
+ "last_memory_updated_at": 1714560000,
672
+ },
673
+ ],
674
+ total_count: 1,
675
+ )
676
+ """
677
+ memories_table_name = self._get_table("memories")
678
+ where = WhereClause()
679
+
680
+ if user_id is None:
681
+ where.and_("!!user", True, "=") # this checks that user is not falsy
682
+ else:
683
+ where.and_("user", RecordID(self._get_table("users"), user_id), "=")
684
+
685
+ where_clause, where_vars = where.build()
686
+ # Group
687
+ group_clause = "GROUP BY user"
688
+ # Order
689
+ order_limit_start_clause = order_limit_start("last_memory_updated_at", "DESC", limit, page)
690
+ # Total count
691
+ total_count = (
692
+ self._query_one(f"(SELECT user FROM {memories_table_name} GROUP BY user).map(|$x| $x.user).len()", {}, int)
693
+ or 0
694
+ )
695
+ # Query
696
+ query = dedent(f"""
697
+ SELECT
698
+ user,
699
+ count(id) AS total_memories,
700
+ time::max(updated_at) AS last_memory_updated_at
701
+ FROM {memories_table_name}
702
+ {where_clause}
703
+ {group_clause}
704
+ {order_limit_start_clause}
705
+ """)
706
+ result = self._query(query, where_vars, dict)
707
+
708
+ # deserialize dates and RecordIDs
709
+ for row in result:
710
+ row["user_id"] = row["user"].id
711
+ del row["user"]
712
+ row["last_memory_updated_at"] = row["last_memory_updated_at"].timestamp()
713
+ row["last_memory_updated_at"] = int(row["last_memory_updated_at"])
714
+
715
+ return list(result), total_count
716
+
717
+ def upsert_user_memory(
718
+ self, memory: UserMemory, deserialize: Optional[bool] = True
719
+ ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
720
+ """Upsert a user memory in the database.
721
+
722
+ Args:
723
+ memory (UserMemory): The user memory to upsert.
724
+ deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
725
+
726
+ Returns:
727
+ Optional[Union[UserMemory, Dict[str, Any]]]:
728
+ - When deserialize=True: UserMemory object
729
+ - When deserialize=False: UserMemory dictionary
730
+
731
+ Raises:
732
+ Exception: If an error occurs during upsert.
733
+ """
734
+ table = self._get_table("memories")
735
+ user_table = self._get_table("users")
736
+ if memory.memory_id:
737
+ record = RecordID(table, memory.memory_id)
738
+ query = "UPSERT ONLY $record CONTENT $content"
739
+ result = self._query_one(
740
+ query, {"record": record, "content": serialize_user_memory(memory, table, user_table)}, dict
741
+ )
742
+ else:
743
+ query = f"CREATE ONLY {table} CONTENT $content"
744
+ result = self._query_one(query, {"content": serialize_user_memory(memory, table, user_table)}, dict)
745
+ if result is None:
746
+ return None
747
+ elif not deserialize:
748
+ return desurrealize_user_memory(result)
749
+ return deserialize_user_memory(result)
750
+
751
+ def upsert_memories(
752
+ self, memories: List[UserMemory], deserialize: Optional[bool] = True
753
+ ) -> List[Union[UserMemory, Dict[str, Any]]]:
754
+ """
755
+ Bulk insert or update multiple memories in the database for improved performance.
756
+
757
+ Args:
758
+ memories (List[UserMemory]): The list of memories to upsert.
759
+ deserialize (Optional[bool]): Whether to deserialize the memories. Defaults to True.
760
+
761
+ Returns:
762
+ List[Union[UserMemory, Dict[str, Any]]]: List of upserted memories
763
+
764
+ Raises:
765
+ Exception: If an error occurs during bulk upsert.
766
+ """
767
+ if not memories:
768
+ return []
769
+ table = self._get_table("memories")
770
+ user_table_name = self._get_table("users")
771
+ raw: list[dict] = []
772
+ for memory in memories:
773
+ if memory.memory_id:
774
+ # UPSERT does only work for one record at a time
775
+ session_raw = self._query_one(
776
+ "UPSERT ONLY $record CONTENT $content",
777
+ {
778
+ "record": RecordID(table, memory.memory_id),
779
+ "content": serialize_user_memory(memory, table, user_table_name),
780
+ },
781
+ dict,
782
+ )
783
+ else:
784
+ session_raw = self._query_one(
785
+ f"CREATE ONLY {table} CONTENT $content",
786
+ {"content": serialize_user_memory(memory, table, user_table_name)},
787
+ dict,
788
+ )
789
+ if session_raw is not None:
790
+ raw.append(session_raw)
791
+ if raw is None or not deserialize:
792
+ return [desurrealize_user_memory(x) for x in raw]
793
+ # wrapping with list because of:
794
+ # Type "List[Session]" is not assignable to return type "List[Session | Dict[str, Any]]"
795
+ # Consider switching from "list" to "Sequence" which is covariant
796
+ return list(deserialize_user_memories(raw))
797
+
798
+ # --- Metrics ---
799
+ def get_metrics(
800
+ self,
801
+ starting_date: Optional[date] = None,
802
+ ending_date: Optional[date] = None,
803
+ ) -> Tuple[List[Dict[str, Any]], Optional[int]]:
804
+ """Get all metrics matching the given date range.
805
+
806
+ Args:
807
+ starting_date (Optional[date]): The starting date to filter metrics by.
808
+ ending_date (Optional[date]): The ending date to filter metrics by.
809
+
810
+ Returns:
811
+ Tuple[List[dict], Optional[int]]: A tuple containing the metrics and the timestamp of the latest update.
812
+
813
+ Raises:
814
+ Exception: If an error occurs during retrieval.
815
+ """
816
+ table = self._get_table("metrics")
817
+
818
+ where = WhereClause()
819
+
820
+ # starting_date - need to convert date to datetime for comparison
821
+ if starting_date is not None:
822
+ starting_datetime = datetime.combine(starting_date, datetime.min.time()).replace(tzinfo=timezone.utc)
823
+ where = where.and_("date", starting_datetime, ">=")
824
+
825
+ # ending_date - need to convert date to datetime for comparison
826
+ if ending_date is not None:
827
+ ending_datetime = datetime.combine(ending_date, datetime.min.time()).replace(tzinfo=timezone.utc)
828
+ where = where.and_("date", ending_datetime, "<=")
829
+
830
+ where_clause, where_vars = where.build()
831
+
832
+ # Query
833
+ query = dedent(f"""
834
+ SELECT *
835
+ FROM {table}
836
+ {where_clause}
837
+ ORDER BY date ASC
838
+ """)
839
+
840
+ results = self._query(query, where_vars, dict)
841
+
842
+ # Get the latest updated_at from all results
843
+ latest_update = None
844
+ if results:
845
+ # Find the maximum updated_at timestamp
846
+ latest_update = max(int(r["updated_at"].timestamp()) for r in results)
847
+
848
+ # Transform results to match expected format
849
+ transformed_results = []
850
+ for r in results:
851
+ transformed = dict(r)
852
+
853
+ # Convert RecordID to string
854
+ if hasattr(transformed.get("id"), "id"):
855
+ transformed["id"] = transformed["id"].id
856
+ elif isinstance(transformed.get("id"), RecordID):
857
+ transformed["id"] = str(transformed["id"].id)
858
+
859
+ # Convert datetime objects to Unix timestamps
860
+ if isinstance(transformed.get("created_at"), datetime):
861
+ transformed["created_at"] = int(transformed["created_at"].timestamp())
862
+ if isinstance(transformed.get("updated_at"), datetime):
863
+ transformed["updated_at"] = int(transformed["updated_at"].timestamp())
864
+ if isinstance(transformed.get("date"), datetime):
865
+ transformed["date"] = int(transformed["date"].timestamp())
866
+
867
+ transformed_results.append(transformed)
868
+
869
+ return transformed_results, latest_update
870
+
871
+ return [], latest_update
872
+
873
+ def calculate_metrics(self) -> Optional[List[Dict[str, Any]]]: # More specific return type
874
+ """Calculate metrics for all dates without complete metrics.
875
+
876
+ Returns:
877
+ Optional[List[Dict[str, Any]]]: The calculated metrics.
878
+
879
+ Raises:
880
+ Exception: If an error occurs during metrics calculation.
881
+ """
882
+ try:
883
+ table = self._get_table("metrics") # Removed create_table_if_not_found parameter
884
+
885
+ starting_date = get_metrics_calculation_starting_date(self.client, table, self.get_sessions)
886
+
887
+ if starting_date is None:
888
+ log_info("No session data found. Won't calculate metrics.")
889
+ return None
890
+
891
+ dates_to_process = get_dates_to_calculate_metrics_for(starting_date)
892
+ if not dates_to_process:
893
+ log_info("Metrics already calculated for all relevant dates.")
894
+ return None
895
+
896
+ start_timestamp = datetime.combine(dates_to_process[0], datetime.min.time()).replace(tzinfo=timezone.utc)
897
+ end_timestamp = datetime.combine(dates_to_process[-1] + timedelta(days=1), datetime.min.time()).replace(
898
+ tzinfo=timezone.utc
899
+ )
900
+
901
+ sessions = get_all_sessions_for_metrics_calculation(
902
+ self.client, self._get_table("sessions"), start_timestamp, end_timestamp
903
+ )
904
+
905
+ all_sessions_data = fetch_all_sessions_data(
906
+ sessions=sessions, # Added parameter name for clarity
907
+ dates_to_process=dates_to_process,
908
+ start_timestamp=int(start_timestamp.timestamp()), # This expects int
909
+ )
910
+ if not all_sessions_data:
911
+ log_info("No new session data found. Won't calculate metrics.")
912
+ return None
913
+
914
+ metrics_records = []
915
+
916
+ for date_to_process in dates_to_process:
917
+ date_key = date_to_process.isoformat()
918
+ sessions_for_date = all_sessions_data.get(date_key, {})
919
+
920
+ # Skip dates with no sessions
921
+ if not any(len(sessions) > 0 for sessions in sessions_for_date.values()):
922
+ continue
923
+
924
+ metrics_record = calculate_date_metrics(date_to_process, sessions_for_date)
925
+ metrics_records.append(metrics_record)
926
+
927
+ results = [] # Initialize before the if block
928
+ if metrics_records:
929
+ results = bulk_upsert_metrics(self.client, table, metrics_records)
930
+
931
+ log_debug("Updated metrics calculations")
932
+ return results
933
+
934
+ except Exception as e:
935
+ log_error(f"Exception refreshing metrics: {e}")
936
+ raise e
937
+
938
+ # --- Knowledge ---
939
+ def clear_knowledge(self) -> None:
940
+ """Delete all knowledge rows from the database.
941
+
942
+ Raises:
943
+ Exception: If an error occurs during deletion.
944
+ """
945
+ table = self._get_table("knowledge")
946
+ _ = self.client.delete(table)
947
+
948
+ def delete_knowledge_content(self, id: str):
949
+ """Delete a knowledge row from the database.
950
+
951
+ Args:
952
+ id (str): The ID of the knowledge row to delete.
953
+ """
954
+ table = self._get_table("knowledge")
955
+ self.client.delete(RecordID(table, id))
956
+
957
+ def get_knowledge_content(self, id: str) -> Optional[KnowledgeRow]:
958
+ """Get a knowledge row from the database.
959
+
960
+ Args:
961
+ id (str): The ID of the knowledge row to get.
962
+
963
+ Returns:
964
+ Optional[KnowledgeRow]: The knowledge row, or None if it doesn't exist.
965
+ """
966
+ table = self._get_table("knowledge")
967
+ record_id = RecordID(table, id)
968
+ raw = self._query_one("SELECT * FROM ONLY $record_id", {"record_id": record_id}, dict)
969
+ return deserialize_knowledge_row(raw) if raw else None
970
+
971
+ def get_knowledge_contents(
972
+ self,
973
+ limit: Optional[int] = None,
974
+ page: Optional[int] = None,
975
+ sort_by: Optional[str] = None,
976
+ sort_order: Optional[str] = None,
977
+ ) -> Tuple[List[KnowledgeRow], int]:
978
+ """Get all knowledge contents from the database.
979
+
980
+ Args:
981
+ limit (Optional[int]): The maximum number of knowledge contents to return.
982
+ page (Optional[int]): The page number.
983
+ sort_by (Optional[str]): The column to sort by.
984
+ sort_order (Optional[str]): The order to sort by.
985
+
986
+ Returns:
987
+ Tuple[List[KnowledgeRow], int]: The knowledge contents and total count.
988
+
989
+ Raises:
990
+ Exception: If an error occurs during retrieval.
991
+ """
992
+ table = self._get_table("knowledge")
993
+ where = WhereClause()
994
+ where_clause, where_vars = where.build()
995
+
996
+ # Total count
997
+ total_count = self._count(table, where_clause, where_vars)
998
+
999
+ # Query
1000
+ order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
1001
+ query = dedent(f"""
1002
+ SELECT *
1003
+ FROM {table}
1004
+ {where_clause}
1005
+ {order_limit_start_clause}
1006
+ """)
1007
+ result = self._query(query, where_vars, dict)
1008
+ return [deserialize_knowledge_row(row) for row in result], total_count
1009
+
1010
+ def upsert_knowledge_content(self, knowledge_row: KnowledgeRow) -> Optional[KnowledgeRow]:
1011
+ """Upsert knowledge content in the database.
1012
+
1013
+ Args:
1014
+ knowledge_row (KnowledgeRow): The knowledge row to upsert.
1015
+
1016
+ Returns:
1017
+ Optional[KnowledgeRow]: The upserted knowledge row, or None if the operation fails.
1018
+ """
1019
+ knowledge_table_name = self._get_table("knowledge")
1020
+ record = RecordID(knowledge_table_name, knowledge_row.id)
1021
+ query = "UPSERT ONLY $record CONTENT $content"
1022
+ result = self._query_one(
1023
+ query, {"record": record, "content": serialize_knowledge_row(knowledge_row, knowledge_table_name)}, dict
1024
+ )
1025
+ return deserialize_knowledge_row(result) if result else None
1026
+
1027
+ # --- Evals ---
1028
+ def clear_evals(self) -> None:
1029
+ """Delete all eval rows from the database.
1030
+
1031
+ Raises:
1032
+ Exception: If an error occurs during deletion.
1033
+ """
1034
+ table = self._get_table("evals")
1035
+ _ = self.client.delete(table)
1036
+
1037
+ def create_eval_run(self, eval_run: EvalRunRecord) -> Optional[EvalRunRecord]:
1038
+ """Create an EvalRunRecord in the database.
1039
+
1040
+ Args:
1041
+ eval_run (EvalRunRecord): The eval run to create.
1042
+
1043
+ Returns:
1044
+ Optional[EvalRunRecord]: The created eval run, or None if the operation fails.
1045
+
1046
+ Raises:
1047
+ Exception: If an error occurs during creation.
1048
+ """
1049
+ table = self._get_table("evals")
1050
+ rec_id = RecordID(table, eval_run.run_id)
1051
+ query = "CREATE ONLY $record CONTENT $content"
1052
+ result = self._query_one(
1053
+ query, {"record": rec_id, "content": serialize_eval_run_record(eval_run, self.table_names)}, dict
1054
+ )
1055
+ return deserialize_eval_run_record(result) if result else None
1056
+
1057
+ def delete_eval_runs(self, eval_run_ids: List[str]) -> None:
1058
+ """Delete multiple eval runs from the database.
1059
+
1060
+ Args:
1061
+ eval_run_ids (List[str]): List of eval run IDs to delete.
1062
+ """
1063
+ table = self._get_table("evals")
1064
+ records = [RecordID(table, id) for id in eval_run_ids]
1065
+ _ = self.client.query(f"DELETE FROM {table} WHERE id IN $records", {"records": records})
1066
+
1067
+ def get_eval_run(
1068
+ self, eval_run_id: str, deserialize: Optional[bool] = True
1069
+ ) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
1070
+ """Get an eval run from the database.
1071
+
1072
+ Args:
1073
+ eval_run_id (str): The ID of the eval run to get.
1074
+ deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
1075
+
1076
+ Returns:
1077
+ Optional[Union[EvalRunRecord, Dict[str, Any]]]:
1078
+ - When deserialize=True: EvalRunRecord object
1079
+ - When deserialize=False: EvalRun dictionary
1080
+
1081
+ Raises:
1082
+ Exception: If an error occurs during retrieval.
1083
+ """
1084
+ table = self._get_table("evals")
1085
+ record = RecordID(table, eval_run_id)
1086
+ result = self._query_one("SELECT * FROM ONLY $record", {"record": record}, dict)
1087
+ if not result or not deserialize:
1088
+ return desurrealize_eval_run_record(result) if result is not None else None
1089
+ return deserialize_eval_run_record(result)
1090
+
1091
+ def get_eval_runs(
1092
+ self,
1093
+ limit: Optional[int] = None,
1094
+ page: Optional[int] = None,
1095
+ sort_by: Optional[str] = None,
1096
+ sort_order: Optional[str] = None,
1097
+ agent_id: Optional[str] = None,
1098
+ team_id: Optional[str] = None,
1099
+ workflow_id: Optional[str] = None,
1100
+ model_id: Optional[str] = None,
1101
+ filter_type: Optional[EvalFilterType] = None,
1102
+ eval_type: Optional[List[EvalType]] = None,
1103
+ deserialize: Optional[bool] = True,
1104
+ ) -> Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
1105
+ """Get all eval runs from the database.
1106
+
1107
+ Args:
1108
+ limit (Optional[int]): The maximum number of eval runs to return.
1109
+ page (Optional[int]): The page number to return.
1110
+ sort_by (Optional[str]): The field to sort by.
1111
+ sort_order (Optional[str]): The order to sort by.
1112
+ agent_id (Optional[str]): The ID of the agent to filter by.
1113
+ team_id (Optional[str]): The ID of the team to filter by.
1114
+ workflow_id (Optional[str]): The ID of the workflow to filter by.
1115
+ model_id (Optional[str]): The ID of the model to filter by.
1116
+ eval_type (Optional[List[EvalType]]): The type of eval to filter by.
1117
+ filter_type (Optional[EvalFilterType]): The type of filter to apply.
1118
+ deserialize (Optional[bool]): Whether to serialize the eval runs. Defaults to True.
1119
+
1120
+ Returns:
1121
+ Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
1122
+ - When deserialize=True: List of EvalRunRecord objects
1123
+ - When deserialize=False: List of eval run dictionaries and the total count
1124
+
1125
+ Raises:
1126
+ Exception: If there is an error getting the eval runs.
1127
+ """
1128
+ table = self._get_table("evals")
1129
+
1130
+ where = WhereClause()
1131
+ if filter_type is not None:
1132
+ if filter_type == EvalFilterType.AGENT:
1133
+ where.and_("agent", RecordID(self._get_table("agents"), agent_id))
1134
+ elif filter_type == EvalFilterType.TEAM:
1135
+ where.and_("team", RecordID(self._get_table("teams"), team_id))
1136
+ elif filter_type == EvalFilterType.WORKFLOW:
1137
+ where.and_("workflow", RecordID(self._get_table("workflows"), workflow_id))
1138
+ if model_id is not None:
1139
+ where.and_("model_id", model_id)
1140
+ if eval_type is not None:
1141
+ where.and_("eval_type", eval_type)
1142
+ where_clause, where_vars = where.build()
1143
+
1144
+ # Order
1145
+ order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
1146
+
1147
+ # Total count
1148
+ total_count = self._count(table, where_clause, where_vars)
1149
+
1150
+ # Query
1151
+ query = dedent(f"""
1152
+ SELECT *
1153
+ FROM {table}
1154
+ {where_clause}
1155
+ {order_limit_start_clause}
1156
+ """)
1157
+ result = self._query(query, where_vars, dict)
1158
+
1159
+ if not deserialize:
1160
+ return list(result), total_count
1161
+ return [deserialize_eval_run_record(x) for x in result]
1162
+
1163
+ def rename_eval_run(
1164
+ self, eval_run_id: str, name: str, deserialize: Optional[bool] = True
1165
+ ) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
1166
+ """Update the name of an eval run in the database.
1167
+
1168
+ Args:
1169
+ eval_run_id (str): The ID of the eval run to update.
1170
+ name (str): The new name of the eval run.
1171
+ deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
1172
+
1173
+ Returns:
1174
+ Optional[Union[EvalRunRecord, Dict[str, Any]]]:
1175
+ - When deserialize=True: EvalRunRecord object
1176
+ - When deserialize=False: EvalRun dictionary
1177
+
1178
+ Raises:
1179
+ Exception: If there is an error updating the eval run.
1180
+ """
1181
+ table = self._get_table("evals")
1182
+ vars = {"record": RecordID(table, eval_run_id), "name": name}
1183
+
1184
+ # Query
1185
+ query = dedent("""
1186
+ UPDATE ONLY $record
1187
+ SET name = $name
1188
+ """)
1189
+ raw = self._query_one(query, vars, dict)
1190
+
1191
+ if not raw or not deserialize:
1192
+ return raw
1193
+ return deserialize_eval_run_record(raw)